In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
from collections import defaultdict
import os
from datetime import datetime
import time

from utils import State, Action, is_terminal, change_state, get_all_valid_actions, terminal_utility, is_valid_action, invert

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)  # Skip connection

class UTTT(nn.Module):
    def __init__(self):
        super(UTTT, self).__init__()

        # Convolutional layers
        #self.conv0 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1)
        #self.bn0 = nn.BatchNorm2d(128)

        #self.conv1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)  # Same padding to keep 9x9
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        #self.conv2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        #self.conv3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(128, eps=1e-5, momentum=0.09)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)

        # Fully connected layers
        self.fc1 = nn.Linear(128 * 9 * 9, 512)  # Input size: 128 filters × 9 × 9
        self.fc2 = nn.Linear(512, 256)

        # Output layers
        self.pi = nn.Linear(256, 81)  # Policy output (Softmax for move probabilities)
        self.v = nn.Linear(256, 1)    # Value output (Tanh for game state evaluation)
        

    '''
    def forward(self, x):
        #x = x.unsqueeze(1)  # Add channel dimension (Batch, 1, 9, 9)
        #print("After unsqueeze:", x.shape)
        
        #x = torch.relu(self.conv0(x))

        x = torch.relu(self.conv1(x))
        #print("After conv1:", x.shape)

        x = torch.relu(self.conv2(x))
        #print("After conv2:", x.shape)

        x = torch.relu(self.conv3(x))
        #print("After conv3:", x.shape)

        #print("Shape before flattening:", x.shape)

        # Flatten before fully connected layers
        x = x.view(x.size(0), -1)  # Flatten: (Batch, 128*9*9)
        
        #print("Shape after flattening:", x.shape)

        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))

        # Policy and value heads
        pi = torch.softmax(self.pi(x), dim=1)  # Softmax for policy
        v = torch.tanh(self.v(x))  # Tanh for value

        return pi, v
    '''
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))

        # Residual connections
        x = self.res2(x)
        x = self.res3(x)

        # Flatten and pass through FC layers
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))

        # Outputs
        pi = F.softmax(self.pi(x), dim=1)  # Policy head (logits, apply softmax in loss function)
        v = F.tanh(self.v(x))  # Value head (range [-1, 1])

        return pi, v

    
# Instantiate the model
model = UTTT()
model.to(device)

# Define the optimizer
def get_optimizer(model, lr=0.0001):
    return optim.Adam(model.parameters(), lr=lr)

# Define loss functions
def compute_policy_loss(predicted_policy, true_policy):
    #return nn.CrossEntropyLoss()(predicted_policy, true_policy)
    policy_loss_fn = torch.nn.CrossEntropyLoss()  # For policy (classification)
    return policy_loss_fn(predicted_policy, true_policy)

def compute_value_loss(predicted_value, true_value):
    #return nn.MSELoss()(predicted_value, true_value)
    value_loss_fn = torch.nn.MSELoss()  # For value (regression)
    return value_loss_fn(predicted_value, true_value)

# Print model summary
print(model)


UTTT(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (res2): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (res3): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (fc1): Lin

In [2]:
#hyperparameters
train_episodes = 100
mcts_search = 600
n_pit_network = 50 #20
threshold = 0.52 #0.50 #0.55
temperature = 0.05 #lower is more deterministic
playgames_before_training = 4
cpuct = 4
training_epochs = 4
learning_rate = 0.0001 #1e-4 for first 250 episodes, 5e-5 for next 250 episodes
save_model_path = 'training'

In [3]:
def board_to_array(state):
    """
    Convert the Ultimate Tic-Tac-Toe board state into a 9x9 array, making sure it's hashable.
    
    Args:
        state: The current game state (State object).
        mini_board: The index of the mini-board where the next move should be made.
        current_player: The player whose turn it is (1 or -1).
    
    Returns:
        A NumPy array (2, 9, 9) representation of the board (with current player layer).
    """
    player = 1 if state.fill_num==1 else -1
    
    # Get the board from the state
    board = state.board  # This is a 3x3x3x3 ndarray

    # Flatten the 3x3x3x3 board into a 9x9 representation
    board_array = np.zeros((9, 9), dtype=np.float32)

    for i in range(3):
        for j in range(3):
            for k in range(3):
                for l in range(3):
                    # Map the 3x3x3x3 into the 9x9 flattened array
                    board_array[i * 3 + k][j * 3 + l] = board[i][j][k][l]

    # Normalize the board values with flip
    board_array[board_array == 1] = 1.0 * player
    board_array[board_array == 2] = -1.0 * player
    board_array[board_array == 0] = 0.0

    # Convert to a tuple of tuples for hashing
    #final_representation_tuple = tuple(map(tuple, np.array(board_array).reshape(1, -1).T))
    
    #return tuple(map(tuple, np.array(board_array)))
    return np.array(board_array)


In [4]:
# Initialize MCTS statistics
#P = defaultdict(lambda: np.zeros(81))  # Policy
P={}
Ns = defaultdict(int)  # Visit count for states
Q = defaultdict(float)  # Q-value for (state, action)
Nsa = defaultdict(int)  # Visit count for (state, action)
W = defaultdict(float)  # Total reward for (state, action)



In [5]:
def action_to_index(action):
    """
    Convert the action tuple (meta_row, meta_col, local_row, local_col) into an integer index for the 9x9 grid.

    Args:
        action: Tuple (meta_row, meta_col, local_row, local_col) representing the coordinates on the boards.
    
    Returns:
        Integer index corresponding to the 9x9 grid.
    """
    meta_row, meta_col, local_row, local_col = action
    
    # Map meta-board (meta_row, meta_col) to a flat index in the range [0, 8]
    meta_index = meta_row * 3 + meta_col  # 3x3 meta-board
    
    # Map local-board (local_row, local_col) to a flat index in the range [0, 8]
    local_index = local_row * 3 + local_col  # 3x3 local-board
    
    # Final index in the flattened 9x9 grid
    return meta_index * 9 + local_index  # 9x9 flattened grid


In [6]:
def index_to_action(index: int) -> Action:
    """
    Convert an action index back to the action tuple (meta_row, meta_col, local_row, local_col).
    
    Args:
        index (int): The action index.
        
    Returns:
        Action: The corresponding action tuple (meta_row, meta_col, local_row, local_col).
    """
    # Calculate the meta-row and meta-column from the index
    meta_row = index // 27
    meta_col = (index % 27) // 9
    
    # Calculate the local row and local column from the index
    local_row = (index % 9) // 3
    local_col = index % 3
    
    return (meta_row, meta_col, local_row, local_col)


In [7]:
def normalize_policy(policy):
    """
    Normalize the policy and ensure no NaN or infinite values.
    
    Args:
        policy: A numpy array containing the policy probabilities for each action.
        valid_mask: A mask array where valid actions are 1, invalid actions are 0.
    
    Returns:
        A normalized policy.
    """    
    # Check if any valid actions are left
    if np.sum(policy) > 0:
        policy = policy / np.sum(policy)  # Normalize
    else:
        # If no valid actions, reset to uniform distribution over valid actions
        policy = np.ones_like(policy) / len(policy)
    
    # Check for NaN or infinite values in the policy and reset if needed
    if np.any(np.isnan(policy)) or np.any(np.isinf(policy)):
        policy = np.ones_like(policy) / len(policy)  # Reset to uniform distribution
    
    return policy



In [8]:
def mcts(state):
    """
    Perform MCTS to explore the game tree and return the value for the current state.
    
    Args:
        state: The current game state (State object).
        current_player: The player whose turn it is (1 or -1).
        mini_board: The index of the mini-board where the next move is to be made (0-8).
        
    Returns:
        The value of the current state.
    """
    # Get valid actions using utils.py
    possible_actions = get_all_valid_actions(state)

    # Convert the state to an array representation (as expected by the neural network)
    sArray = board_to_array(state)  # Get a hashable state
    #print(sArray.shape) #(9,9)
    # Ensure sTuple is a 9x9 grid (shape: [81] -> [9, 9])
    #board_array = np.array(sTuple).reshape(9, 9)  # Reshaping the state into a 9x9 grid
    sTuple = tuple(map(tuple, sArray))

    # Now reshape into the format expected by the neural network: [batch_size, channels, height, width]
    state_tensor = torch.tensor(sArray, dtype=torch.float32).unsqueeze(0).unsqueeze(1)  # Shape: [1, 1, 9, 9]
    #state_tensor = torch.tensor(sArray.reshape(1,9,9),dtype=torch.float32)
    #state_tensor = sArray.reshape(1,9,9)
    #print(state_tensor.shape)
    # If there are valid actions, proceed with exploration
    if len(possible_actions) > 0:
        #if Ns[sTuple] == 0:  # If the state hasn't been visited before
        if sTuple not in P.keys():
            # Get policy and value from the neural network (nn is your PyTorch model)
            pi, v = model(state_tensor.to(device))  # Forward pass through the model
            #print(pi)
            v = v.item()  # Convert the value to a scalar (it comes as a 1D tensor)

            # Mask the policy with valid moves
            valid_mask = np.zeros(81)
            #valid_mask[possible_actions] = 1
            for action in possible_actions:
                # Convert action tuple into an integer index
                #action_idx = action[0] * 27 + action[1] * 9 + action[2] * 3 + action[3]  # Assuming action is a tuple (i, j, k, l)
                action_idx = action_to_index(action)
                valid_mask[action_idx] = 1

            pi = pi.detach().cpu().numpy().reshape(81) * valid_mask  # Convert to numpy and apply valid mask
            pi = pi / np.sum(pi)  # Normalize the policy to sum to 1
            #pi = normalize_policy(pi)
            #print(pi)
            # Store the policy and initialize visit counts and rewards
            P[sTuple] = pi
            #P[sTuple] = {action: pi[i] for i, action in enumerate(possible_actions)}
            Ns[sTuple] = 1
            for action in possible_actions:
                Q[(sTuple, action)] = 0
                Nsa[(sTuple, action)] = 0
                W[(sTuple, action)] = 0
            return -v  # Return the negative value (since we're maximizing)

        # Select the best action using UCT (Upper Confidence Bound for Trees)
        best_uct = -100
        for action in possible_actions:
            #normally we would use a separate Node class to store this information to calculate UCT
            #this is modified uct since actual uct is wi/ni + c * math.sqrt(math.log(Ni) / ni):
            #uct_value = Q[(sTuple, action)] + cpuct * P[sTuple][action_to_index(action)] * math.sqrt(math.log(Ns[sTuple]) / (1 + Nsa[(sTuple, action)]))
            uct_value = Q[(sTuple, action)] + cpuct * P[sTuple][action_to_index(action)] * (math.sqrt(Ns[sTuple]) / (1 + Nsa[(sTuple, action)]))
            if uct_value > best_uct:
                best_uct = uct_value
                best_action = action

        # Apply the chosen action to generate the next state
        next_state = change_state(state, best_action)

        # If the game is over (terminal state), assign a utility value
        if is_terminal(next_state):
            v = terminal_utility(next_state)
        else:
            # Otherwise, switch players and recurse with the new state
            v = mcts(next_state)

    else:
        # If there are no valid actions, return 0 (draw or no possible moves)
        v = 0.5 #0

    # Update MCTS statistics with the results from the simulation
    W[(sTuple, best_action)] += v
    Ns[sTuple] += 1
    Nsa[(sTuple, best_action)] += 1
    Q[(sTuple, best_action)] = W[(sTuple, best_action)] / Nsa[(sTuple, best_action)]
    
    return -v


In [9]:
def get_action_probs(init_board):
    """
    Get the action probabilities from the MCTS search for a given board state.
    
    Args:
        init_board: The initial game state (State object).
        current_player: The player whose turn it is (1 or -1).
        mini_board: The mini-board index where the next move is to be made (0-8).
        
    Returns:
        action_probs: A numpy array of size 81 representing the action probabilities.
    """
    # Perform MCTS simulations
    for _ in range(mcts_search):
        state_copy = init_board  # No need for deep copy, we're using immutable states
        value = mcts(state_copy)
    
    print("Done one iteration of MCTS")

    # Initialize a dictionary to store action probabilities
    actions_dict = {}

    # Convert the state to an array representation for hashing (as done in MCTS)
    sArray = board_to_array(init_board)
    sTuple = tuple(map(tuple, sArray))  # Create a tuple from the array for hashing in Ns, Nsa, etc.

    # Get all valid actions using the state method
    possible_actions = get_all_valid_actions(init_board)

    # Calculate action probabilities based on MCTS visit statistics
    for action in possible_actions:
        if Ns[sTuple] > 0:
            actions_dict[action] = Nsa[(sTuple, action)] / Ns[sTuple]
        else:
            actions_dict[action] = 0

    # Initialize the action probabilities array (size 81)
    action_probs = np.zeros(81)

    # Assign probabilities to the corresponding actions
    for action in actions_dict:
        np.put(action_probs, action_to_index(action), actions_dict[action], mode='raise')
    #print(action_probs)
    return action_probs


In [10]:
def playgame():
    """
    Simulate one game of Ultimate Tic-Tac-Toe, utilizing MCTS for decision-making.
    
    Returns:
        game_mem: A list of game memory (state, player, action probability, game result).
    """
    done = False
    game_mem = []

    # Initialize the game state (empty board)
    real_board = State()

    while not done:
        # Get action probabilities using MCTS
        policy = get_action_probs(real_board)
        policy = policy / np.sum(policy)  # Normalize the policy
        #print(policy)
        # Store the game memory (board array, current player, action probabilities)
        game_mem.append([board_to_array(real_board), real_board.fill_num, policy, None])

        # Choose an action based on the policy
        action_index = np.random.choice(len(policy), p=policy)
        action = index_to_action(action_index)
        
        #print("Policy:", policy)
        #print("Chosen action:", action)
        #print("Mini-board:", mini_board)

        # Print the current board
        #print(real_board)

        # Apply the chosen action to update the game state
        next_state = change_state(real_board, action)

        # Check if the game is over (i.e., no valid moves left or a win condition)
        if len(get_all_valid_actions(real_board)) == 0:
            # No valid actions left, end the game with a draw
            for tup in game_mem:
                tup[3] = 0.5 #0  # Draw result
            return game_mem

        if is_terminal(next_state):
            if terminal_utility(next_state) == 1.0:
                for tup in game_mem:
                    if tup[1] == 1:
                        tup[3] = 1.0
                    else:
                        tup[3] = 0.0
            if terminal_utility(next_state) == 0.0:
                for tup in game_mem:
                    if tup[1] == 1:
                        tup[3] = 0.0
                    else:
                        tup[3] = 1.0
            if terminal_utility(next_state) == 0.5:
                for tup in game_mem:
                    tup[3] = 0.5
            return game_mem
        
        # Switch players
        real_board = next_state  # Update the game state to the next state


In [11]:
def train_nn(nn, game_mem, epochs=4, batch_size=32, lr=0.0001):
    """
    Train the neural network using the game memory (states, policies, and results).
    
    Args:
        nn: The neural network model to train.
        game_mem: A list of game memories, each containing [state_array, current_player, policy, result].
    
    Returns:
        history: The training history of the neural network.
    """
    print("Training Network")
    print("Length of game_mem:", len(game_mem))
    
    state = []
    policy = []
    value = []

    # Prepare the training data from game memory
    for mem in game_mem:
        # Extract state, policy, and result (value)
        state.append(mem[0])  # mem[0] is the board state array
        policy.append(mem[2])  # mem[2] is the action policy
        value.append(mem[3])   # mem[3] is the game result (1, 0.5, or 0)
    
    #state = np.array(state)  # Convert the list of states to a numpy array
    #policy = np.array(policy)  # Convert the list of policies to a numpy array
    #value = np.array(value)  # Convert the list of values to a numpy array
        
    states = torch.tensor(np.array(state), dtype=torch.float32).to(device)
    policies = torch.tensor(np.array(policy), dtype=torch.float32).to(device)
    values = torch.tensor(np.array(value), dtype=torch.float32).to(device)

    # Train the neural network on the collected data
    #history = nn.fit(state, [policy, value], batch_size=32, epochs=training_epochs, verbose=1)
    #states = states.unsqueeze(0)  # Add channel dimension, resulting shape: [batch_size, 1, 9, 9]
    optimizer = get_optimizer(nn,lr)
    #Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Shuffle the game memory for each epoch
        indices = torch.randperm(len(game_mem))
        states = states[indices]
        policies = policies[indices]
        values = values[indices]

        # Mini-batch training
        for i in range(0, len(game_mem), batch_size):
            batch_states = states[i:i + batch_size]
            batch_policies = policies[i:i + batch_size]
            batch_values = values[i:i + batch_size]

            # Zero the gradients
            optimizer.zero_grad()
            
            total_loss = 0
            
            for j in range(batch_states.size(0)):
                # Get the single state for the current iteration
                state_single = batch_states[j].unsqueeze(0).unsqueeze(1)  # Shape: [1, 1, 9, 9]
                policy_single = batch_policies[j].unsqueeze(0)  # Shape: [1, 81]
                value_single = batch_values[j]  # Shape: [1]

                # Forward pass for the single state
                predicted_policy, predicted_value = nn(state_single)
                
                #print(predicted_policy.shape, policy_single.shape)
                #print(predicted_value.shape, value_single.shape)

                # Compute loss for the single state
                policy_loss = compute_policy_loss(predicted_policy, policy_single)
                value_loss = compute_value_loss(predicted_value, value_single)

                # Accumulate the total loss
                total_loss += (policy_loss + value_loss)

            # Backpropagation
            total_loss.backward()

            # Optimizer step (update weights)
            optimizer.step()

        print(f"Epoch {epoch + 1} completed.")
    #return history
    return nn

In [12]:
def pit(nn, new_nn):
    """
    Pits the old neural network (nn) against the new one (new_nn).
    The new network must win at least 52% of games to be accepted.

    Args:
        nn: The old neural network.
        new_nn: The newly trained neural network.

    Returns:
        True if the new network is better (win rate > 52%), otherwise False.
    """
    print("Pitting networks...")

    nn_wins = 0
    new_nn_wins = 0
    total_games = n_pit_network

    for game in range(total_games):
        state = State()  # Start with an empty board
        
        nets = [None, nn, new_nn]
        mover = 1 if game%2==0 else -1

        while True:
            # Select which network plays
            net = nets[mover]
            
            # Get action probabilities from the network
            state_tensor = torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1)  # Add batch and channel dimensions
            policy, _ = net(state_tensor.to(device))

            # Mask invalid actions
            valid_moves = get_all_valid_actions(state)
            if not valid_moves:
                break  # No valid moves left, game is a tie

            action_probs = np.zeros(81)
            #np.put(action_probs, valid_moves, 1)
            for action in valid_moves:
                # Convert action tuple into an integer index
                #action_idx = action[0] * 27 + action[1] * 9 + action[2] * 3 + action[3]  # Assuming action is a tuple (i, j, k, l)
                action_idx = action_to_index(action)
                action_probs[action_idx] = 1
            policy = policy.detach().cpu().numpy().reshape(81) * action_probs
            policy /= np.sum(policy)
            #policy /= torch.sum(policy)  # Normalize
            
            #the problem here is that in inference mode, you should take best move
            #however, the normal alphazero method is to train the model to predict pi, v to improve MCTS
            #then during gameplay to use the best move from the improved MCTS
            #what is done here is instead to do MCTS and then use these to train the model to predict pi, v
            #and to use the best move from the predicted policy
            #since we need to train the model to predict v for use in minimax
            #therefore we cannot use best move here since best move will be deterministic
            # Choose the best move
            #action_index = int(np.argmax(policy))
            #action = torch.argmax(policy).item()
            action_index = np.random.choice(len(policy), p=policy)
            
            # Execute the move
            #print(state, action_index)
            #print(is_valid_action(state, index_to_action(action_index)))
            state = change_state(state, index_to_action(action_index))
            
            if is_terminal(state):
                #print(state)
                if terminal_utility(state) == 1.0:  # Player 1 wins
                    nn_wins += 1 if game%2==0 else 0
                    new_nn_wins += 1 if game%2==1 else 0
                if terminal_utility(state) == 0.0:  # Player 2 wins
                    nn_wins += 1 if game%2==1 else 0
                    new_nn_wins += 1 if game%2==0 else 0
                break  # Game over

            # Switch players            
            mover *= -1

    total_wins = nn_wins + new_nn_wins
    
    if total_wins == 0:
        print("All games ended in a tie.")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeTie{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(nn.state_dict(), model_path)
        return False

    nn_win_percent = nn_wins / total_wins
    new_nn_win_percent = new_nn_wins / total_wins
    print(f"Old NN win rate: {nn_win_percent:.2%} ({nn_wins} / {total_wins} wins)")
    print(f"New NN win rate: {new_nn_win_percent:.2%} ({new_nn_wins} / {total_wins} wins)")

    if new_nn_win_percent >= threshold:
        print("The new network is better!")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeWin{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(new_nn.state_dict(), model_path)
        return True
    else:
        print("The new network lost.")
        return False


In [13]:
model = UTTT()
model.to(device)
def train(episodes=100):
    """
    Trains the neural network using self-play, MCTS, and reinforcement learning.
    Saves the best model based on self-play evaluations.
    """
    global model
    global Q, Nsa, Ns, W, P  # MCTS data structures

    if os.path.isfile('temp.pth'):
        print("Resuming training from last checkpoint")
        model.load_state_dict(torch.load('temp.pth', map_location=device))
    else:
        print("Starting training from scratch")

    print("Starting training...")

    game_mem = []

    for episode in range(episodes):
        print(f"Episode {episode + 1}/{episodes}")
        start_time = time.time()
        print(start_time)
        
        # Save the current model as a temporary model
        torch.save(model.state_dict(), 'temp.pth')
        
        # Load the saved model
        #old_model = nn.__class__()  # Instantiate a new model of the same class
        old_model = UTTT()
        old_model.load_state_dict(torch.load('temp.pth'))
        old_model.to(device)

        # Self-play to generate training data
        for _ in range(playgames_before_training):
            game_mem += playgame()

        # Train the network with collected data
        train_nn(model, game_mem)

        # Clear memory after training
        game_mem = []

        # Compare old vs. new network through self-play
        if pit(old_model, model):
            # If new NN is better, reset MCTS values
            del old_model
            #Q, Nsa, Ns, W, P = {}, {}, {}, {}, {}
            P={}
            Ns = defaultdict(int)  # Visit count for states
            Q = defaultdict(float)  # Q-value for (state, action)
            Nsa = defaultdict(int)  # Visit count for (state, action)
            W = defaultdict(float)  # Total reward for (state, action)
        else:
            # If new NN is worse, revert back to old NN
            #model.load_state_dict(torch.load('temp.pth'))
            model.load_state_dict(old_model.state_dict())
            del old_model
        
        end_time = time.time()
        print(end_time)
        
        print((end_time-start_time)/60)

    # Save the final trained model
    now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f'tictactoe_MCTS{episodes}_{now}.pth'
    model_path = os.path.join(save_model_path, filename)
    torch.save(model.state_dict(), model_path)
    torch.save(model.state_dict(), 'temp.pth')
    print(f"Training complete. Model saved as {filename}")


In [27]:
train(25)

Resuming training from last checkpoint
Starting training...
Episode 1/25
1742461545.9377208
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteratio

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Training Network
Length of game_mem: 238
Epoch 1/4
Epoch 1 completed.
Epoch 2/4
Epoch 2 completed.
Epoch 3/4
Epoch 3 completed.
Epoch 4/4
Epoch 4 completed.
Pitting networks...
Old NN win rate: 37.84% (14 / 37 wins)
New NN win rate: 62.16% (23 / 37 wins)
The new network is better!
1742466262.2446344
21.694873666763307
Episode 5/25
1742466262.2450721
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS


Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Training Network
Length of game_mem: 209
Epoch 1/4
Epoch 1 completed.
Epoch 2/4
Epoch 2 completed.
Epoch 3/4
Epoch 3 completed.
Epoch 4/4
Epoch 4 completed.
Pitting networks...
Old NN win rate: 53.66% (22 / 41 wins)
New NN win rate: 46.34% (19 / 41 wins)
The new network lost.
1742470690.4976492
18.8553413550059
Episode 9/25
1742470690.4984748
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done on

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
T

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Training Network
Length of game_mem: 191
Epoch 1/4
Epoch 1 completed.
Epoch 2/4
Epoch 2 completed.
Epoch 3/4
Epoch 3 completed.
Epoch 4/4
Epoch 4 completed.
Pitting networks...
Old NN win rate: 48.84% (21 / 43 wins)
New NN win rate: 51.16% (22 / 43 wins)
The new network lost.
1742481038.381137
16.497522739569344
Episode 18/25
1742481038.3813875
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done 

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Training Network
Length of game_mem: 207
Epoch 1/4
Epoch 1 completed.
Epoch 2/4
Epoch 2 completed.
Epoch 3/4
Epoch 3 completed.
Epoch 4/4
Epoch 4 completed.
Pitting networks...
Old NN win rate: 41.46% (17 / 41 wins)
New NN win rate: 58.54% (24 / 41 wins)
The new network is better!
1742485514.920664
16.603967142105102
Episode 22/25
1742485514.9218504
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS


Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
D

Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Done one iteration of MCTS
Training Network
Length of game_mem: 202
Epoch 1/4
Epoch 1 completed.
Epoch 2/4
Epoch 2 completed.
Epoch 3/4
Epoch 3 completed.
Epoch 4/4
Epoch 4 completed.
Pitting networks...
Old NN win rate: 47.73% (21 / 44 wins)
New NN win rate: 52.27% (23 / 44 wins)
The new network is better!
1742490197.191203
18.792184126377105
Training complete. Model saved as tictactoe_MCTS25_2025-03-20_17-03-17.pth


In [28]:
inference_model = UTTT()
inference_model.load_state_dict(torch.load('temp.pth'))
inference_model.to(device)

UTTT(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (res2): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (res3): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (fc1): Lin

In [29]:
def state_to_tensor(state):
    return torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)  # Add batch and channel dimensions

def minimax(state, depth, alpha, beta, maximizing_player, model):
    """
    Minimax algorithm with Alpha-Beta pruning.
    
    Args:
        state: The current game state.
        depth: Depth of the search tree.
        alpha: Alpha value for pruning (best already found for maximizing player).
        beta: Beta value for pruning (best already found for minimizing player).
        maximizing_player: Boolean indicating if it's the maximizing player's turn.
        model: The neural network model that predicts (pi, v).
    
    Returns:
        The best score/evaluation for the given state.
    """
    if is_terminal(state) or depth == 0:
        _, v = model(state_to_tensor(state))
        return v  # Return the evaluation from the model

    valid_moves = get_all_valid_actions(state)
    
    if maximizing_player:
        max_eval = -float("inf")
        for action in valid_moves:
            new_state = change_state(state, action)
            eval = minimax(new_state, depth - 1, alpha, beta, False, model)
            max_eval = max(max_eval, eval)
            alpha = max(alpha, eval)
            if beta <= alpha:  # Prune
                break
        return max_eval
    else:
        min_eval = float("inf")
        for action in valid_moves:
            new_state = change_state(state, action)
            eval = minimax(new_state, depth - 1, alpha, beta, True, model)
            min_eval = min(min_eval, eval)
            beta = min(beta, eval)
            if beta <= alpha:  # Prune
                break
        return min_eval

def best_move(state, depth, model):
    """
    Determines the best possible move using minimax with alpha-beta pruning.
    
    Args:
        state: The current game state.
        depth: The maximum depth to search.
        model: The neural network model.
    
    Returns:
        The best action to take.
    """
    best_action = None
    best_value = -float("inf")
    alpha = -float("inf")
    beta = float("inf")

    valid_moves = get_all_valid_actions(state)
    
    for action in valid_moves:
        new_state = change_state(state, action)
        eval = minimax(new_state, depth - 1, alpha, beta, new_state.fill_num==1, model)
        
        if eval > best_value:
            best_value = eval
            best_action = action
        
        alpha = max(alpha, eval)

    return best_action


In [30]:
def get_move(state, model):
    valid_actions = get_all_valid_actions(state)
    ev,best_action = float('inf'),None
    for action in valid_actions:
        next_state=change_state(state,action)
        _,v=model(state_to_tensor(next_state)) #this gives v from next player POV due to flip
        if v<ev: ev,best_action = v,action
    return best_action
            

In [52]:
class StudentAgent:
    def __init__(self):
        """Instantiates your agent.
        """

    def choose_action(self, state: State) -> Action:
        """Returns a valid action to be played on the board.
        Assuming that you are filling in the board with number 1.

        Parameters
        ---------------
        state: The board to make a move on.
        """
        #return best_move(state, 1, inference_model)
        return get_move(state, inference_model)

# Use this cell to test your agent in two full games against a random agent.
# The random agent will choose actions randomly among the valid actions.

class RandomStudentAgent(StudentAgent):
    def choose_action(self, state: State) -> Action:
        # If you're using an existing Player 1 agent, you may need to invert the state
        # to have it play as Player 2. Uncomment the next line to invert the state.
        # state = state.invert()

        # Choose a random valid action from the current game state
        return state.get_random_valid_action()

def run(your_agent: StudentAgent, opponent_agent: StudentAgent, start_num: int):
    your_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    opponent_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    turn_count = 0
    
    state = State(fill_num=start_num)
    
    while not state.is_terminal():
        #print(state)
        turn_count += 1

        agent_name = "your_agent" if state.fill_num == 1 else "opponent_agent"
        agent = your_agent if state.fill_num == 1 else opponent_agent
        stats = your_agent_stats if state.fill_num == 1 else opponent_agent_stats

        start_time = time.time()
        action = agent.choose_action(state.clone())
        end_time = time.time()
        
        random_action = state.get_random_valid_action()
        if end_time - start_time > 3:
            print(f"{agent_name} timed out!")
            stats["timeout_count"] += 1
            action = random_action
        if not state.is_valid_action(action):
            print(f"{agent_name} made an invalid action!")
            stats["invalid_count"] += 1
            action = random_action
                
        #print(action)
        state = state.change_state(action)

    print(f"== {your_agent.__class__.__name__} (1) vs {opponent_agent.__class__.__name__} (2) - First Player: {start_num} ==")
        
    if state.terminal_utility() == 1:
        print("You win!")
    elif state.terminal_utility() == 0:
        print("You lose!")
    else:
        print("Draw")

    for agent_name, stats in [("your_agent", your_agent_stats), ("opponent_agent", opponent_agent_stats)]:
        print(f"{agent_name} statistics:")
        print(f"Timeout count: {stats['timeout_count']}")
        print(f"Invalid count: {stats['invalid_count']}")
        
    print(f"Turn count: {turn_count}\n")
    #print(state)

your_agent = lambda: StudentAgent()
opponent_agent = lambda: RandomStudentAgent()

run(your_agent(), opponent_agent(), 1)
run(your_agent(), opponent_agent(), 2)

== StudentAgent (1) vs RandomStudentAgent (2) - First Player: 1 ==
Draw
your_agent statistics:
Timeout count: 0
Invalid count: 0
opponent_agent statistics:
Timeout count: 0
Invalid count: 0
Turn count: 60

== StudentAgent (1) vs RandomStudentAgent (2) - First Player: 2 ==
Draw
your_agent statistics:
Timeout count: 0
Invalid count: 0
opponent_agent statistics:
Timeout count: 0
Invalid count: 0
Turn count: 58



In [37]:
s=State()
while not is_terminal(s):
    print(s)
    inference_model(state_to_tensor(s))
    #print(s.get_all_valid_actions())
    print(action:=s.get_random_valid_action())
    s=s.change_state(action)
print(s)
inference_model(state_to_tensor(s))

State(
    board=
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        ---------------------
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        ---------------------
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0, 
    local_board_status=
        [[0 0 0]
         [0 0 0]
         [0 0 0]], 
    prev_local_action=None, 
    fill_num=1
)

(0, 0, 1, 2)
State(
    board=
        0 0 0 | 0 0 0 | 0 0 0
        0 0 1 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        ---------------------
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        ---------------------
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0
        0 0 0 | 0 0 0 | 0 0 0, 
    local_board_status=
        [[0 0 0]
         [0 0 0]
         [0 0 0]], 
    prev_local_action=(1, 2), 
    fill_num=2
)

(1, 2, 2, 0)
State(
    board=
    

(1, 1, 2, 2)
State(
    board=
        0 2 2 | 0 0 2 | 0 1 2
        2 0 1 | 0 0 0 | 2 1 0
        1 0 2 | 1 1 1 | 1 1 2
        ---------------------
        0 1 1 | 1 1 0 | 1 1 0
        1 2 0 | 2 0 0 | 0 0 0
        2 0 1 | 2 0 1 | 2 0 0
        ---------------------
        1 2 2 | 1 0 1 | 2 2 1
        2 0 2 | 0 2 0 | 1 2 2
        1 1 0 | 0 2 2 | 0 2 1, 
    local_board_status=
        [[0 1 1]
         [0 0 0]
         [0 0 2]], 
    prev_local_action=(2, 2), 
    fill_num=2
)

(2, 0, 2, 2)
State(
    board=
        0 2 2 | 0 0 2 | 0 1 2
        2 0 1 | 0 0 0 | 2 1 0
        1 0 2 | 1 1 1 | 1 1 2
        ---------------------
        0 1 1 | 1 1 0 | 1 1 0
        1 2 0 | 2 0 0 | 0 0 0
        2 0 1 | 2 0 1 | 2 0 0
        ---------------------
        1 2 2 | 1 0 1 | 2 2 1
        2 0 2 | 0 2 0 | 1 2 2
        1 1 2 | 0 2 2 | 0 2 1, 
    local_board_status=
        [[0 1 1]
         [0 0 0]
         [2 0 2]], 
    prev_local_action=(2, 2), 
    fill_num=1
)

(1, 2, 2, 1)
State(


(tensor([[2.2859e-03, 3.4782e-09, 1.0948e-02, 2.7600e-06, 1.1569e-06, 7.1274e-02,
          2.4060e-08, 1.0381e-09, 2.0469e-01, 1.0673e-05, 6.2828e-08, 5.8707e-03,
          7.5934e-08, 9.2144e-08, 5.8572e-02, 1.3714e-09, 1.7948e-09, 5.2803e-03,
          2.1606e-08, 5.7949e-02, 1.1157e-01, 3.4340e-08, 2.7338e-03, 1.2535e-03,
          3.5510e-09, 3.2869e-05, 1.4927e-03, 1.3997e-08, 2.4681e-03, 1.5641e-07,
          9.5258e-09, 1.7156e-03, 1.7313e-08, 2.1413e-08, 1.1071e-02, 8.5786e-02,
          1.1501e-07, 4.2648e-02, 2.8255e-03, 4.1153e-09, 3.8017e-03, 3.0554e-06,
          1.7885e-02, 1.6206e-08, 1.6770e-02, 2.5489e-03, 2.0295e-09, 3.7610e-02,
          2.3160e-04, 7.7432e-09, 3.3313e-09, 8.0548e-03, 1.4141e-05, 4.2675e-04,
          3.6534e-09, 3.5217e-09, 2.7222e-06, 7.4703e-09, 1.1672e-05, 5.9299e-04,
          1.2366e-01, 3.3770e-03, 6.0208e-04, 8.8406e-05, 3.9605e-09, 1.0991e-02,
          6.3034e-03, 1.0582e-03, 1.3227e-02, 8.9470e-09, 1.5516e-09, 3.0423e-09,
          2.6925