Players throw coins, any player that throws tails gets eliminated. Players want to win, so not to get eliminated. If all players throw tails, no one is eliminated. Some players are cheaters and their coins flip heads more then 50% of the time.

There are also a detectives that can see history of all throws of each player. There are few types of detectives. Their task is to kick out all cheaters, while keeping fair players in. Cheaters' task is to find a coin that will allow them to win easily, but not get caught. When a player gives up or is kicked out it is raveled if it was a cheater and if yes, what coin they were using.

Each round starts with n players, where n is a hyperparameter defined below. When a player is eliminated from a round it comes back the next one. Players have a chance of giving up, which is also a hyperparameter defined below. When a player leaves or gets kicked out by the detective, new player joins so number of players is always n. Chance of a player being a cheater is also a hyperparameter defied below.

In [1]:
# Imports
import random
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
# Hyperparameters
NUMBER_OF_PLAYERS: int = 16 # n
GIVE_UP_CHANCE: float = 0.1 # chance of giving up
CHEATER_CHANCE: float = 0.4 # chance of a player being a cheater

In [3]:
# Helper functions
def pad(arr: list[list], filler, min_len: int=-1) -> list[list]:
    if not arr:
        return []
    max_len = max(map(len, arr))
    if max_len < min_len:
        max_len = min_len
    return [[filler] * (max_len - len(sub)) + sub for sub in arr]


def softmax(arr, *, round_dec=0):
    total = sum(arr)
    if total == 0:
        return arr
    return [round(x/total, round_dec) for x in arr]

In [4]:
class Player:
    def __init__(self):
        """
        A fair player class
        """
        self.history: list[int] = []
        self.chance: float = 0.5
        self.in_game = True


    def throw(self) -> int:
        """
        Returns 0 if tails, 1 if heads
        """
        flip: int = int(random.random() < self.chance)
        self.history.append(flip)
        return flip


    def after_play(self, *args, **kwargs):
        pass

In [5]:
class Cheater(Player):
    def __init__(self):
        """
        A cheater class 
        """
        super().__init__()
        self.chance: float = 0.5 + (random.random() / 2)
        self.chance_history: list[float] = []
        self.suspect_history: list[float] = []

In [6]:
def get_player_type(player: Player | Cheater):
    return type(player).__name__


def encode_player_type(player: Player | Cheater):
    match player:
        case Cheater():
            return 1
        case Player():
            return 0
        case _:
            return type(player).__name__

In [7]:
class Detective(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.optimizer = torch.optim.Adam
        self.optim_parameters = {
            "lr": 0.1
        }
        self.loss_fn = torch.nn.BCELoss()
    
    
    def post_init(self):
        self.optimizer = self.optimizer(self.parameters(), **self.optim_parameters)
    
    
    def step(self, pred, true):
        self.train()
        self.optimizer.zero_grad()
        loss = self.loss_fn(pred, true)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.parameters(), 50)
        self.optimizer.step()

In [8]:
class ConvDetective(Detective):
    def __init__(self):
        super().__init__()
        self.upsample = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=1, out_channels=16, kernel_size=1),
            torch.nn.ReLU(),
            torch.nn.Flatten(),
        )
        self.conv = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=1, out_channels=2, kernel_size=4),
            torch.nn.ReLU(),
            torch.nn.Conv1d(in_channels=2, out_channels=4, kernel_size=4),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool1d(output_size=1),
        )
        self.linear = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(in_features=4, out_features=8),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=8, out_features=1),
            torch.nn.Sigmoid()
        )
        
        self.post_init()


    def forward(self, x):
        x = torch.tensor(pad([p for p in x], -1), dtype=torch.float32, device=device)
        x = x.unsqueeze(1)
        x = self.upsample(x).unsqueeze(1)
        x = self.conv(x)
        x = self.linear(x)
        return x.squeeze(1)

In [9]:
class RnnDetective(Detective):
    def __init__(self):
        super().__init__()
        self.rnn = torch.nn.RNN(input_size=1, hidden_size=4)
        
        self.linear = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=4, out_features=8),
            torch.nn.ReLU(),
            torch.nn.Linear(in_features=8, out_features=1),
            torch.nn.Sigmoid(),
        )
        
        self.post_init()


    def forward(self, x):
        x = torch.tensor(pad(x, -1), dtype=torch.float32, device=device)
        x = x.unsqueeze(2)
        x, _ = self.rnn(x)
        x = x[:, -1]
        x = self.linear(x)
        return x.squeeze(1)

In [10]:
class Game:
    def __init__(self):
        """
        A game class
        """
        self.players: list[Player | Cheater] = []
        self.detectives: list[Detective] = [
            ConvDetective().to(device),
            RnnDetective().to(device),
        ]
        
        self.player_stats = {
            "Number of players": [0, 0],
            "Number of eliminated": [0, 0],
            "Playtime": [0, 0],
        }
        self.detective_stats = {
            "Number of parameters": [sum(p.numel() for p in d.parameters()) for d in self.detectives],
            "True positives": [0]*len(self.detectives),
            "True negatives": [0]*len(self.detectives),
            "False positives": [0]*len(self.detectives),
            "False negatives": [0]*len(self.detectives),
        }
        
        self.create_players()


    def create_players(self):
        for _ in range(NUMBER_OF_PLAYERS - len(self.players)):
            if random.random() < CHEATER_CHANCE:
                self.players.append(Cheater())
                self.player_stats["Number of players"][1] += 1
            else:
                self.players.append(Player())
                self.player_stats["Number of players"][0] += 1


    def play_round(self):
        tail_indexes: list[int] = []
        for i, player in enumerate(self.players):
            if not player.in_game:
                continue
            if player.throw() == 0:
                tail_indexes.append(i)
        return tail_indexes


    def play(self, *, silent=False):
        while len([0 for p in self.players if p.in_game]) > 1:
            losers = self.play_round()
            if len(losers) == len(self.players):
                continue
            for i in reversed(losers):
                self.players[i].in_game = False
        self.after_play(silent=silent)


    def after_play(self, *, silent=False):
        for player in self.players:
            self.player_stats["Playtime"][encode_player_type(player)] += 1
        
        x_train, y_train = [], []
        
        gave_up_indexes = []
        for i, _ in enumerate(self.players):
            if random.random() < GIVE_UP_CHANCE:
                gave_up_indexes.append(i)
        for i in reversed(gave_up_indexes):
            elimed_player = self.players.pop(i)
            x_train.append(elimed_player.history)
            y_train.append(encode_player_type(elimed_player))
            if not silent:
                print(f"A {get_player_type(elimed_player)} gave up!")
        
        scores = torch.empty((len(self.players), len(self.detectives)))
        for i, detective in enumerate(self.detectives):
            detective.eval()
            scores[:, i] = detective([p.history for p in self.players])
        
        elimed_idxs = []
        for i, (player, player_scores) in enumerate(zip(self.players.copy(), scores)):
            encoded_ptype = encode_player_type(player)
            for j, player_score in enumerate(player_scores):
                match (int(player_score > 0.5), encoded_ptype):
                    case (1, 1):
                        self.detective_stats["True positives"][j] += 1
                    case (0, 0):
                        self.detective_stats["True negatives"][j] += 1
                    case (1, 0):
                        self.detective_stats["False positives"][j] += 1
                    case (0, 1):
                        self.detective_stats["False negatives"][j] += 1
            
            if player_scores.mean() > 0.5:
                elimed_idxs.append(i)
        
        for i in reversed(elimed_idxs):
            elimed_player = self.players.pop(i)
            x_train.append(elimed_player.history)
            y_train.append(encode_player_type(elimed_player))
            if not silent:
                print(f"A {get_player_type(elimed_player)} got eliminated!")
            self.player_stats["Number of eliminated"][encode_player_type(elimed_player)] += 1
        
        y_train = torch.tensor(y_train, dtype=torch.float32, device=device)
        if x_train:
            for detective in self.detectives:
                outputs = detective(x_train)
                
                detective.step(outputs, y_train)
        
        for player in self.players:
            player.in_game = True
        self.create_players()
    
    
    def show_player_stats(self):
        heads = ["Players", "Cheaters"]
        keys, values = [list(x) for x in zip(*self.player_stats.items())]
        keys.append("Average playtime")
        playtime_idx = keys.index("Playtime")
        keys.pop(playtime_idx)
        values.append([round(y/x, 3) for x, y in zip(values[0], values[2])])
        values.pop(playtime_idx)
        
        width = max(max(map(len, keys)), max(map(len, heads))) + 2
        sep_row = "".join(["+", "-"*(3*width+2), "+"])
        
        print(sep_row)

        print("|", " "*width, "|", "|".join([x.center(width) for x in heads]), "|", sep="")

        print(sep_row)

        for k, vs in zip(keys, values):
            print("|", k.center(width),
                  "|", str(vs[0]).center(width), "|", str(vs[1]).center(width), "|", sep="")

        print(sep_row)
        
        print("|", "Softmax".center(len(sep_row)-2), "|", sep="")

        print(sep_row)

        for k, vs in zip(keys, [softmax(v, round_dec=3) for v in values]):
            print("|", k.center(width),
                  "|", str(vs[0]).center(width), "|", str(vs[1]).center(width), "|", sep="")

        print(sep_row)
    
    
    def show_detective_stats(self):
        heads = [type(d).__name__ for d in self.detectives]
        keys, values = [list(x) for x in zip(*self.detective_stats.items())]
        
        keys.insert(1, "Number eliminated")
        values.insert(1, [x+y for x, y in zip(values[1], values[3])])
        
        keys.insert(2, "Number passed")
        values.insert(2, [x+y for x, y in zip(values[3], values[5])])
        
        keys.insert(3, "Number correct")
        values.insert(3, [x+y for x, y in zip(values[3], values[4])])
        
        keys.insert(4, "Number wrong")
        values.insert(4, [x+y for x, y in zip(values[6], values[7])])
        
        
        width = max(max(map(len, heads)), max(map(len, keys))) + 2
        sep_row = "".join(["+", "-"*(3*width+2), "+"])
        
        print(sep_row)
        
        print("|", " "*width, "|", "|".join(h.center(width) for h in heads), "|", sep="")
        
        print(sep_row)
        
        for k, vs in zip(keys, values):
            print("|", k.center(width), "|", "|".join(str(x).center(width) for x in vs), "|", sep="")
        
        print(sep_row)
        
        print("|", "Softmax".center(len(sep_row)-2), "|", sep="")

        print(sep_row)

        for k, vs in zip(keys, [softmax(v, round_dec=3) for v in values]):
            print("|", k.center(width),
                  "|", str(vs[0]).center(width), "|", str(vs[1]).center(width), "|", sep="")

        print(sep_row)
        
    
    def show_stats(self):
        self.show_player_stats()
        self.show_detective_stats()


game = Game()

In [11]:
num_plays = 1e3
for _ in range(int(num_plays)):
    game.play(silent=True)

In [12]:
game.show_stats()

+--------------------------------------------------------------------+
|                      |       Players        |       Cheaters       |
+--------------------------------------------------------------------+
|  Number of players   |         2056         |         1447         |
| Number of eliminated |         1124         |         778          |
|   Average playtime   |        4.609         |        4.508         |
+--------------------------------------------------------------------+
|                              Softmax                               |
+--------------------------------------------------------------------+
|  Number of players   |        0.587         |        0.413         |
| Number of eliminated |        0.591         |        0.409         |
|   Average playtime   |        0.506         |        0.494         |
+--------------------------------------------------------------------+
+--------------------------------------------------------------------+
|     