In [57]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
from collections import defaultdict
import os
from datetime import datetime
import time
from functools import lru_cache
from multiprocessing import Pool
import concurrent.futures

from utils import State, Action, is_terminal, change_state, get_all_valid_actions, terminal_utility, is_valid_action, invert

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#we can stack arbitrary number of Residualblock
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)  # Skip connection

#if using batchnorm, bias=False, can add additional FC/conv layer also
class UTTT(nn.Module):
    def __init__(self, residual=True, batchnorm=True):
        super(UTTT, self).__init__()
        
        self.channels = 8
        self.residual = residual
        self.fc_size1 = 128
        self.batchnorm = batchnorm
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, self.channels, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn1 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.conv2 = nn.Conv2d(self.channels, self.channels*2, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn2 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.conv3 = nn.Conv2d(self.channels*2, self.channels*2, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn3 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.res2 = ResidualBlock(self.channels) if self.residual else None
        self.res3 = ResidualBlock(self.channels) if self.residual else None

        # Fully connected layers
        self.fc1 = nn.Linear(self.channels*2 * 9 * 9, self.fc_size1)  # Input size: 128 filters × 9 × 9

        # Output layers
        self.pi = nn.Linear(self.fc_size1, 81)  # Policy output (Softmax for move probabilities)
        self.v = nn.Linear(self.fc_size1, 1)    # Value output (Tanh for game state evaluation)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x))) if self.batchnorm else F.relu(self.conv1(x))

        # Residual connections
        if self.residual:
            x = self.res2(x)
            x = self.res3(x)
        else:
            x = F.relu(self.bn2(self.conv2(x))) if self.batchnorm else F.relu(self.conv2(x))
            x = F.relu(self.bn3(self.conv3(x))) if self.batchnorm else F.relu(self.conv3(x))

        # Flatten and pass through FC layers
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        # Outputs
        pi = F.softmax(self.pi(x), dim=1)  # Policy head (logits, apply softmax in loss function)
        v = F.tanh(self.v(x))  # Value head (range [-1, 1])

        return pi, v

class _UTTT(nn.Module):
    def __init__(self, residual=True, batchnorm=True):
        super(_UTTT, self).__init__()
        
        self.channels = 8
        self.residual = residual
        self.fc_size1 = 128
        self.batchnorm = batchnorm
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, self.channels, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn1 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.conv2 = nn.Conv2d(self.channels, self.channels*2, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn2 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.conv3 = nn.Conv2d(self.channels*2, self.channels*2, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn3 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.res2 = ResidualBlock(self.channels) if self.residual else None
        self.res3 = ResidualBlock(self.channels) if self.residual else None

        # Fully connected layers
        self.fc1 = nn.Linear(self.channels*2 * 9 * 9, self.fc_size1)  # Input size: 128 filters × 9 × 9

        # Output layers
        self.v = nn.Linear(self.fc_size1, 1)    # Value output (Tanh for game state evaluation)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x))) if self.batchnorm else F.relu(self.conv1(x))

        # Residual connections
        if self.residual:
            x = self.res2(x)
            x = self.res3(x)
        else:
            x = F.relu(self.bn2(self.conv2(x))) if self.batchnorm else F.relu(self.conv2(x))
            x = F.relu(self.bn3(self.conv3(x))) if self.batchnorm else F.relu(self.conv3(x))

        # Flatten and pass through FC layers
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))

        # Outputs
        v = F.tanh(self.v(x))  # Value head (range [-1, 1])

        return v
    
# Instantiate the model
model = UTTT(False,False)
model.to(device)

# Define the optimizer
def get_optimizer(model, lr=0.0001):
    return optim.Adam(model.parameters(), lr=lr)

# Define loss functions
def compute_policy_loss(predicted_policy, true_policy):
    #return nn.CrossEntropyLoss()(predicted_policy, true_policy)
    policy_loss_fn = torch.nn.CrossEntropyLoss()  # For policy (classification)
    return policy_loss_fn(predicted_policy, true_policy)

def compute_value_loss(predicted_value, true_value):
    #return nn.MSELoss()(predicted_value, true_value)
    value_loss_fn = torch.nn.MSELoss()  # For value (regression)
    return value_loss_fn(predicted_value, true_value)

# Print model summary
print(model)


UTTT(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=1296, out_features=128, bias=True)
  (pi): Linear(in_features=128, out_features=81, bias=True)
  (v): Linear(in_features=128, out_features=1, bias=True)
)


In [None]:
def board_to_array(state):
    """
    Convert the Ultimate Tic-Tac-Toe board state into a 9x9 array, making sure it's hashable.
    
    Args:
        state: The current game state (State object).
    
    Returns:
        A NumPy array (9, 9) representation of the board (with current player layer).
    """
    player = 1 if state.fill_num==1 else -1
    
    # Get the board from the state
    board = state.board  # This is a 3x3x3x3 ndarray

    # Flatten the 3x3x3x3 board into a 9x9 representation
    board_array = np.zeros((9, 9), dtype=np.float32)

    for i in range(3):
        for j in range(3):
            for k in range(3):
                for l in range(3):
                    board_array[i * 3 + k][j * 3 + l] = board[i][j][k][l]

    # Normalize the board values with flip
    board_array[board_array == 1] = 1.0 * player
    board_array[board_array == 2] = -1.0 * player
    board_array[board_array == 0] = 0.0

    return np.array(board_array)

def action_to_index(action):
    """
    Convert the action tuple (meta_row, meta_col, local_row, local_col) into an integer index for the 9x9 grid.

    Args:
        action: Tuple (meta_row, meta_col, local_row, local_col) representing the coordinates on the boards.
    
    Returns:
        Integer index corresponding to the 9x9 grid.
    """
    meta_row, meta_col, local_row, local_col = action
    
    # Map meta-board (meta_row, meta_col) to a flat index in the range [0, 8]
    meta_index = meta_row * 3 + meta_col  # 3x3 meta-board
    
    # Map local-board (local_row, local_col) to a flat index in the range [0, 8]
    local_index = local_row * 3 + local_col  # 3x3 local-board
    
    # Final index in the flattened 9x9 grid
    return meta_index * 9 + local_index  # 9x9 flattened grid

def index_to_action(index: int) -> Action:
    """
    Convert an action index back to the action tuple (meta_row, meta_col, local_row, local_col).
    
    Args:
        index (int): The action index.
        
    Returns:
        Action: The corresponding action tuple (meta_row, meta_col, local_row, local_col).
    """
    # Calculate the meta-row and meta-column from the index
    meta_row = index // 27
    meta_col = (index % 27) // 9
    
    # Calculate the local row and local column from the index
    local_row = (index % 9) // 3
    local_col = index % 3
    
    return (meta_row, meta_col, local_row, local_col)

def normalize_policy(policy):
    """
    Normalize the policy and ensure no NaN or infinite values.
    
    Args:
        policy: A numpy array containing the policy probabilities for each action.
        valid_mask: A mask array where valid actions are 1, invalid actions are 0.
    
    Returns:
        A normalized policy.
    """    
    # Check if any valid actions are left
    if np.sum(policy) > 0:
        policy = policy / np.sum(policy)  # Normalize
    else:
        # If no valid actions, reset to uniform distribution over valid actions
        policy = np.ones_like(policy) / len(policy)
    
    # Check for NaN or infinite values in the policy and reset if needed
    if np.any(np.isnan(policy)) or np.any(np.isinf(policy)):
        policy = np.ones_like(policy) / len(policy)  # Reset to uniform distribution
    
    return policy

def state_to_tensor(state):
    return torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)  # Add batch and channel dimensions


In [None]:
class Node:
    __slots__ = ('state', 'children', 'N', 'W', 'Q', 'P')
    def __init__(self, state):
        self.state = state
        #self.parent = parent
        #self.action = action
        #self.action_probability = None
        self.children = []
        
        self.N = 0
        #self.visit_count = 0
        
        self.W = 0
        #self.total_reward = 0 self.state_value_sum = 0.0
        
        self.Q = 0
        #self.q_value = 0 self.state_value_mean = 0.0
        
        self.P = None
        #self.prior_policy = None self.state_value = None

    def is_leaf(self):
        return len(self.children) == 0
    
    def is_terminal(self):
        return is_terminal(self.state)
    
    def no_valid_actions(self):
        return len(get_all_valid_actions(self.state)) == 0
    
    def best_child(self, cpuct): #cpuct=exploration_weight
        #return max(self.children, key=lambda child: child.W / child.N + cpuct * child.P * math.sqrt(math.log(self.N) / (1 + child.N)))
        #return max(self.children, key=lambda child: child.Q + cpuct * child.P * math.sqrt(self.N) / (1 + child.N))
        #values = np.array([child.W / child.N + cpuct * child.P * np.sqrt(np.log(self.N) / (1 + child.N)) for child in self.children])
        #values = np.array([child.Q + cpuct * child.P * np.sqrt(self.N) / (1 + child.N) for child in self.children])
        values = np.array([-child.Q + cpuct * child.P * np.sqrt(self.N) / (1 + child.N) for child in self.children])
        best_idx = np.argmax(values)
        return self.children[best_idx], best_idx
    
    def best_action(self, cpuct):
        return get_all_valid_actions(self.state)[self.best_child(cpuct)[1]]

    
class MCTS:
    __slots__ = ('model', 'cpuct', 'tree')
    def __init__(self, model, cpuct=2.0):
        self.model = model
        self.cpuct = cpuct
        self.tree = {}
    
    def search(self, root_state, simulations=200):
        root = self.get_or_create_node(root_state)
        
        if root.is_terminal() or root.no_valid_actions():
            return
        
        for _ in range(simulations):
            node = root
            path = [node]
            
            # Selection
            while not node.is_leaf():
                node = node.best_child(self.cpuct)[0]
                path.append(node)
            
            pi, v = self.model(self.get_state_tensor(node.state))
            v = v.item()
            
            # Expansion
            if not node.is_terminal() and not node.no_valid_actions():
                valid_actions = get_all_valid_actions(node.state)
                self.expand_node(node, valid_actions, pi)
            
            # Evaluation
            v = self.evaluate(node, v)
            
            # Backpropagation
            for node in reversed(path):
                node.N += 1
                node.W += v
                node.Q = node.W / node.N
                v *= -1
            #print(path)
        return root.best_action(self.cpuct)
    
    def expand_node(self, node, valid_actions, pi):
        valid_mask = np.zeros(81)
        action_idxs = [action_to_index(a) for a in valid_actions]
        valid_mask[action_idxs] = 1
        pi = pi.detach().cpu().numpy().reshape(81) * valid_mask  # Convert to numpy and apply valid mask
        pi = pi / np.sum(pi)  # Normalize the policy to sum to 1
        for action, action_idx in zip(valid_actions, action_idxs):
            child_state = change_state(node.state, action)
            child_node = Node(child_state)
            child_node.P = pi[action_idx]
            node.children.append(child_node)

    def evaluate(self, node, v):
        if is_terminal(node.state):
            #return 2*terminal_utility(node.state)-1
            #return 0 if 2*terminal_utility(node.state)-1==0 else 1
            return 0 if 2*terminal_utility(node.state)-1==0 else -1
            #remember v is utility to current player. so terminal state, current player to move loses so -1
            #otherwise, model predicts v which is utiilty for current player
            #then we need to use -child.Q since we are maximizing (from POV of current player)
            #i.e. current player wants to find most -ve child.Q since that is worst for child -> best for current
        return v
    
    def get_or_create_node(self, state):
        state_tuple = tuple(map(tuple, board_to_array(state)))
        if state_tuple not in self.tree:
            self.tree[state_tuple] = Node(state)
        return self.tree[state_tuple]
    
    #@lru_cache(maxsize=10000)
    def get_state_tensor(self, state):
        return torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)


In [None]:
#hyperparameters
train_episodes = 100
mcts_search = 200 #100 #400 #600
n_pit_network = 50 #20
threshold = 0.52 #0.50 #0.55
temperature = 0.05 #lower is more deterministic
playgames_before_training = 2 #4 #5 #25
parallel_games = 8
cpuct = 2
training_epochs = 4
learning_rate = 0.0001
save_model_path = 'training'


In [None]:
def get_action_probs(state, mcts, simulations=200):
    """
    Get the action probabilities from the MCTS search for a given board state.
    
    Args:
        state: The initial game state (State object).
        mcts: The MCTS instance used for search.
        simulations: Number of MCTS simulations to run.
    
    Returns:
        action_probs: A numpy array of size 81 representing the action probabilities.
    """
    # Perform MCTS search
    #best_action = mcts.search(state, simulations)
    mcts.search(state, simulations)
    
    print("Done one iteration of MCTS")
    
    # Initialize action probabilities
    action_probs = np.zeros(81)
    
    # Get visit counts for valid actions
    root_node = mcts.get_or_create_node(state)
    #child_Ns = {child.action: child.N for child in root_node.children}
    child_Ns = {get_all_valid_actions(root_node.state)[idx]: child.N for idx, child in enumerate(root_node.children)}
    
    if root_node.N > 0:
        for action, N in child_Ns.items():
            action_probs[action_to_index(action)] = N / root_node.N
    #print(state, action_probs)
    return action_probs

def playgame(mcts, simulations=200):
    """
    Simulate one game of Ultimate Tic-Tac-Toe, utilizing MCTS for decision-making.
    
    Args:
        mcts: The MCTS instance used for decision-making.
        simulations: Number of MCTS simulations per move.
    
    Returns:
        game_mem: A list of game memory (state, player, action probability, game result).
    """
    game_mem = []
    state = State()
    
    while True:
        #print(state)
        # Check if the game is over
        if len(get_all_valid_actions(state)) == 0 or is_terminal(state):
            #need to include terminal state? but what policy to put?
            #remember model predicts utility for current player, so
            #if result==1 (p1 win) and current player==1 then 1*1=1
            #if result==1 (p1 win) and current player==2 then 1*-1=-1
            #if result==-1 (p2 win) and current player==1 then -1*1=-1
            #if result==-1 (p2 win) and current player==2 then -1*-1=1
            result = 2*terminal_utility(state)-1
            for mem in game_mem:
                mem[3] = result * (1 if mem[1]==1 else -1)
            return game_mem

        # Get action probabilities using MCTS
        policy = get_action_probs(state, mcts, simulations)
        policy = policy / np.sum(policy)  # Normalize the policy
        #game_mem.append([board_to_array(state), state.fill_num, policy, None])
        game_mem.append([state, state.fill_num, policy, None])
        
        # Choose an action based on the policy
        action_index = np.random.choice(len(policy), p=policy)
        action = index_to_action(action_index)
        
        #print("Policy:", policy)
        #print("Chosen action:", action)
        
        state = change_state(state, action)

def playgames(mctss, num_games=8, simulations=200):
    """
    Run multiple games in parallel using batched MCTS.
    
    Args:
        mctss: List of MCTS instances
        num_games: Number of games to run in parallel.
        simulations: Number of MCTS simulations per move.
    
    Returns:
        games_mem: A list containing `num_games` game memory logs.
    """
    games_mem = [[] for _ in range(num_games)]
    states = [State() for _ in range(num_games)]
    active_games = np.ones(num_games, dtype=bool)  # Track which games are still running

    while active_games.any():
        # Apply actions and check game termination
        for i in range(num_games):
            if active_games[i]:
                if len(get_all_valid_actions(states[i])) == 0 or is_terminal(states[i]):
                    #remember model predicts utility for current player, so
                    #if result==1 (p1 win) and current player==1 then 1*1=1
                    #if result==1 (p1 win) and current player==2 then 1*-1=-1
                    #if result==-1 (p2 win) and current player==1 then -1*1=-1
                    #if result==-1 (p2 win) and current player==2 then -1*-1=1
                    result = 2*terminal_utility(states[i])-1
                    for mem in games_mem[i]:
                        mem[3] = result * (1 if mem[1]==1 else -1)
                    active_games[i] = False  # Mark game as finished
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_games) as executor:
            results = executor.map(lambda mcts,state: get_action_probs(state,mcts,simulations), mctss, states)
        
        policies = list(results)
        
        # Get action probabilities using batched MCTS
        #policies = get_action_probs_batch(states, mctss, simulations)
        #policies = np.nan_to_num([policy/np.sum(policy) for policy in policies], nan=1/81)
        
        # Sample actions using the policies
        #action_indices = np.array([np.random.choice(81, p=policy) for policy in policies])
        #actions = [index_to_action(int(idx)) for idx in action_indices]
        
        # Save state, player, policy in memory
        for i in range(num_games):
            if active_games[i]:
                policy = policies[i] / np.sum(policies[i])  # Normalize policy
                #games_mem[i].append([board_to_array(states[i]), states[i].fill_num, policies[i], None])
                #games_mem[i].append([board_to_array(states[i]), states[i].fill_num, policy, None])
                games_mem[i].append([states[i], states[i].fill_num, policy, None])
                action_index = np.random.choice(len(policy), p=policy)
                action = index_to_action(action_index)
                #states[i] = change_state(states[i], actions[i])
                states[i] = change_state(states[i], action)
    
    return games_mem

def test_playgame():
    mcts = MCTS(model, cpuct)
    start_time = time.time()
    p=[]
    for _ in range(parallel_games):
        p.append(playgame(mcts))
    end_time = time.time()
    print(end_time-start_time)
    
def test_playgames():
    mctss = [MCTS(model, cpuct) for _ in range(parallel_games)]
    start_time=time.time()
    ps=playgames(mctss,parallel_games)
    end_time=time.time()
    print(end_time-start_time)
    


In [None]:
def train_model(model, games_mem, epochs=4, batch_size=32, lr=0.0001):
    """
    Train the neural network using the game memory (states, policies, and results).
    
    Args:
        model: The neural network model to train.
        games_mem: A list of list of game memories, each containing [state_array, current_player, policy, result].
    
    Returns:
        model: The trained model.
    """
    print("Training Network")
    sumlen =sum([len(game_mem) for game_mem in games_mem])
    print("Length of game_mem:", sumlen)
    
    states = []
    policies = []
    values = []

    # Prepare the training data from game memory
    for game_mem in games_mem:
        for mem in game_mem:
            # Extract state, policy, and result (value)
            states.append(mem[0])  # mem[0] is the board state
            policies.append(mem[2])  # mem[2] is the action policy
            values.append(mem[3])   # mem[3] is the game result (1, 0, or -1)
        
    states = torch.stack([state_to_tensor(state) for state in states])
    policies = torch.tensor(policies, dtype=torch.float32).to(device)
    values = torch.tensor(values, dtype=torch.float32).to(device)
    
    #state = np.array(state)  # Convert the list of states to a numpy array
    #policy = np.array(policy)  # Convert the list of policies to a numpy array
    #value = np.array(value)  # Convert the list of values to a numpy array
        
    #states = torch.tensor(np.array(state), dtype=torch.float32).to(device)
    #policies = torch.tensor(np.array(policy), dtype=torch.float32).to(device)
    #values = torch.tensor(np.array(value), dtype=torch.float32).to(device)

    # Train the neural network on the collected data
    #states = states.unsqueeze(0)  # Add channel dimension, resulting shape: [batch_size, 1, 9, 9]
    optimizer = get_optimizer(model,lr)
    #Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Shuffle the game memory for each epoch
        indices = torch.randperm(sumlen)
        states = states[indices]
        policies = policies[indices]
        values = values[indices]

        # Mini-batch training
        for i in range(0, sumlen, batch_size):
            batch_states = states[i:i + batch_size]
            batch_policies = policies[i:i + batch_size]
            batch_values = values[i:i + batch_size]

            # Zero the gradients
            #optimizer.zero_grad()
            for param in model.parameters:
                param.grad=None
            
            total_loss = 0
            
            for j in range(batch_states.size(0)):
                # Get the single state for the current iteration
                #state_single = batch_states[j].unsqueeze(0).unsqueeze(1)  # Shape: [1, 1, 9, 9]
                #policy_single = batch_policies[j].unsqueeze(0)  # Shape: [1, 81]
                #value_single = batch_values[j]  # Shape: [1]
                state_single = batch_states[j]
                policy_single = batch_policies[j].unsqueeze(0)
                value_single = batch_values[j].unsqueeze(0).unsqueeze(1)
                
                # Forward pass for the single state
                predicted_policy, predicted_value = model(state_single)
                
                #print(predicted_policy.shape, policy_single.shape)
                #print(predicted_value.shape, value_single.shape)

                # Compute loss for the single state
                policy_loss = compute_policy_loss(predicted_policy, policy_single)
                value_loss = compute_value_loss(predicted_value, value_single)

                # Accumulate the total loss
                total_loss += (policy_loss + value_loss)

            # Backpropagation
            total_loss.backward()

            # Optimizer step (update weights)
            optimizer.step()

        print(f"Epoch {epoch + 1} completed.")
    return model

In [None]:
def pit(old_model, new_model):
    """
    Pits the old neural network (old_model) against the new one (new_model).
    The new network must win at least 52% of games to be accepted.

    Args:
        old_model: The old neural network.
        new_model: The newly trained neural network.

    Returns:
        True if the new network is better (win rate >= 52%), otherwise False.
    """
    print("Pitting networks...")

    old_model_wins = 0
    new_model_wins = 0
    total_games = n_pit_network
    nets = [None, old_model, new_model]

    for game in range(total_games):
        state = State()  # Start with an empty board
        mover = 1 if game%2==0 else -1

        while True:
            # Select which network plays
            net = nets[mover]
            
            # Get action probabilities from the network
            pi, _ = net(state_to_tensor(state))

            # Mask invalid actions
            valid_actions = get_all_valid_actions(state)
            if not valid_actions:
                break  # No valid actions left, game is a tie
   
            valid_mask = np.zeros(81)
            action_idxs = [action_to_index(a) for a in valid_actions]
            valid_mask[action_idxs] = 1
            pi = pi.detach().cpu().numpy().reshape(81) * valid_mask  # Convert to numpy and apply valid mask
            pi = pi / np.sum(pi)  # Normalize the policy to sum to 1
            
            #the problem here is that in inference mode, you should take best move
            #however, the normal alphazero method is to train the model to predict pi, v to improve MCTS
            #then during gameplay to use the best move from the improved MCTS
            #what is done here is instead to do MCTS and then use these to train the model to predict pi, v
            #and to use the best move from the predicted policy
            #since we need to train the model to predict v for use in minimax
            #therefore we cannot use best move here since best move will be deterministic
            # Choose the best move
            #action_index = int(np.argmax(policy))
            #action = torch.argmax(policy).item()
            action_index = np.random.choice(len(pi), p=pi) #or add dirichlet noise and take max
            
            # Execute the move
            #print(state, action_index)
            state = change_state(state, index_to_action(action_index))
            
            if is_terminal(state):
                #print(state)
                if terminal_utility(state) == 1.0:  # Player 1 wins
                    old_model_wins += 1 if game%2==0 else 0
                    new_model_wins += 1 if game%2==1 else 0
                if terminal_utility(state) == 0.0:  # Player 2 wins
                    old_model_wins += 1 if game%2==1 else 0
                    new_model_wins += 1 if game%2==0 else 0
                break  # Game over

            # Switch players            
            mover *= -1

    total_wins = old_model_wins + new_model_wins
    
    if total_wins == 0:
        print("All games ended in a tie.")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeTie{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(old_model.state_dict(), model_path)
        return False

    old_model_win_percent = old_model_wins / total_wins
    new_model_win_percent = new_model_wins / total_wins
    print(f"Old NN win rate: {old_model_win_percent:.2%} ({old_model_wins} / {total_wins} wins)")
    print(f"New NN win rate: {new_model_win_percent:.2%} ({new_model_wins} / {total_wins} wins)")

    if new_model_win_percent >= threshold:
        print("The new network is better!")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeWin{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(new_model.state_dict(), model_path)
        return True
    else:
        print("The new network lost.")
        return False


In [None]:
#model = UTTT()
#model.to(device)

def train(episodes=100, num_games=8, simulations=200):
    """
    Trains the neural network using self-play, MCTS, and reinforcement learning.
    Saves the best model based on self-play evaluations.
    """
    total_start_time = time.time()
    
    if os.path.isfile('temp.pth'):
        print("Resuming training from last checkpoint")
        model.load_state_dict(torch.load('temp.pth', map_location=device))
    else:
        print("Starting training from scratch")

    print("Starting training...")

    games_mem = []
    
    for episode in range(episodes):
        print(f"Episode {episode + 1}/{episodes}")
        start_time = time.time()
        #print(start_time)
        
        # Save the current model as a temporary model
        torch.save(model.state_dict(), 'temp.pth')
        
        # Load the saved model
        #old_model = nn.__class__()  # Instantiate a new model of the same class
        old_model = UTTT()
        old_model.load_state_dict(torch.load('temp.pth'))
        old_model.to(device)
        
        mctss = [MCTS(model, cpuct) for _ in range(parallel_games)]

        # Self-play to generate training data
        #for _ in range(playgames_before_training):
            #games_mem += playgame()
        games_mem += playgames(mctss, num_games=num_games, simulations=simulations)

        # Train the network with collected data
        train_model(model, games_mem)

        # Clear memory after training
        games_mem = []

        # Compare old vs. new network through self-play
        if pit(old_model, model):
            del old_model
        else:
            # If new NN is worse, revert back to old NN
            #model.load_state_dict(torch.load('temp.pth'))
            model.load_state_dict(old_model.state_dict())
            del old_model
        
        end_time = time.time()
        #print(end_time)
        
        #print((end_time-start_time)/60)
        print(f"Episode {episode + 1} took {(end_time - start_time) / 60} minutes.")

    # Save the final trained model
    now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f'tictactoe_MCTS{episodes}_{now}.pth'
    model_path = os.path.join(save_model_path, filename)
    torch.save(model.state_dict(), model_path)
    torch.save(model.state_dict(), 'temp.pth')
    print(f"Training complete. Model saved as {filename}")
    
    total_end_time = time.time()
    print(f"Training {episodes} took {(total_end_time - total_start_time) / 3600} hours.")

In [None]:
inference_model = UTTT()
inference_model.load_state_dict(torch.load('temp.pth'))
inference_model.to(device)

def get_move(state, model):
    valid_actions = get_all_valid_actions(state)
    ev,best_action = float('inf'),None
    for action in valid_actions:
        next_state=change_state(state,action)
        _,v=model(state_to_tensor(next_state)) #this gives v from next player POV due to flip #v measures how good for the CURRENT PLAYER
        if v<ev: ev,best_action = v,action
    return best_action

class StudentAgent:
    def __init__(self):
        """Instantiates your agent.
        """

    def choose_action(self, state: State) -> Action:
        """Returns a valid action to be played on the board.
        Assuming that you are filling in the board with number 1.

        Parameters
        ---------------
        state: The board to make a move on.
        """
        #return best_move(state, 1, inference_model)
        return get_move(state, inference_model)

# Use this cell to test your agent in two full games against a random agent.
# The random agent will choose actions randomly among the valid actions.

class RandomStudentAgent(StudentAgent):
    def choose_action(self, state: State) -> Action:
        # If you're using an existing Player 1 agent, you may need to invert the state
        # to have it play as Player 2. Uncomment the next line to invert the state.
        # state = state.invert()

        # Choose a random valid action from the current game state
        return state.get_random_valid_action()

def run(your_agent: StudentAgent, opponent_agent: StudentAgent, start_num: int):
    your_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    opponent_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    turn_count = 0
    
    state = State(fill_num=start_num)
    
    while not state.is_terminal():
        #print(state)
        turn_count += 1

        agent_name = "your_agent" if state.fill_num == 1 else "opponent_agent"
        agent = your_agent if state.fill_num == 1 else opponent_agent
        stats = your_agent_stats if state.fill_num == 1 else opponent_agent_stats

        start_time = time.time()
        action = agent.choose_action(state.clone())
        end_time = time.time()
        
        random_action = state.get_random_valid_action()
        if end_time - start_time > 3:
            print(f"{agent_name} timed out!")
            stats["timeout_count"] += 1
            action = random_action
        if not state.is_valid_action(action):
            print(f"{agent_name} made an invalid action!")
            stats["invalid_count"] += 1
            action = random_action
                
        #print(action)
        print(inference_model(state_to_tensor(state)))
        state = state.change_state(action)

    print(f"== {your_agent.__class__.__name__} (1) vs {opponent_agent.__class__.__name__} (2) - First Player: {start_num} ==")
        
    if state.terminal_utility() == 1:
        print("You win!")
    elif state.terminal_utility() == 0:
        print("You lose!")
    else:
        print("Draw")

    for agent_name, stats in [("your_agent", your_agent_stats), ("opponent_agent", opponent_agent_stats)]:
        print(f"{agent_name} statistics:")
        print(f"Timeout count: {stats['timeout_count']}")
        print(f"Invalid count: {stats['invalid_count']}")
        
    print(f"Turn count: {turn_count}\n")
    #print(state)

your_agent = lambda: StudentAgent()
opponent_agent = lambda: RandomStudentAgent()

run(your_agent(), opponent_agent(), 1)
run(your_agent(), opponent_agent(), 2)

In [55]:
torch.set_printoptions(precision=10, threshold=200000)

model1 = UTTT(False, False)
model2 = _UTTT(False, False)

print(sum(p.numel() for p in model1.parameters() if p.requires_grad))
print(sum(p.numel() for p in model2.parameters() if p.requires_grad))

print([k for k,v in model1.state_dict().items()])

print([k for k,v in model2.state_dict().items()])

print(model2)

180162
169713
['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'fc1.weight', 'fc1.bias', 'pi.weight', 'pi.bias', 'v.weight', 'v.bias']
['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'conv3.weight', 'conv3.bias', 'fc1.weight', 'fc1.bias', 'v.weight', 'v.bias']
_UTTT(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (fc1): Linear(in_features=1296, out_features=128, bias=True)
  (v): Linear(in_features=128, out_features=1, bias=True)
)


In [56]:
with open("model_weights.txt", "w") as f:
    print(model2.state_dict(),file=f)


In [44]:
coeffs=model2.state_dict()
model3=_UTTT(False,False)
model3.load_state_dict(coeffs)

<All keys matched successfully>