In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

import random
from dataclasses import dataclass

import time
from math import log, sqrt
from collections import defaultdict
from abc import ABC, abstractmethod
from itertools import chain

import sys
import json
import copy

from tqdm import tqdm

In [2]:
## Progress bar

def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

bar = tqdm # tqdm if in terminal

In [3]:
## Network

class PolicyValueNet(nn.Module):
    def __init__(self, n_players, hidden_dim):
        super(PolicyValueNet, self).__init__()

        # CNN Branch for processing matrix M
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),  # Batch Norm after Conv
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),  # Batch Norm after Conv
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4)),  # Downsampling to a fixed size
            nn.Flatten(),
            nn.Linear(32 * 4 * 4, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Fully-connected Branch for processing vector b
        self.fc_branch = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Main Branch
        self.main_branch = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),  # Batch Norm after Linear
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output in range [0, 1]
        )

        # Value Head
        self.value_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()   # Output n-dim vector in range [0, 1]
        )

    def forward(self, M, b):
        M = self.cnn_branch(M)
        b = self.fc_branch(b)
        x = torch.cat((M, b), dim=-1)
        x = self.main_branch(x)
        p = self.policy_head(x)
        value = self.value_head(x)
        # print(f"Debug: Policy Head Output (p): {p}")  # Debugging output
        return p, value

def train_nn(model, batch_size, n_players, hidden_dim, input_state, target_policy, target_value):
    M, b = input_state
    M = torch.tensor(M, dtype=torch.float32)
    b = torch.tensor(b, dtype=torch.float32)
    target_policy = torch.tensor(target_policy, dtype=torch.float32).view(-1, 1)
    target_value = torch.tensor(target_value, dtype=torch.float32).view(-1, 1)


    # Define optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    weights = torch.where(target_policy == 1., torch.tensor(1.1), torch.tensor(1.0))  # Adjust this value as needed
    policy_loss_fn = nn.BCELoss(weight=weights)  # Binary Cross-Entropy Loss with positive weight
    value_loss_fn = nn.MSELoss()

    # Training loop (few steps)
    for step in range(10):
        optimizer.zero_grad()

        # Forward pass
        pred_policy, pred_value = model(M, b)

        # Compute policy loss
        policy_loss = policy_loss_fn(pred_policy, target_policy)
    
        # Compute value loss
        value_loss = value_loss_fn(pred_value, target_value)

        # Compute L2 regularization term (sum of all parameters' squared values)
        l2_lambda = 1e-4
        l2_penalty = sum(w.pow(2.0).sum() for w in model.parameters())

        # Combine losses
        loss = policy_loss + value_loss*0.01 + l2_lambda * l2_penalty

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
        optimizer.step()

        if step == 9:
            print(f"Loss = {loss.item():.4f}, "
                  f"Policy Loss = {policy_loss.item():.4f}, "
                  f"Value Loss = {value_loss.item():.4f}")

In [4]:
## Game

ACTION_TAKE = 0
ACTION_PASS = 1

def diff(first, second):
    second = set(second)
    return [item for item in first if item not in second]

@dataclass
class NoThanksConfig:
    min_card: int = 3
    max_card: int = 14
    n_omit_cards: int = 3
    n_players: int = 3
    start_coins: int = 4


class NoThanksBoard():
    def __init__(self, n_players = 3, config = NoThanksConfig):
        self.n_players = n_players
        self.min_card = config.min_card
        self.max_card = config.max_card
        self.full_deck = list(range(self.min_card, self.max_card+1))
        self.n_omit_cards = config.n_omit_cards
        self.n_cards = self.max_card - self.min_card + 1
        self.start_coins = config.start_coins
        random.seed(999)

            
    # state: ((player coins),(player cards),(card in play, coins in play, n_cards_remaining, current player))
    def starting_state(self, current_player = 0):
        coins = [self.start_coins for i in range(self.n_players)]
        cards = [[] for i in range(self.n_players)]

        card_in_play = random.choice(self.full_deck)
        
        coins_in_play = 0
        n_cards_in_deck = self.n_cards - 1 - self.n_omit_cards

        return coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)

    def next_state(self, state, action):

        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if action == ACTION_TAKE:
            cards[current_player].append(card_in_play)
            coins[current_player] += coins_in_play

            all_player_cards = [card for player_cards in cards for card in player_cards]
            cards_in_deck = diff(self.full_deck, all_player_cards)
            current_player = current_player
            
            if cards_in_deck and n_cards_in_deck > 0:   
                # random.shuffle(list(cards_in_deck))
                card_in_play = random.choice(cards_in_deck)
                n_cards_in_deck -= 1
            else:
                card_in_play = None
            coins_in_play = 0

        else:
            coins[current_player] -= 1
            coins_in_play += 1
            current_player += 1
        
        if current_player == self.n_players:
            current_player = 0

        next_state = coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)
        return self.pack_state(next_state)
    
    def all_possible_next(self, state, action):
        if action == ACTION_PASS:
            return self.next_state(state, action)
        elif action == ACTION_TAKE:
            next_states = []
            state = self.unpack_state(state)
            coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state
            cards[current_player].append(card_in_play)
            coins[current_player] += coins_in_play
            n_cards_in_deck -= 1
            coins_in_play = 0

            all_player_cards = [card for player_cards in cards for card in player_cards]
            cards_in_deck = diff(self.full_deck, all_player_cards)
            current_player = current_player
            
            if not cards_in_deck:
                return self.next_state(state, action)
            else:
                for card in cards_in_deck: 
                    card_in_play = card
                    if current_player == self.n_players:
                        current_player = 0
                    next_state = coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)
                    next_state = self.pack_state(next_state)
                    next_states.append(next_state)
            
            return next_states

    def is_legal(self, state, action):
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if card_in_play is None:
            return False
        if coins[current_player] <= 0 and action == ACTION_PASS:
            return False
        else:
            return True

    def legal_actions(self, state):
        actions = []
        
        if self.is_legal(state, ACTION_TAKE):
            actions.append(ACTION_TAKE)

        if self.is_legal(state, ACTION_PASS):
            actions.append(ACTION_PASS)

        return actions

    def pack_state(self, state):
        coins, cards, details = state
        packed_state = tuple(coins), tuple(map(tuple, cards)), details
        return packed_state

    def unpack_state(self, packed_state):
        coins, cards, details = packed_state
        coins = list(coins)
        cards = list(map(list, cards))
        return coins, cards, details
    

    def standard_state(self, state):
        """
        Input state (packed or unpacked): ([coins], [[cards]], (card_in_play, coins_in_play, n_cards_in_deck, current_player))
        Transform state into the required format:
        1. Extract the state into M and b where M is the card/coin matrix and b is the vector (card_in_play, coins_in_play, n_cards_in_deck)
        2. Rotate M such that the first row corresponds to the current player
        3. Transform M into array of shape (3, n_players, 33)
        """
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state
        
        # Step 1: Build the M matrix (n_players x 34)
        M = []
        for k in range(self.n_players):
            # Initialize the card representation for player k
            card_rep = [0] * self.n_cards  # 33 cards

            # Set 1 for each card player k has
            for card in cards[k]:
                card_rep[card - self.min_card] = 1  # Cards are indexed from 3 to 35

            # Add the number of coins for player k
            M.append(card_rep + [coins[k]])  # 33 card columns + 1 coin column
        
        # Step 2: Build the b vector (card_in_play, coins_in_play, n_cards_in_deck)
        b = np.array([card_in_play, coins_in_play, n_cards_in_deck])

        # Step 3: Rotate M such that the first row is the current player
        M_rotated = M[current_player:] + M[:current_player]  # Rotate the matrix
        
        # Step 4: Transform M into array of shape (3, n_players, 33)
        # where M[0] is the card matrix, M[1] is the coin matrix, M[2] is the card in play
        M_transformed = np.zeros((3, self.n_players, self.n_cards))
        M_rotated = np.array(M_rotated)
        M_transformed[0] = M_rotated[:, :-1]  # Card matrix
        M_transformed[1] = np.repeat(M_rotated[:, -1][:, np.newaxis], self.n_cards, axis=1)  # Coin matrix
        M_transformed[1] = M_transformed[1] / (self.start_coins * self.n_players)  # Normalize coins to be between 0 and 1
        
        card_in_play_onehot = np.zeros(self.n_cards)
        if self.min_card <= card_in_play <= self.max_card:
            card_in_play_onehot[card_in_play - self.min_card] = 1
        M_transformed[2] = np.tile(card_in_play_onehot, (self.n_players, 1))

        return M_transformed, b


    def is_ended(self, state):
        # print(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if n_cards_in_deck == 0 and card_in_play == None:
            return True
        else:
            return False

    def compute_scores(self, state):
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        scores = []

        for p_idx in range(self.n_players):
            cards[p_idx].sort()

            score = 0
            if cards[p_idx]:
                score += cards[p_idx][0]
                last_card = cards[p_idx][0]

                for card_idx in range(1, len(cards[p_idx])):
                    new_card = cards[p_idx][card_idx]

                    if not new_card == last_card + 1:
                        score += new_card
                    last_card = new_card

            score -= coins[p_idx]

            scores.append(score)

        return scores

    def winner(self, state):
        """Temporary winner: player with the lowest score even if the game is not ended."""
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if not self.is_ended(state):
            return None
        
        scores = self.compute_scores(state)
        min_score = 1000
        lowest_scorers = []
        # get lowest scorers (could be more than one)
        for i, score in enumerate(scores):
            if score < min_score:
                lowest_scorers = [i]
                min_score = score
            if score <= min_score:
                lowest_scorers.append(i)
        
        # if players are tied on lowest score, get the one with the fewest cards
        if len(lowest_scorers) > 1:
            min_n_cards = 1000
            for i in lowest_scorers:
                n_cards = len(cards[i])
                if n_cards < min_n_cards:
                    lowest_card_players = [i]
                    min_n_cards = n_cards
                elif n_cards <= min_n_cards:
                    lowest_card_players.append(i)

            if len(lowest_card_players) > 1:
                winner = lowest_card_players[0]
            else: # if still tied, pick a random winner (not the official rules)
                winner = random.choice(lowest_card_players) 
        else:
            winner = lowest_scorers[0]

        return winner

    
    def basic_display_state(self, state):
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        print("Coins:           {0}".format(coins))
        print("Cards:           {0}".format(cards))
        print("Card in play:    {0}".format(card_in_play))
        print("Coins:           {0}".format(coins_in_play))
        print("Player:          {0}".format(current_player))

    def display_scores(self, state, players):
        scores = self.compute_scores(state)
        print("")
        print("--- Scores ---")
        for player in players:
            print("{:<10} {:<10}".format(
                player.name, scores[player.turn])
            )
        print("")

    def display_state(self, state, players):
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        scores = self.compute_scores(state)

        def format_cards(card_list):
            return ", ".join(map(str, sorted(card_list)))

        player_labels = [player.name for player in players]
        card_strings = [format_cards(cards[i]) for i in range(self.n_players)]
        coin_strings = [str(coins[i]) for i in range(self.n_players)]
        score_strings = [str(scores[i]) for i in range(self.n_players)]

        max_card_len = max(20, max(len(card_str) for card_str in card_strings))

        print("")
        print("-" * (15 + max_card_len + 10 + 10 + 10))
        print("")
        print("{:<15} {:<{}} {:<10} {:<10}".format("Player", "Cards", max_card_len, "Coins", "Score"))
        print("-" * (15 + max_card_len + 10 + 10 + 10))

        for i in range(self.n_players):
            print("{:<15} {:<{}} {:<10} {:<10}".format(
                player_labels[i],
                card_strings[i],
                max_card_len,
                coin_strings[i],
                score_strings[i]
            ))

        print("-" * (15 + max_card_len + 10 + 10 + 10))
        print("\t\t In play: [{0}]".format(card_in_play))
        print("\t\t Cards remaining: {0}".format(n_cards_in_deck))
        print("\t\t   Coins: {0}".format(coins_in_play))
        print("")
            
    def pack_action(self, notation):
        if notation == "y" or notation == "Y":
            return ACTION_TAKE
        else:
            return ACTION_PASS

    def current_player(self, state):
        return state[2][3]
    
    def remaining_cards(self, state):
        s = self.unpack(state)

        # All cards in original deck
        full_deck = set(self.initial_deck)  # This should be fixed when game starts

        # Cards already taken by players
        taken_cards = set()
        for player_cards in s["player_cards"]:
            taken_cards.update(player_cards)

        # Card currently in play
        if s["card"] is not None:
            taken_cards.add(s["card"])

        # Cards already revealed (optional: e.g., discard pile if any)
        # taken_cards.update(s.get("discarded_cards", []))

        # Remaining = full_deck - taken
        remaining = full_deck - taken_cards
        return list(remaining)
    

In [5]:
## Players - UCT and PUCT Players [Open Loop MCTS]

class Player(ABC):
    """The abstract class for a player. A player can be an AI agent (bot) or human."""
    def __init__(self, game, turn):
        self.name = "Player " + str(turn)
        self.game = game
        self.turn = turn  # starting form 0 as convention in python
        assert self.turn < self.game.n_players, "Player turn out of range."

    @abstractmethod
    def get_action(self, state):
        pass


class BaseMCTSPlayer(Player, ABC):
    def __init__(self, game, turn, thinking_time=1, simNum=0, max_moves=200):
        super().__init__(game, turn)
        self.thinking_time = thinking_time
        self.simNum = simNum
        self.max_moves = max_moves
        self.max_depth = 0

    @abstractmethod
    def get_action(self, state):
        pass

    def score(self, state, player, legal_actions, plays, wins):
        total_ply = sum(plays[("decision", player, state, a)] for a in legal_actions)
        if total_ply == 0:
            return 0
        score = 0
        for action in legal_actions:
            key = ("decision", player, state, action)
            if plays[key]:
                score += (wins[key] / plays[key]) * (plays[key] / total_ply)
        return score


class UCTPlayer(BaseMCTSPlayer):
    def __init__(self, game, turn=0, thinking_time=1, simNum=0):
        super().__init__(game, turn, thinking_time, simNum)
        self.C = 1.4  # Exploration parameter

    def get_action(self, state):
        board = self.game
        player = board.current_player(state)
        legal_actions = board.legal_actions(state)

        if not legal_actions:
            return None, None
        if len(legal_actions) == 1:
            return legal_actions[0], 0

        plays = defaultdict(int)
        wins = defaultdict(int)
        games = 0

        if self.thinking_time > 0 and self.simNum == 0:
            start_time = time.perf_counter()
            while time.perf_counter() - start_time < self.thinking_time:
                self.run_simulation(state, board, plays, wins)
                games += 1
        else:
            for _ in range(self.simNum):
                self.run_simulation(state, board, plays, wins)
                games += 1

        random.shuffle(legal_actions)
        action = max(
            legal_actions,
            key=lambda a: plays.get((player, state, a), 0)
        )

        return action, self.score(state, player, legal_actions, plays, wins)

    def run_simulation(self, state, board, plays, wins):
        """Run a single MCTS simulation."""
        tree = set()
        player = board.current_player(state)

        # === Selection & Expansion ===
        for t in range(1, self.max_moves + 1):
            legal_actions = board.legal_actions(state)

            # Selection using UCB1 if data exists for all actions
            if all(plays.get((player, state, a)) for a in legal_actions):
                log_total = log(sum(plays[(player, state, a)] for a in legal_actions))
                action = max(
                    legal_actions,
                    key=lambda a: (
                        wins[(player, state, a)] / plays[(player, state, a)] +
                        self.C * sqrt(log_total / plays[(player, state, a)])
                    )
                )
                
            else:
                # Expansion – If any action is unexplored, take a random one
                action = random.choice(legal_actions)
                if (player, state, action) not in plays:
                    plays[(player, state, action)] = 0
                    wins[(player, state, action)] = 0
                    if t > self.max_depth:
                        self.max_depth = t

            tree.add((player, state, action))
            state = board.next_state(state, action)
            player = board.current_player(state)

            # Check for game-ending state
            winner = board.winner(state)
            if winner is not None:
                break

        # === Backpropagation ===
        for player, state, action in tree:
            plays[(player, state, action)] += 1
            if player == winner:
                wins[(player, state, action)] += 1


class PUCTPlayer(BaseMCTSPlayer):
    def __init__(self, game, turn=0, thinking_time=1, simNum=0):
        super().__init__(game, turn, thinking_time, simNum)
        self.C = 1.5 # c_puct exploration parameter
        self.prior = lambda state, action: 1 / len(self.game.legal_actions(state))
        self.value = None

    def get_action(self, state):
        board = self.game
        player = board.current_player(state)
        legal_actions = board.legal_actions(state)

        if not legal_actions:
            return None, None
        if len(legal_actions) == 1:
            return legal_actions[0], 0

        plays = defaultdict(int)
        wins = defaultdict(int)
        games = 0

        if self.thinking_time > 0 and self.simNum == 0:
            start_time = time.perf_counter()
            while time.perf_counter() - start_time < self.thinking_time:
                self.run_simulation(state, board, plays, wins)
                games += 1
        else:
            for _ in range(self.simNum):
                self.run_simulation(state, board, plays, wins)
                games += 1

        random.shuffle(legal_actions)
        action = max(
            legal_actions,
            key=lambda a: plays.get((player, state, a), 0)
        )

        return action, self.score(state, player, legal_actions, plays, wins)

    def run_simulation(self, state, board, plays, wins):
        """Run a single MCTS simulation."""
        tree = set()
        player = board.current_player(state)

        # === Selection & Expansion ===
        for t in range(1, self.max_moves + 1):
            legal_actions = board.legal_actions(state)

            # Selection using UCB1 if data exists for all actions
            if all(plays.get((player, state, a)) for a in legal_actions):
                total = sum(plays[(player, state, a)] for a in legal_actions)
                # for a in legal_actions:
                #     print("Check Total:", a, wins[(player, state, a)] / plays[(player, state, a)])
                action = max(
                    legal_actions,
                    key=lambda a: (
                        wins[(player, state, a)] / plays[(player, state, a)] +
                        self.C * self.prior(state, a) * sqrt(total / plays[(player, state, a)])
                    )
                )
            else:
                # Expansion – If any action is unexplored, take a random one
                action = random.choice(legal_actions)
                if (player, state, action) not in plays:
                    plays[(player, state, action)] = 0
                    wins[(player, state, action)] = 0
                    if t > self.max_depth:
                        self.max_depth = t

            tree.add((player, state, action)) # trajectory
            state = board.next_state(state, action)
            player = board.current_player(state)

            # Check for game-ending state
            winner = board.winner(state)
            if winner is not None:
                break

        # === Backpropagation ===
        for player, state, action in tree:
            plays[(player, state, action)] += 1
            if player == winner:
                wins[(player, state, action)] += 1

In [6]:
## Utils - prior functions

def nn_prior_fn(model, game):
    def nn_prior(state, action):
        with torch.no_grad():
            M, b = game.standard_state(state)
            M_tensor = torch.tensor(M, dtype=torch.float32).unsqueeze(0)
            b_tensor = torch.tensor(b, dtype=torch.float32).unsqueeze(0)
            policy = model(M_tensor, b_tensor)[0]
            prob = policy.item()
            return (prob if action == ACTION_PASS else 1 - prob)
    return nn_prior

def smart_prior_fn(game, p=0.98):
    def smart_prior(state, action):
        state = game.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state 
        other_cards = [i for i in list(chain.from_iterable(cards)) if i not in cards[current_player]]
        good_for_me = any(abs(card_in_play - card) < 2 for card in cards[current_player])
        good_for_them = any(abs(card_in_play - card) < 2 for card in other_cards)
        least_chip = min(coins)
        legal_actions = game.legal_actions(state)

        if good_for_me:
            if good_for_them:
                good_action = ACTION_TAKE
            else:
                good_action = ACTION_TAKE
                if least_chip > 2:
                    good_action = ACTION_PASS
        else:
            good_action = ACTION_PASS
            if coins[current_player] < 2 or abs(coins_in_play - card_in_play) < min(3, card_in_play//2):
                good_action = ACTION_TAKE

        if action not in legal_actions:
            return 0  # Invalid action for the state
        return p if action == good_action else (1 - p)
    
    return smart_prior


In [7]:
## Utils - play and performance evaluation

def play(game, players, display=True):    
    players.sort(key=lambda x: x.turn)
    current_player = players[0].turn
    
    state = game.starting_state(current_player=current_player)
    state = game.pack_state(state)

    while not game.is_ended(state):
        player = players[current_player]
        action, score = player.get_action(state)
        state = game.next_state(state, action)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = game.unpack_state(state)
        if display:
            game.display_state(state, players)
        
        # print(game.standard_state(state))  
    winner = game.winner(state)
    if display:
        game.display_scores(state, players)
        print("Game ended. Player", winner, "wins!")

    return winner

def eval_performance(game, target_player, opponents, num_games=100, verbose=False):
    random.seed(time.time())
    target_player.name = "Target"
    players = [target_player] + opponents
    win = defaultdict(int)
    for i in bar(range(num_games)):
        target_player.turn = i % len(players)
        for j, player in enumerate(players):
            if player != target_player:
                player.turn = (i + j) % len(players)
            # print(f"Game {i+1}: Player {player.name} turn: {player.turn}")
        winner = play(game, players, display=False)
        win[players[winner].name] += 1
        if verbose and i in [150, 180, 210, 240, 270]:
            print(f"Number of wins for each player: {win}")

    print(f"Number of wins for each player: {win}")
    winrate = win[target_player.name] / num_games

    return winrate

In [8]:
## Training process functions

def self_play(game, players, times=1, to_file=None, smart=False):
    if smart:
        for player in players:
            player.prior = smart_prior_fn(game)

    data = {"state": [], "policy": [], "value": []}
    for _ in bar(range(times)):
        state = game.starting_state(current_player=0)
        state = game.pack_state(state)
        current_player = 0

        while not game.is_ended(state):
            player = players[current_player]
            action, score = player.get_action(state)
            state = game.next_state(state, action)
            coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = game.unpack_state(state)
            if card_in_play is not None:
                data["state"].append(game.standard_state(state))  # Append (M, b) as NumPy arrays
                data["policy"].append(action)
                data["value"].append(score)

    # Convert NumPy arrays in "state" to lists for JSON serialization
    if to_file is not None:
        serializable_data = {
            "state": [(M.tolist(), b.tolist()) for M, b in data["state"]],
            "policy": data["policy"],
            "value": data["value"]
        }
        with open(to_file, "w") as f:
            json.dump(serializable_data, f)

    return data

def rl_train(rounds=10, num_games=4, simNum=300, prior=False, ctd_from=0):
    game = NoThanksBoard(n_players=3)
    Player_0 = PUCTPlayer(game=game, turn=0, simNum=simNum)
    Player_1 = PUCTPlayer(game=game, turn=1, simNum=simNum)
    Player_2 = PUCTPlayer(game=game, turn=2, simNum=simNum)
    players = [Player_0, Player_1, Player_2]

    tester = PUCTPlayer(game=game, turn=0, simNum=simNum)

    batch_size = 512
    n_players = 3
    model = PolicyValueNet(n_players, hidden_dim=32)

    if prior == True:
        model.load_state_dict(torch.load(f'policy_value_net_rd{ctd_from}.pth', weights_only=True))
        model.eval()
        for player in players:
            player.prior = nn_prior_fn(model, game)
            player.simNum = simNum
        print(f"Model {ctd_from} loaded...Priors are updated for players.")

    for i in range(rounds):
        print(f"Round {i}: The bots are playing...")

        # smart = True if i == 0 else False
        smart = False
        data = self_play(game, players, times=num_games, smart=smart)
        
        # Combine and shuffle the data
        combined_data = list(zip(data["state"], data["policy"], data["value"]))
        random.shuffle(combined_data)
        data["state"], data["policy"], data["value"] = zip(*combined_data)

        states = data["state"]
        target_policy = np.array(data["policy"])
        target_value = np.array(data["value"])
        print("Training Data prepared.")
        num_samples = len(states)
        num_batches = num_samples // batch_size
        print(f"Total samples: {num_samples}, Batches: {num_batches}")

        backup_model = copy.deepcopy(model.state_dict())
            
        for batch_idx in range(num_batches):
            # Extract batch data
            batch_start = batch_idx * batch_size
            batch_end = batch_start + batch_size
            
            batch_states = states[batch_start:batch_end]
            batch_policies = target_policy[batch_start:batch_end]
            batch_values = target_value[batch_start:batch_end]
            
            # Reshape states into (M, b) format
            M = np.array([s[0] for s in batch_states])
            b = np.array([np.array(s[1], dtype=np.float32) for s in batch_states], dtype=np.float32)

            # Train the model on the batch
            model.train()
            print(f"Training batch {batch_idx + 1}/{num_batches}")
            train_nn(model, batch_size, n_players, 32, (M, b), batch_policies, batch_values)

        # Save the model
        torch.save(model.state_dict(), f'policy_value_net_rd{i+ctd_from}.pth')

        print(f"Round {i} completed. Evaluating performance...")
        
        # Update prior for the tester
        model.eval()
        tester.prior = nn_prior_fn(model, game)

        # Evaluate the performance of the trained model
        # need about 300 games to reach stable estimate of winrate
        winrate = eval_performance(game, tester, players[1:], num_games=150, verbose=False)
        if winrate > 0.4:
            print(f"Winrate of the Target Player: {winrate:.2%}; Model is accepted.")
            for player in players:
                player.prior = nn_prior_fn(model, game)
        else:
            print(f"Winrate of the Target Player: {winrate:.2%}; Model is rejected.")
            model.load_state_dict(backup_model)


In [9]:
## Training

rl_train(rounds=5, num_games=600, simNum=300, prior=False, ctd_from=0)

Round 0: The bots are playing...


100%|██████████| 600/600 [13:46<00:00,  1.38s/it]


Training Data prepared.
Total samples: 20960, Batches: 40
Training batch 1/40
Loss = 0.7456, Policy Loss = 0.6948, Value Loss = 0.2132
Training batch 2/40
Loss = 0.7048, Policy Loss = 0.6541, Value Loss = 0.2105
Training batch 3/40
Loss = 0.6879, Policy Loss = 0.6372, Value Loss = 0.2077
Training batch 4/40
Loss = 0.6603, Policy Loss = 0.6098, Value Loss = 0.2031
Training batch 5/40
Loss = 0.6314, Policy Loss = 0.5809, Value Loss = 0.2019
Training batch 6/40
Loss = 0.6091, Policy Loss = 0.5588, Value Loss = 0.1968
Training batch 7/40
Loss = 0.6067, Policy Loss = 0.5564, Value Loss = 0.1937
Training batch 8/40
Loss = 0.5684, Policy Loss = 0.5182, Value Loss = 0.1904
Training batch 9/40
Loss = 0.5607, Policy Loss = 0.5106, Value Loss = 0.1855
Training batch 10/40
Loss = 0.5134, Policy Loss = 0.4634, Value Loss = 0.1823
Training batch 11/40
Loss = 0.5095, Policy Loss = 0.4596, Value Loss = 0.1808
Training batch 12/40
Loss = 0.4901, Policy Loss = 0.4402, Value Loss = 0.1759
Training batch 

100%|██████████| 150/150 [49:49<00:00, 19.93s/it]


Number of wins for each player: defaultdict(<class 'int'>, {'Player 1': 65, 'Player 2': 84, 'Target': 1})
Winrate of the Target Player: 0.67%; Model is rejected.
Round 1: The bots are playing...


100%|██████████| 600/600 [14:16<00:00,  1.43s/it]


Training Data prepared.
Total samples: 20808, Batches: 40
Training batch 1/40
Loss = 0.7689, Policy Loss = 0.7181, Value Loss = 0.2138
Training batch 2/40
Loss = 0.7282, Policy Loss = 0.6775, Value Loss = 0.2095
Training batch 3/40
Loss = 0.6957, Policy Loss = 0.6451, Value Loss = 0.2066
Training batch 4/40
Loss = 0.6396, Policy Loss = 0.5890, Value Loss = 0.2020
Training batch 5/40
Loss = 0.6367, Policy Loss = 0.5863, Value Loss = 0.1999
Training batch 6/40
Loss = 0.6256, Policy Loss = 0.5752, Value Loss = 0.1948
Training batch 7/40
Loss = 0.5910, Policy Loss = 0.5407, Value Loss = 0.1916
Training batch 8/40
Loss = 0.5536, Policy Loss = 0.5034, Value Loss = 0.1894
Training batch 9/40
Loss = 0.5333, Policy Loss = 0.4832, Value Loss = 0.1849
Training batch 10/40
Loss = 0.5650, Policy Loss = 0.5150, Value Loss = 0.1820
Training batch 11/40
Loss = 0.4994, Policy Loss = 0.4494, Value Loss = 0.1787
Training batch 12/40
Loss = 0.4796, Policy Loss = 0.4297, Value Loss = 0.1764
Training batch 

100%|██████████| 150/150 [49:44<00:00, 19.90s/it]


Number of wins for each player: defaultdict(<class 'int'>, {'Player 1': 66, 'Player 2': 83, 'Target': 1})
Winrate of the Target Player: 0.67%; Model is rejected.
Round 2: The bots are playing...


100%|██████████| 600/600 [14:19<00:00,  1.43s/it]


Training Data prepared.
Total samples: 20725, Batches: 40
Training batch 1/40
Loss = 0.7536, Policy Loss = 0.7028, Value Loss = 0.2141
Training batch 2/40
Loss = 0.7363, Policy Loss = 0.6856, Value Loss = 0.2089
Training batch 3/40
Loss = 0.6935, Policy Loss = 0.6429, Value Loss = 0.2057
Training batch 4/40
Loss = 0.6565, Policy Loss = 0.6060, Value Loss = 0.2015
Training batch 5/40
Loss = 0.6249, Policy Loss = 0.5745, Value Loss = 0.1977
Training batch 6/40
Loss = 0.6104, Policy Loss = 0.5601, Value Loss = 0.1949
Training batch 7/40
Loss = 0.6048, Policy Loss = 0.5545, Value Loss = 0.1916
Training batch 8/40
Loss = 0.5719, Policy Loss = 0.5217, Value Loss = 0.1880
Training batch 9/40
Loss = 0.5447, Policy Loss = 0.4946, Value Loss = 0.1857
Training batch 10/40
Loss = 0.5160, Policy Loss = 0.4660, Value Loss = 0.1812
Training batch 11/40
Loss = 0.5076, Policy Loss = 0.4577, Value Loss = 0.1775
Training batch 12/40
Loss = 0.4936, Policy Loss = 0.4437, Value Loss = 0.1720
Training batch 

100%|██████████| 150/150 [50:07<00:00, 20.05s/it]


Number of wins for each player: defaultdict(<class 'int'>, {'Player 1': 68, 'Player 2': 82})
Winrate of the Target Player: 0.00%; Model is rejected.
Round 3: The bots are playing...


100%|██████████| 600/600 [14:41<00:00,  1.47s/it]


Training Data prepared.
Total samples: 21268, Batches: 41
Training batch 1/41
Loss = 0.7637, Policy Loss = 0.7129, Value Loss = 0.2131
Training batch 2/41
Loss = 0.7096, Policy Loss = 0.6589, Value Loss = 0.2092
Training batch 3/41
Loss = 0.6869, Policy Loss = 0.6363, Value Loss = 0.2072
Training batch 4/41
Loss = 0.6550, Policy Loss = 0.6045, Value Loss = 0.2020
Training batch 5/41
Loss = 0.6485, Policy Loss = 0.5980, Value Loss = 0.1988
Training batch 6/41
Loss = 0.6114, Policy Loss = 0.5611, Value Loss = 0.1954
Training batch 7/41
Loss = 0.5996, Policy Loss = 0.5493, Value Loss = 0.1930
Training batch 8/41
Loss = 0.5798, Policy Loss = 0.5296, Value Loss = 0.1886
Training batch 9/41
Loss = 0.5299, Policy Loss = 0.4798, Value Loss = 0.1853
Training batch 10/41
Loss = 0.5207, Policy Loss = 0.4707, Value Loss = 0.1813
Training batch 11/41
Loss = 0.4980, Policy Loss = 0.4480, Value Loss = 0.1792
Training batch 12/41
Loss = 0.4833, Policy Loss = 0.4335, Value Loss = 0.1752
Training batch 

100%|██████████| 150/150 [50:53<00:00, 20.36s/it]


Number of wins for each player: defaultdict(<class 'int'>, {'Player 1': 71, 'Player 2': 79})
Winrate of the Target Player: 0.00%; Model is rejected.
Round 4: The bots are playing...


100%|██████████| 600/600 [14:44<00:00,  1.47s/it]


Training Data prepared.
Total samples: 21239, Batches: 41
Training batch 1/41
Loss = 0.7510, Policy Loss = 0.7002, Value Loss = 0.2145
Training batch 2/41
Loss = 0.7230, Policy Loss = 0.6723, Value Loss = 0.2099
Training batch 3/41
Loss = 0.7009, Policy Loss = 0.6503, Value Loss = 0.2060
Training batch 4/41
Loss = 0.6665, Policy Loss = 0.6159, Value Loss = 0.2025
Training batch 5/41
Loss = 0.6307, Policy Loss = 0.5802, Value Loss = 0.1996
Training batch 6/41
Loss = 0.6085, Policy Loss = 0.5582, Value Loss = 0.1956
Training batch 7/41
Loss = 0.6054, Policy Loss = 0.5552, Value Loss = 0.1916
Training batch 8/41
Loss = 0.5762, Policy Loss = 0.5260, Value Loss = 0.1891
Training batch 9/41
Loss = 0.5524, Policy Loss = 0.5024, Value Loss = 0.1842
Training batch 10/41
Loss = 0.5384, Policy Loss = 0.4884, Value Loss = 0.1797
Training batch 11/41
Loss = 0.5031, Policy Loss = 0.4531, Value Loss = 0.1780
Training batch 12/41
Loss = 0.5049, Policy Loss = 0.4551, Value Loss = 0.1747
Training batch 

100%|██████████| 150/150 [50:52<00:00, 20.35s/it]

Number of wins for each player: defaultdict(<class 'int'>, {'Player 1': 68, 'Player 2': 81, 'Target': 1})
Winrate of the Target Player: 0.67%; Model is rejected.





In [10]:
# ## Evaluate Perfromance

# game = NoThanksBoard(n_players = 3)
# Player_0 = PUCTPlayer(game=game, turn=0, simNum=500)
# Player_1 = PUCTPlayer(game=game, turn=1, simNum=500)
# Player_2 = PUCTPlayer(game, turn=2, simNum=500)

# # model = PolicyValueNet(game.n_players, 32)
# # model.load_state_dict(torch.load('policy_value_net_rd0.pth', weights_only=True))
# # model.eval()

# # model = PolicyOnlyNet(game.n_players, 128)
# # model.load_state_dict(torch.load('policy_only_net.pth'))
# # model.eval()

# # new_prior = smart_prior_fn(game)
# # new_prior = nn_prior_fn(model, game)
# # Player_0.prior = smart_prior_fn(game)
# # Player_1.prior = new_prior
# # Player_2.prior = new_prior

# players = [Player_0, Player_1, Player_2]

# # rl_train(rounds=1, num_games=2, simNum=2000, prior=False, ctd_from=0)

# # play(game, players, display=True)

# winrate = eval_performance(game, Player_0, [Player_1, Player_2], num_games=300, verbose=True)

# print(f"Winrate of the Target Player: {winrate:.2%}")
