Players throw coins, any player that throws tails gets eliminated. Players want to win, so not to get eliminated. If all players throw tails, no one is eliminated. Some players are cheaters and their coins flip heads more then 50% of the time.

There is also a detective that can see history of all throws of each player. He's task is to kick out all cheaters, while keeping fair players in. Cheaters' task is to find a coin that will allow them to win easily, but not get caught. When a player gives up or is kicked out it is raveled if it was a cheater and if yes, what coin they were using.

Each round starts with n players, where n is a hyperparameter defined below. When a player is eliminated from a round it comes back the next one. Players have a chance of giving up, which is also a hyperparameter defined below. When a player leaves or gets kicked out by the detective, new player joins so number of players is always n. Chance of a player being a cheater is also a hyperparameter defied below.

In [1]:
# Imports
import random
import torch

In [2]:
# Hyperparameters
NUMBER_OF_PLAYERS: int = 10 # n
GIVE_UP_CHANCE: float = 0.1 # chance of giving up
CHEATER_CHANCE: float = 0.4 # chance of a player being a cheater

In [3]:
# Helper functions
def pad(arr: list[list], filler, min_len: int=-1) -> list[list]:
    if not arr:
        return []
    max_len = max(map(len, arr))
    if max_len < min_len:
        max_len = min_len
    return [[filler] * (max_len - len(sub)) + sub for sub in arr]


def softmax(arr, *, round_dec):
    total = sum(arr)
    return [round(x/total, round_dec) for x in arr]

In [4]:
class Player:
    def __init__(self):
        """
        A fair player class
        """
        self.history: list[int] = []
        self.chance: float = 0.5
        self.in_game = True


    def throw(self) -> int:
        """
        Returns 0 if tails, 1 if heads
        """
        flip: int = int(random.random() < self.chance)
        self.history.append(flip)
        return flip


    def after_play(self, *args, **kwargs):
        pass

In [5]:
class Cheater(Player):
    def __init__(self):
        """
        A cheater class 
        """
        super().__init__()
        self.chance: float = 0.5 + (random.random() / 2)
        self.chance_history: list[float] = []
        self.suspect_history: list[float] = []


    def after_play(self, suspect_score):
        self.chance_history.append(self.chance)
        self.suspect_history.append(suspect_score)


    def update_chance(self):
        pass

In [6]:
def get_player_type(player: Player | Cheater):
    return type(player).__name__


def encode_player_type(player: Player | Cheater):
    match player:
        case Cheater():
            return 1
        case Player():
            return 0
        case _:
            return type(player).__name__

In [7]:
class Detective(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = torch.nn.Sequential(
            torch.nn.Conv1d(1, 4, 4),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool1d(1),

            torch.nn.Flatten(),
            torch.nn.Linear(4, 8),
            torch.nn.ReLU(),
            torch.nn.Linear(8, 1),
            torch.nn.Sigmoid()
        )
        
        self.loss_fn = torch.nn.BCELoss()
        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.01)


    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.layers(x)
        return x.squeeze(1)


detective = Detective()

In [8]:
class Game:
    def __init__(self):
        """
        A game class
        """
        self.players: list[Player | Cheater] = []
        
        self.stats = {
            "num_players": [0, 0],
            "num_eliminated": [0, 0],
            "playtime": [0, 0],
        }
        
        self.create_players()


    def create_players(self):
        for _ in range(NUMBER_OF_PLAYERS - len(self.players)):
            if random.random() < CHEATER_CHANCE:
                self.players.append(Cheater())
                self.stats["num_players"][1] += 1
            else:
                self.players.append(Player())
                self.stats["num_players"][0] += 1


    def play_round(self):
        tail_indexes: list[int] = []
        for i, player in enumerate(self.players):
            if not player.in_game:
                continue
            if player.throw() == 0:
                tail_indexes.append(i)
        return tail_indexes


    def play(self, *, silent=False):
        while len([0 for p in self.players if p.in_game]) > 1:
            losers = self.play_round()
            if len(losers) == len(self.players):
                continue
            for i in reversed(losers):
                self.players[i].in_game = False
        self.after_play(silent=silent)


    def after_play(self, *, silent=False):
        for player in self.players:
            self.stats["playtime"][encode_player_type(player)] += 1
        
        x_train, y_train = [], []
        
        gave_up_indexes = []
        for i, _ in enumerate(self.players):
            if random.random() < GIVE_UP_CHANCE:
                gave_up_indexes.append(i)
        for i in reversed(gave_up_indexes):
            elimed_player = self.players.pop(i)
            x_train.append(elimed_player.history)
            y_train.append(encode_player_type(elimed_player))
            if not silent:
                print(f"A {get_player_type(elimed_player)} gave up!")
        
        detective.eval()
        histories = torch.tensor(pad([player.history for player in self.players], -1, 4), dtype=torch.float32)
        scores = detective(histories)
        
        eliminated_indexes = torch.where(scores > 0.5)[0]
        for i in reversed(eliminated_indexes):
            elimed_player = self.players.pop(i.item())
            x_train.append(elimed_player.history)
            y_train.append(encode_player_type(elimed_player))
            if not silent:
                print(f"A {get_player_type(elimed_player)} got eliminated!")
            self.stats["num_eliminated"][encode_player_type(elimed_player)] += 1
        
        if x_train:
            detective.train()
            detective.optimizer.zero_grad()
            
            x_train, y_train = torch.tensor(pad(x_train, -1, 4), dtype=torch.float32), torch.tensor(y_train, dtype=torch.float32)
            outputs = detective(x_train)
            
            loss = detective.loss_fn(outputs, y_train)
            loss.backward()
            detective.optimizer.step()
        
        for player in self.players:
            player.in_game = True
        self.create_players()
    
    
    def show_stats(self):
        keys, values = [list(x) for x in zip(*self.stats.items())]
        keys.append("average playtime")
        playtime_idx = keys.index("playtime")
        keys.pop(playtime_idx)
        values.append([round(y/x, 3) for x, y in zip(values[0], values[2])])
        values.pop(playtime_idx)
        softmax_values = [softmax(x, round_dec=3) for x in values]
        
        width = max(map(len, keys)) + 2
        sep_row = "".join(["+", "-"*(5*width+4), "+"])
        
        print(sep_row)
        
        print("|", " "*width, "|", "Total".center(width*2+1), "|", "Softmax".center(width*2+1), "|", sep="")
        
        print("|", " "*width, "+", "-"*(4*width+3), "+", sep="")

        print("|", " "*width, "|", "Player".center(width), "|", "Cheater".center(width),
              "|", "Player".center(width), "|", "Cheater".center(width), "|", sep="")

        print(sep_row)

        for i, k in enumerate(keys):
            print("|", k.center(width),
                  "|", str(values[i][0]).center(width), "|", str(values[i][1]).center(width),
                  "|", str(softmax_values[i][0]).center(width), "|", str(softmax_values[i][1]).center(width), "|", sep="")

        print(sep_row)


game = Game()

In [9]:
num_plays = 1e4
for _ in range(int(num_plays)):
    game.play(silent=True)

In [10]:
game.show_stats()

+----------------------------------------------------------------------------------------------+
|                  |                Total                |               Softmax               |
|                  +---------------------------------------------------------------------------+
|                  |      Player      |     Cheater      |      Player      |     Cheater      |
+----------------------------------------------------------------------------------------------+
|   num_players    |       6953       |       4619       |      0.601       |      0.399       |
|  num_eliminated  |       290        |       1362       |      0.176       |      0.824       |
| average playtime |      9.537       |      7.293       |      0.567       |      0.433       |
+----------------------------------------------------------------------------------------------+
