This notebook is a shortcut for the No thanks program. The policy-ONLY network is trained ONCE on UCT Players' data. 

In [1]:
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import random
from dataclasses import dataclass, field

import time
from math import log, sqrt
from collections import deque

import sys
import json
from multiprocess import Pool


In [2]:
## Progress bar

def log_progress(sequence, every=None, size=None, name='Items'):
    from ipywidgets import IntProgress, HTML, VBox
    from IPython.display import display

    is_iterator = False
    if size is None:
        try:
            size = len(sequence)
        except TypeError:
            is_iterator = True
    if size is not None:
        if every is None:
            if size <= 200:
                every = 1
            else:
                every = int(size / 200)     # every 0.5%
    else:
        assert every is not None, 'sequence is iterator, set every'

    if is_iterator:
        progress = IntProgress(min=0, max=1, value=1)
        progress.bar_style = 'info'
    else:
        progress = IntProgress(min=0, max=size, value=0)
    label = HTML()
    box = VBox(children=[label, progress])
    display(box)

    index = 0
    try:
        for index, record in enumerate(sequence, 1):
            if index == 1 or index % every == 0:
                if is_iterator:
                    label.value = '{name}: {index} / ?'.format(
                        name=name,
                        index=index
                    )
                else:
                    progress.value = index
                    label.value = u'{name}: {index} / {size}'.format(
                        name=name,
                        index=index,
                        size=size
                    )
            yield record
    except:
        progress.bar_style = 'danger'
        raise
    else:
        progress.bar_style = 'success'
        progress.value = index
        label.value = "{name}: {index}".format(
            name=name,
            index=str(index or '?')
        )

In [3]:
## Network

class PolicyOnlyNet(nn.Module):
    def __init__(self, n_players, hidden_dim):
        super(PolicyOnlyNet, self).__init__()

        # CNN Branch for processing matrix M
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(8),  # Batch Norm after Conv
            nn.ReLU(),
            nn.Conv2d(in_channels=8, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),  # Batch Norm after Conv
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4, 4)),  # Downsampling to a fixed size
            nn.Flatten(),
            nn.Linear(32 * 4 * 4, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Fully-connected Branch for processing vector b
        self.fc_branch = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Main Branch
        self.main_branch = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim * 2),
            nn.BatchNorm1d(hidden_dim * 2),  # Batch Norm after Linear
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),  # Batch Norm after Linear
            nn.ReLU()
        )

        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim), 
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()  # Output in range [0, 1]
        )

        # Value Head
        # self.value_head = nn.Sequential(
        #     nn.Linear(hidden_dim, hidden_dim),
        #     nn.BatchNorm1d(hidden_dim), 
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, hidden_dim),
        #     nn.BatchNorm1d(hidden_dim), 
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, hidden_dim),
        #     nn.BatchNorm1d(hidden_dim), 
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, 1),
        #     nn.Sigmoid()   # Output n-dim vector in range [0, 1]
        # )

    def forward(self, M, b):
        M = self.cnn_branch(M)
        b = self.fc_branch(b)
        x = torch.cat((M, b), dim=-1)
        x = self.main_branch(x)
        p = self.policy_head(x)
        # value = self.value_head(x)
        # print(f"Debug: Policy Head Output (p): {p}")  # Debugging output
        return p

def train_nn(model, batch_size, n_players, hidden_dim, input_state, target_policy):
    M, b = input_state
    M = torch.tensor(M, dtype=torch.float32)
    b = torch.tensor(b, dtype=torch.float32)
    target_policy = torch.tensor(target_policy, dtype=torch.float32).view(-1, 1)
    # target_value = torch.tensor(target_value, dtype=torch.float32).view(-1, 1)


    # Define optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
    weights = torch.where(target_policy == 1., torch.tensor(2.0), torch.tensor(1.0))  # Adjust this value as needed
    policy_loss_fn = nn.BCELoss(weight=weights)  # Binary Cross-Entropy Loss with positive weight
    # value_loss_fn = nn.MSELoss()

    # Training loop (few steps)
    for step in range(10):
        optimizer.zero_grad()

        # Forward pass
        pred_policy = model(M, b)

        # Compute policy loss
        policy_loss = policy_loss_fn(pred_policy, target_policy)
    
        # Compute value loss
        # value_loss = value_loss_fn(pred_value, target_value)

        # Compute L2 regularization term (sum of all parameters' squared values)
        l2_lambda = 1e-4
        l2_penalty = sum(w.pow(2.0).sum() for w in model.parameters())

        # Combine losses
        loss = policy_loss + l2_lambda * l2_penalty

        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
        optimizer.step()

        if step == 9:
            print(f"Loss = {loss.item():.4f}, "
                  f"Policy Loss = {policy_loss.item():.4f}"
                  )

In [4]:
## Game

ACTION_TAKE = 0
ACTION_PASS = 1

random.seed(999)

def diff(first, second):
    second = set(second)
    return [item for item in first if item not in second]

@dataclass
class NoThanksConfig:
    min_card: int = 3
    max_card: int = 35
    n_omit_cards: int = 9
    n_players: int = 3
    start_coins: int = field(init=False)

    def __post_init__(self):
        self.start_coins = self.calculate_start_coins()

    def calculate_start_coins(self):
        if 3 <= self.n_players <= 5:
            return 11
        elif self.n_players == 6:
            return 9
        elif self.n_players == 7:
            return 7
        else:
            raise ValueError("Number of players must be between 3 and 7")

class NoThanksBoard():
    def __init__(self, n_players = 3, config = NoThanksConfig(n_players=3)):
        self.n_players = n_players
        self.min_card = config.min_card
        self.max_card = config.max_card
        self.full_deck = list(range(self.min_card, self.max_card+1))
        self.n_omit_cards = config.n_omit_cards
        self.n_cards = self.max_card - self.min_card + 1
        self.start_coins = config.start_coins
        random.seed(999)
    
    def reward_dict(self):
        if self.n_players == 3:
            return {1: 1, 2: 0, 3: -1}
        elif self.n_players == 4:
            return {1: 1, 2: 0.5, 3: -0.5, 4: -1}
        elif self.n_players == 5:
            return {1: 1, 2: 0.5, 3: 0, 4: -0.5, 5: -1}
        elif self.n_players == 6:
            return {1: 1, 2: 0.75, 3: 0.5, 4: -0.5, 5: -0.75, 6: -1}
        elif self.n_players == 7:
            return {1: 1, 2: 0.75, 3: 0.5, 4: 0, 5: -0.5, 6: -0.75, 7: -1}

            
    # state: ((player coins),(player cards),(card in play, coins in play, n_cards_remaining, current player))
    def starting_state(self, current_player = 0):
        coins = [self.start_coins for i in range(self.n_players)]
        cards = [[] for i in range(self.n_players)]

        card_in_play = random.choice(self.full_deck)
        
        coins_in_play = 0
        n_cards_in_deck = self.n_cards - 1 - self.n_omit_cards

        return coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)

    def next_state(self, state, action):

        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if action == ACTION_TAKE:
            cards[current_player].append(card_in_play)
            coins[current_player] += coins_in_play

            all_player_cards = [card for player_cards in cards for card in player_cards]
            cards_in_deck = diff(self.full_deck, all_player_cards)
            current_player = current_player
            
            if cards_in_deck and n_cards_in_deck > 0:   
                # random.shuffle(list(cards_in_deck))
                card_in_play = random.choice(cards_in_deck)
                n_cards_in_deck -= 1
            else:
                card_in_play = None
            coins_in_play = 0

        else:
            coins[current_player] -= 1
            coins_in_play += 1
            current_player += 1
        
        if current_player == self.n_players:
            current_player = 0

        next_state = coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)
        return self.pack_state(next_state)
    
    def all_possible_next(self, state, action):
        if action == ACTION_PASS:
            return self.next_state(state, action)
        elif action == ACTION_TAKE:
            next_states = []
            state = self.unpack_state(state)
            coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state
            cards[current_player].append(card_in_play)
            coins[current_player] += coins_in_play
            n_cards_in_deck -= 1
            coins_in_play = 0

            all_player_cards = [card for player_cards in cards for card in player_cards]
            cards_in_deck = diff(self.full_deck, all_player_cards)
            current_player = current_player
            
            if not cards_in_deck:
                return self.next_state(state, action)
            else:
                for card in cards_in_deck: 
                    card_in_play = card
                    if current_player == self.n_players:
                        current_player = 0
                    next_state = coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player)
                    next_state = self.pack_state(next_state)
                    next_states.append(next_state)
            
            return next_states

    def is_legal(self, state, action):
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if card_in_play is None:
            return False
        if coins[current_player] <= 0 and action == ACTION_PASS:
            return False
        else:
            return True

    def legal_actions(self, state):
        actions = []
        
        if self.is_legal(state, ACTION_TAKE):
            actions.append(ACTION_TAKE)

        if self.is_legal(state, ACTION_PASS):
            actions.append(ACTION_PASS)

        return actions

    def pack_state(self, state):
        coins, cards, details = state
        packed_state = tuple(coins), tuple(map(tuple, cards)), details
        return packed_state

    def unpack_state(self, packed_state):
        coins, cards, details = packed_state
        coins = list(coins)
        cards = list(map(list, cards))
        return coins, cards, details
    

    def standard_state(self, state):
        """
        Input state (packed or unpacked): ([coins], [[cards]], (card_in_play, coins_in_play, n_cards_in_deck, current_player))
        Transform state into the required format:
        1. Extract the state into M and b where M is the card/coin matrix and b is the vector (card_in_play, coins_in_play, n_cards_in_deck)
        2. Rotate M such that the first row corresponds to the current player
        3. Transform M into array of shape (3, n_players, 33)
        """
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state
        
        # Step 1: Build the M matrix (n_players x 34)
        M = []
        for k in range(self.n_players):
            # Initialize the card representation for player k
            card_rep = [0] * 33  # 33 cards

            # Set 1 for each card player k has
            for card in cards[k]:
                card_rep[card - 3] = 1  # Cards are indexed from 3 to 35

            # Add the number of coins for player k
            M.append(card_rep + [coins[k]])  # 33 card columns + 1 coin column
        
        # Step 2: Build the b vector (card_in_play, coins_in_play, n_cards_in_deck)
        b = np.array([card_in_play, coins_in_play, n_cards_in_deck])

        # Step 3: Rotate M such that the first row is the current player
        M_rotated = M[current_player:] + M[:current_player]  # Rotate the matrix
        
        # Step 4: Transform M into array of shape (3, n_players, 33)
        # where M[0] is the card matrix, M[1] is the coin matrix, M[2] is the card in play
        M_transformed = np.zeros((3, self.n_players, 33))
        M_rotated = np.array(M_rotated)
        M_transformed[0] = M_rotated[:, :-1]  # Card matrix
        M_transformed[1] = np.repeat(M_rotated[:, -1][:, np.newaxis], 33, axis=1)  # Coin matrix
        M_transformed[1] = M_transformed[1] / (self.start_coins * self.n_players)  # Normalize coins to be between 0 and 1
        
        card_in_play_onehot = np.zeros(33)
        if 3 <= card_in_play <= 35:
            card_in_play_onehot[card_in_play - 3] = 1
        M_transformed[2] = np.tile(card_in_play_onehot, (self.n_players, 1))

        return M_transformed, b


    def is_ended(self, state):
        # print(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if n_cards_in_deck == 0 and card_in_play == None:
            return True
        else:
            return False

    def compute_scores(self, state):
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        scores = []

        for p_idx in range(self.n_players):
            cards[p_idx].sort()

            score = 0
            if cards[p_idx]:
                score += cards[p_idx][0]
                last_card = cards[p_idx][0]

                for card_idx in range(1, len(cards[p_idx])):
                    new_card = cards[p_idx][card_idx]

                    if not new_card == last_card + 1:
                        score += new_card
                    last_card = new_card

            score -= coins[p_idx]

            scores.append(score)

        return scores

    def winner(self, state):
        """Temporary winner: player with the lowest score even if the game is not ended."""
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        if not self.is_ended(state):
            return None
        
        scores = self.compute_scores(state)
        min_score = 1000
        lowest_scorers = []
        # get lowest scorers (could be more than one)
        for i, score in enumerate(scores):
            if score < min_score:
                lowest_scorers = [i]
                min_score = score
            if score <= min_score:
                lowest_scorers.append(i)
        
        # if players are tied on lowest score, get the one with the fewest cards
        if len(lowest_scorers) > 1:
            min_n_cards = 1000
            for i in lowest_scorers:
                n_cards = len(cards[i])
                if n_cards < min_n_cards:
                    lowest_card_players = [i]
                    min_n_cards = n_cards
                elif n_cards <= min_n_cards:
                    lowest_card_players.append(i)

            if len(lowest_card_players) > 1:
                winner = lowest_card_players[0]
            else: # if still tied, pick a random winner (not the official rules)
                winner = random.choice(lowest_card_players) 
        else:
            winner = lowest_scorers[0]

        return winner
    
    def reward_rank(self, state):
        state = self.unpack_state(state)
        scores = self.compute_scores(state)
        rank = sorted(range(len(scores)), key=lambda k: scores[k])
        value = [self.reward_dict()[rank.index(player) + 1] for player in range(self.n_players)]
        return np.array(value)
    
    def reward_winloss(self, state):
        state = self.unpack_state(state)
        value = [1 if self.winner(state) == player else -1 for player in range(self.n_players)]
        return np.array(value)

    def reward_score(self, state):
        state = self.unpack_state(state)
        scores = self.compute_scores(state)
        rewards = [-score for score in scores]

        return np.array(rewards)

    
    def basic_display_state(self, state):
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        print("Coins:           {0}".format(coins))
        print("Cards:           {0}".format(cards))
        print("Card in play:    {0}".format(card_in_play))
        print("Coins:           {0}".format(coins_in_play))
        print("Player:          {0}".format(current_player))

    def display_scores(self, state):
        scores = self.compute_scores(state)
        print("")
        print("--- Scores ---")
        for i in range(self.n_players):
            print("Player {0}: {1}".format(i, scores[i]))
        print("")

    def display_state(self, state, human_player=None):
        state = self.unpack_state(state)
        coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = state

        scores = self.compute_scores(state)

        def format_cards(card_list):
            return ", ".join(map(str, sorted(card_list)))

        player_labels = [f"Player {i}" + (" (You)" if i == human_player else "") for i in range(self.n_players)]
        card_strings = [format_cards(cards[i]) for i in range(self.n_players)]
        coin_strings = [str(coins[i]) for i in range(self.n_players)]
        score_strings = [str(scores[i]) for i in range(self.n_players)]

        max_card_len = max(20, max(len(card_str) for card_str in card_strings))

        print("")
        print("-" * (15 + max_card_len + 10 + 10 + 10))
        print("")
        print("{:<15} {:<{}} {:<10} {:<10}".format("Player", "Cards", max_card_len, "Coins", "Score"))
        print("-" * (15 + max_card_len + 10 + 10 + 10))

        for i in range(self.n_players):
            print("{:<15} {:<{}} {:<10} {:<10}".format(
                player_labels[i],
                card_strings[i],
                max_card_len,
                coin_strings[i],
                score_strings[i]
            ))

        print("-" * (15 + max_card_len + 10 + 10 + 10))
        print("\t\t In play: [{0}]".format(card_in_play))
        print("\t\t Cards remaining: {0}".format(n_cards_in_deck))
        print("\t\t   Coins: {0}".format(coins_in_play))
        print("")
            
    def pack_action(self, notation):
        if notation == "y" or notation == "Y":
            return ACTION_TAKE
        else:
            return ACTION_PASS

    def current_player(self, state):
        return state[2][3]
    

In [5]:
## Players

class Player:
    """The abstract class for a player. A player can be an AI agent (bot) or human."""
    def __init__(self, name, game, turn):
        self.name = name
        self.game = game
        self.turn = turn # starting form 0 as convention in python
        assert self.turn < self.game.n_players, "Player turn out of range."

    def get_action(self, state):
        raise NotImplementedError
    
class UCTPlayer(Player):
    """Monte Carlo Tree Search Player (UCT, no prior)"""
    def __init__(self, game, thinking_time=1, turn=0):
        assert thinking_time > 0
        self.turn = turn
        self.game = game
        self.thinking_time = thinking_time
        self.max_moves = 200
        self.C = 1.4  # Exploration parameter for UCB1
        self.max_depth = 0

    def get_action(self, state):
        board = self.game
        player = board.current_player(state)
        legal_actions = board.legal_actions(state)
        
        if not legal_actions:
            return None, None
        if len(legal_actions) == 1:
            return legal_actions[0], 0
        
        # Initialize visit and win counts
        plays, wins = {}, {}
        games = 0
        start_time = time.perf_counter()

        # Run MCTS for the specified thinking time
        while time.perf_counter() - start_time < self.thinking_time:
            self.run_simulation(state, board, plays, wins)
            games += 1

        # Choose the best action based on win rate
        random.shuffle(legal_actions)
        action = max(
            legal_actions,
            key=lambda a: plays.get((player, state, a), 1)
        )
        # print("UCT:", 0, wins.get((player, state, 0), 0) / plays.get((player, state, 0), 1))
        # print("UCT:", 1, wins.get((player, state, 1), 0) / plays.get((player, state, 1), 1))
        # print("Player:", player, "Max depth searched:", self.max_depth, "Games played:", games)
        return action, wins.get((player, state, action), 0) / plays.get((player, state, action), 1)

    def run_simulation(self, state, board, plays, wins):
        """Run a single MCTS simulation."""
        tree = set()
        player = board.current_player(state)

        # === Selection & Expansion ===
        for t in range(1, self.max_moves + 1):
            legal_actions = board.legal_actions(state)

            # Selection using UCB1 if data exists for all actions
            if all(plays.get((player, state, a)) for a in legal_actions):
                log_total = log(sum(plays[(player, state, a)] for a in legal_actions))
                action = max(
                    legal_actions,
                    key=lambda a: (
                        wins[(player, state, a)] / plays[(player, state, a)] +
                        self.C * sqrt(log_total / plays[(player, state, a)])
                    )
                )
                
            else:
                # Expansion – If any action is unexplored, take a random one
                action = random.choice(legal_actions)
                if (player, state, action) not in plays:
                    plays[(player, state, action)] = 0
                    wins[(player, state, action)] = 0
                    if t > self.max_depth:
                        self.max_depth = t

            tree.add((player, state, action))
            state = board.next_state(state, action)
            player = board.current_player(state)

            # Check for game-ending state
            winner = board.winner(state)
            if winner is not None:
                break

        # === Backpropagation ===
        for player, state, action in tree:
            plays[(player, state, action)] += 1
            if player == winner:
                wins[(player, state, action)] += 1
    
class PUCTPlayer(Player):
    """Monte Carlo Tree Search Player (Prior given by NN)"""
    def __init__(self, game, thinking_time=1, turn=0):
        assert thinking_time > 0
        self.turn = turn
        self.game = game
        self.thinking_time = thinking_time
        self.max_moves = 200
        self.C = 4  # Exploration parameter for PUCT
        self.max_depth = 0
        self.prior = lambda state, action: 1 / len(self.game.legal_actions(state))  # Default prior
        self.value = None

    def get_action(self, state):
        board = self.game
        player = board.current_player(state)
        legal_actions = board.legal_actions(state)
        
        if not legal_actions:
            return None, None
        if len(legal_actions) == 1:
            return legal_actions[0], 0
        
        # Initialize visit and win counts
        plays, wins = {}, {}
        games = 0
        start_time = time.perf_counter()

        # Run MCTS for the specified thinking time
        # while time.perf_counter() - start_time < self.thinking_time:
        #     self.run_simulation(state, board, plays, wins)
        #     games += 1

        for _ in range(1000):
            self.run_simulation(state, board, plays, wins)
            games += 1

        # Choose the best action based on win rate
        random.shuffle(legal_actions)
        action = max(
            legal_actions,
            key=lambda a: ( # choose the action with highest visits
                plays.get((player, state, a), 1)
            )
        )

        # print("PUCT:", 0, wins.get((player, state, 0), 0) / plays.get((player, state, 0), 1))
        # print("PUCT:", 1, wins.get((player, state, 1), 0) / plays.get((player, state, 1), 1))
        # print("Player:", player, "Max depth searched:", self.max_depth, "Games played:", games)
        return action, wins.get((player, state, action), 0) / plays.get((player, state, action), 1)

    def run_simulation(self, state, board, plays, wins):
        """Run a single MCTS simulation."""
        tree = set()
        player = board.current_player(state)

        # === Selection & Expansion ===
        for t in range(1, self.max_moves + 1):
            legal_actions = board.legal_actions(state)

            # Selection using UCB1 if data exists for all actions
            if all(plays.get((player, state, a)) for a in legal_actions):
                total = sum(plays[(player, state, a)] for a in legal_actions)
                # for a in legal_actions:
                #     print("Check Total:", a, wins[(player, state, a)] / plays[(player, state, a)])
                action = max(
                    legal_actions,
                    key=lambda a: (
                        wins[(player, state, a)] / plays[(player, state, a)] +
                        self.C * self.prior(state, a) * sqrt(total / plays[(player, state, a)])
                    )
                )
            else:
                # Expansion – If any action is unexplored, take a random one
                action = random.choice(legal_actions)
                if (player, state, action) not in plays:
                    plays[(player, state, action)] = 0
                    wins[(player, state, action)] = 0
                    if t > self.max_depth:
                        self.max_depth = t

            tree.add((player, state, action)) # trajectory
            state = board.next_state(state, action)
            player = board.current_player(state)

            # Check for game-ending state
            winner = board.winner(state)
            if winner is not None:
                break

        # === Backpropagation ===
        for player, state, action in tree:
            plays[(player, state, action)] += 1
            if player == winner:
                wins[(player, state, action)] += 1

In [6]:
def self_play(game, players, times=1, to_file=None):
    data = {"state": [], "policy": []}
    for _ in log_progress(range(times)):
        state = game.starting_state(current_player=0)
        state = game.pack_state(state)
        current_player = 0

        while not game.is_ended(state):
            player = players[current_player]
            action, _ = player.get_action(state)
            state = game.next_state(state, action)
            coins, cards, (card_in_play, coins_in_play, n_cards_in_deck, current_player) = game.unpack_state(state)
            if card_in_play is not None:
                data["state"].append(game.standard_state(state))  # Append (M, b) as NumPy arrays
                data["policy"].append(action)
                # data["value"].append(score)

    # Convert NumPy arrays in "state" to lists for JSON serialization
    if to_file is not None:
        serializable_data = {
            "state": [(M.tolist(), b.tolist()) for M, b in data["state"]],
            "policy": data["policy"]
            # "value": data["value"]
        }
        with open(to_file, "w") as f:
            json.dump(serializable_data, f)

    return data

def parallel_self_play(args):
    game, players, times, file_prefix, process_id = args
    # time.sleep(0.5)
    # print(f"Process {process_id} started.")
    return self_play(game, players, times, to_file=f"{file_prefix}_process{process_id}.json")

def parallel_self_play_nosave(args):
    game, players, times, process_id = args
    time.sleep(0.25)
    print(f"Process {process_id} started.")
    return self_play(game, players, times)


def rl_train_nosave(from_file=None, num_processes=4, num_games=4, prior=False):
    game = NoThanksBoard(n_players=3)
    Player_0 = UCTPlayer(game=game, turn=0)
    Player_1 = UCTPlayer(game=game, turn=1)
    Player_2 = UCTPlayer(game=game, turn=2)
    players = [Player_0, Player_1, Player_2]

    batch_size = 32
    n_players = 3
    model = PolicyOnlyNet(n_players, hidden_dim=128)

    if prior == True:
        model.load_state_dict(torch.load('policy_only_net.pth', weights_only=True))


    # Parallelize self_play
    games_per_process = num_games // num_processes
    args = [(game, players, games_per_process, process_id) for process_id in range(num_processes)]

    # Use Pool to collect data from all processes
    with Pool(num_processes) as pool:
        results = pool.map(parallel_self_play_nosave, args)

    # Combine results from all processes
    data = {"state": [], "policy": []}
    for result in results:
        data["state"].extend(result["state"])
        data["policy"].extend(result["policy"])
        # data["value"].extend(result["value"])

    # Combine and shuffle the data
    combined_data = list(zip(data["state"], data["policy"]))
    random.shuffle(combined_data)
    data["state"], data["policy"] = zip(*combined_data)

    states = data["state"]
    target_policy = np.array(data["policy"])
    # target_value = np.array(data["value"])
    print("Training Data prepared.")
    num_samples = len(states)
    num_batches = num_samples // batch_size
    print(f"Total samples: {num_samples}, Batches: {num_batches}")
        
    for batch_idx in range(num_batches):
        # Extract batch data
        batch_start = batch_idx * batch_size
        batch_end = batch_start + batch_size
        
        batch_states = states[batch_start:batch_end]
        batch_policies = target_policy[batch_start:batch_end]
        # batch_values = target_value[batch_start:batch_end]
        
        # Reshape states into (M, b) format
        M = np.array([s[0] for s in batch_states])
        b = np.array([np.array(s[1], dtype=np.float32) for s in batch_states], dtype=np.float32)

        # Train the model on the batch
        model.train()
        print(f"Training batch {batch_idx + 1}/{num_batches}")
        train_nn(model, batch_size, n_players, 128, (M, b), batch_policies)
            
            # Update prior
            # model.eval()
            # for player in players:
            #     player.prior = lambda state, action: (
            #         model(
            #             torch.tensor(game.standard_state(state)[0], dtype=torch.float32).unsqueeze(0),
            #             torch.tensor(game.standard_state(state)[1], dtype=torch.float32).unsqueeze(0)
            #         )[0].item() if action == ACTION_PASS else 1 - model(
            #             torch.tensor(game.standard_state(state)[0], dtype=torch.float32).unsqueeze(0),
            #             torch.tensor(game.standard_state(state)[1], dtype=torch.float32).unsqueeze(0)
            #         )[0].item()
            #     )

        torch.save(model.state_dict(), f'policy_only_net.pth')
    return model

In [None]:
rl_train_nosave(from_file=None, num_processes=4, num_games=800)

Process 0 started.
Process 1 started.
Process 2 started.
Process 3 started.


VBox(children=(HTML(value=''), IntProgress(value=0, max=200)))

VBox(children=(HTML(value=''), IntProgress(value=0, max=200)))

VBox(children=(HTML(value=''), IntProgress(value=0, max=200)))

VBox(children=(HTML(value=''), IntProgress(value=0, max=200)))