In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import numpy as np
from collections import defaultdict
import os
from datetime import datetime
import time
from functools import lru_cache
from multiprocessing import Pool
import concurrent.futures

from utils import State, Action, is_terminal, change_state, get_all_valid_actions, terminal_utility, is_valid_action, invert

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

#we can stack arbitrary number of Residualblock
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(channels, eps=1e-5, momentum=0.09)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return F.relu(out + residual)  # Skip connection

#if using batchnorm, bias=False, can add additional FC/conv layer also
class UTTT(nn.Module):
    def __init__(self, residual=True, size=0, layers="MAX", batchnorm=True):
        super(UTTT, self).__init__()
        
        self.channels = 128 // (2**size)
        self.residual = residual
        self.fc_size1 = 512 if not size else 256 if layers=="MAX" else 128
        self.fc_size2 = 256 if not size else 128
        self.layers = layers
        self.batchnorm = batchnorm
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, self.channels, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn1 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        
        self.conv2 = nn.Conv2d(self.channels, self.channels, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm)
        self.bn2 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.batchnorm else None
        self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=3, stride=1, padding=1, bias= not self.batchnorm) if self.layers=="MAX" else None
        self.bn3 = nn.BatchNorm2d(self.channels, eps=1e-5, momentum=0.09) if self.layers=="MAX" and self.batchnorm else None
        
        self.res2 = ResidualBlock(self.channels) if self.residual else None
        self.res3 = ResidualBlock(self.channels) if self.residual else None

        # Fully connected layers
        self.fc1 = nn.Linear(self.channels * 9 * 9, self.fc_size1)  # Input size: 128 filters × 9 × 9
        self.fc2 = nn.Linear(self.fc_size1, self.fc_size2) if self.layers=="MAX" else None

        # Output layers
        self.pi = nn.Linear(self.fc_size2, 81)  # Policy output (Softmax for move probabilities)
        self.v = nn.Linear(self.fc_size2, 1)    # Value output (Tanh for game state evaluation)
    
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x))) if self.batchnorm else F.relu(self.conv1(x))

        # Residual connections
        if self.residual:
            x = self.res2(x)
            x = self.res3(x)
        else:
            x = F.relu(self.bn2(self.conv2(x))) if self.batchnorm else F.relu(self.conv2(x))
            if self.layers=="MAX":
                x = F.relu(self.bn3(self.conv3(x))) if self.batchnorm else F.relu(self.conv3(x))

        # Flatten and pass through FC layers
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        if self.layers=="MAX":
            x = F.relu(self.fc2(x))

        # Outputs
        pi = F.softmax(self.pi(x), dim=1)  # Policy head (logits, apply softmax in loss function)
        v = F.tanh(self.v(x))  # Value head (range [-1, 1])

        return pi, v

    
# Instantiate the model
model = UTTT()
model.to(device)

# Define the optimizer
def get_optimizer(model, lr=0.0001):
    return optim.Adam(model.parameters(), lr=lr)

# Define loss functions
def compute_policy_loss(predicted_policy, true_policy):
    #return nn.CrossEntropyLoss()(predicted_policy, true_policy)
    policy_loss_fn = torch.nn.CrossEntropyLoss()  # For policy (classification)
    return policy_loss_fn(predicted_policy, true_policy)

def compute_value_loss(predicted_value, true_value):
    #return nn.MSELoss()(predicted_value, true_value)
    value_loss_fn = torch.nn.MSELoss()  # For value (regression)
    return value_loss_fn(predicted_value, true_value)

# Print model summary
print(model)


UTTT(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (res2): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (res3): ResidualBlock(
    (conv1): Conv2d(128,

In [16]:
def board_to_array(state):
    """
    Convert the Ultimate Tic-Tac-Toe board state into a 9x9 array, making sure it's hashable.
    
    Args:
        state: The current game state (State object).
        mini_board: The index of the mini-board where the next move should be made.
        current_player: The player whose turn it is (1 or -1).
    
    Returns:
        A NumPy array (2, 9, 9) representation of the board (with current player layer).
    """
    player = 1 if state.fill_num==1 else -1
    
    # Get the board from the state
    board = state.board  # This is a 3x3x3x3 ndarray

    # Flatten the 3x3x3x3 board into a 9x9 representation
    board_array = np.zeros((9, 9), dtype=np.float32)

    for i in range(3):
        for j in range(3):
            for k in range(3):
                for l in range(3):
                    board_array[i * 3 + k][j * 3 + l] = board[i][j][k][l]

    # Normalize the board values with flip
    board_array[board_array == 1] = 1.0 * player
    board_array[board_array == 2] = -1.0 * player
    board_array[board_array == 0] = 0.0

    return np.array(board_array)

def action_to_index(action):
    """
    Convert the action tuple (meta_row, meta_col, local_row, local_col) into an integer index for the 9x9 grid.

    Args:
        action: Tuple (meta_row, meta_col, local_row, local_col) representing the coordinates on the boards.
    
    Returns:
        Integer index corresponding to the 9x9 grid.
    """
    meta_row, meta_col, local_row, local_col = action
    
    # Map meta-board (meta_row, meta_col) to a flat index in the range [0, 8]
    meta_index = meta_row * 3 + meta_col  # 3x3 meta-board
    
    # Map local-board (local_row, local_col) to a flat index in the range [0, 8]
    local_index = local_row * 3 + local_col  # 3x3 local-board
    
    # Final index in the flattened 9x9 grid
    return meta_index * 9 + local_index  # 9x9 flattened grid

def index_to_action(index: int) -> Action:
    """
    Convert an action index back to the action tuple (meta_row, meta_col, local_row, local_col).
    
    Args:
        index (int): The action index.
        
    Returns:
        Action: The corresponding action tuple (meta_row, meta_col, local_row, local_col).
    """
    # Calculate the meta-row and meta-column from the index
    meta_row = index // 27
    meta_col = (index % 27) // 9
    
    # Calculate the local row and local column from the index
    local_row = (index % 9) // 3
    local_col = index % 3
    
    return (meta_row, meta_col, local_row, local_col)

def normalize_policy(policy):
    """
    Normalize the policy and ensure no NaN or infinite values.
    
    Args:
        policy: A numpy array containing the policy probabilities for each action.
        valid_mask: A mask array where valid actions are 1, invalid actions are 0.
    
    Returns:
        A normalized policy.
    """    
    # Check if any valid actions are left
    if np.sum(policy) > 0:
        policy = policy / np.sum(policy)  # Normalize
    else:
        # If no valid actions, reset to uniform distribution over valid actions
        policy = np.ones_like(policy) / len(policy)
    
    # Check for NaN or infinite values in the policy and reset if needed
    if np.any(np.isnan(policy)) or np.any(np.isinf(policy)):
        policy = np.ones_like(policy) / len(policy)  # Reset to uniform distribution
    
    return policy

def state_to_tensor(state):
    return torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)  # Add batch and channel dimensions


In [17]:
class Node:
    __slots__ = ('state', 'children', 'N', 'W', 'Q', 'P')
    def __init__(self, state):
        self.state = state
        #self.parent = parent
        #self.action = action
        #self.action_probability = None
        self.children = []
        
        self.N = 0
        #self.visit_count = 0
        
        self.W = 0
        #self.total_reward = 0 self.state_value_sum = 0.0
        
        self.Q = 0
        #self.q_value = 0 self.state_value_mean = 0.0
        
        self.P = None
        #self.prior_policy = None self.state_value = None

    def is_leaf(self):
        return len(self.children) == 0
    
    def is_terminal(self):
        return is_terminal(self.state)
    
    def no_valid_actions(self):
        return len(get_all_valid_actions(self.state)) == 0
    
    def best_child(self, cpuct): #cpuct=exploration_weight
        #return max(self.children, key=lambda child: child.W / child.N + cpuct * child.P * math.sqrt(math.log(self.N) / (1 + child.N)))
        #return max(self.children, key=lambda child: child.Q + cpuct * child.P * math.sqrt(self.N) / (1 + child.N))
        #values = np.array([child.W / child.N + cpuct * child.P * np.sqrt(np.log(self.N) / (1 + child.N)) for child in self.children])
        #values = np.array([child.Q + cpuct * child.P * np.sqrt(self.N) / (1 + child.N) for child in self.children])
        values = np.array([-child.Q + cpuct * child.P * np.sqrt(self.N) / (1 + child.N) for child in self.children])
        best_idx = np.argmax(values)
        return self.children[best_idx], best_idx
    
    def best_action(self, cpuct):
        return get_all_valid_actions(self.state)[self.best_child(cpuct)[1]]

    
class MCTS:
    __slots__ = ('model', 'cpuct', 'tree')
    def __init__(self, model, cpuct=2.0):
        self.model = model
        self.cpuct = cpuct
        self.tree = {}
    
    def search(self, root_state, simulations=200):
        root = self.get_or_create_node(root_state)
        
        if root.is_terminal() or root.no_valid_actions():
            return
        
        for _ in range(simulations):
            node = root
            path = [node]
            
            # Selection
            while not node.is_leaf():
                node = node.best_child(self.cpuct)[0]
                path.append(node)
            
            pi, v = self.model(self.get_state_tensor(node.state))
            v = v.item()
            
            # Expansion
            if not node.is_terminal() and not node.no_valid_actions():
                valid_actions = get_all_valid_actions(node.state)
                self.expand_node(node, valid_actions, pi)
            
            # Evaluation
            v = self.evaluate(node, v)
            
            # Backpropagation
            for node in reversed(path):
                node.N += 1
                node.W += v
                node.Q = node.W / node.N
                v *= -1
            #print(path)
        return root.best_action(self.cpuct)
    
    def expand_node(self, node, valid_actions, pi):
        valid_mask = np.zeros(81)
        action_idxs = [action_to_index(a) for a in valid_actions]
        valid_mask[action_idxs] = 1
        pi = pi.detach().cpu().numpy().reshape(81) * valid_mask  # Convert to numpy and apply valid mask
        pi = pi / np.sum(pi)  # Normalize the policy to sum to 1
        for action, action_idx in zip(valid_actions, action_idxs):
            child_state = change_state(node.state, action)
            child_node = Node(child_state)
            child_node.P = pi[action_idx]
            node.children.append(child_node)

    def evaluate(self, node, v):
        if is_terminal(node.state):
            #return 2*terminal_utility(node.state)-1
            #return 0 if 2*terminal_utility(node.state)-1==0 else 1
            return 0 if 2*terminal_utility(node.state)-1==0 else -1
            #remember v is utility to current player. so terminal state, current player to move loses so -1
            #otherwise, model predicts v which is utiilty for current player
            #then we need to use -child.Q since we are maximizing (from POV of current player)
            #i.e. current player wants to find most -ve child.Q since that is worst for child -> best for current
        return v
    
    def get_or_create_node(self, state):
        state_tuple = tuple(map(tuple, board_to_array(state)))
        if state_tuple not in self.tree:
            self.tree[state_tuple] = Node(state)
        return self.tree[state_tuple]
    
    #@lru_cache(maxsize=10000)
    def get_state_tensor(self, state):
        return torch.tensor(board_to_array(state), dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)


In [18]:
#hyperparameters
train_episodes = 100
mcts_search = 200 #100 #400 #600
n_pit_network = 50 #20
threshold = 0.52 #0.50 #0.55
temperature = 0.05 #lower is more deterministic
playgames_before_training = 2 #4 #5 #25
parallel_games = 8
cpuct = 2
training_epochs = 4
learning_rate = 0.0001
save_model_path = 'training'


In [19]:
def get_action_probs(state, mcts, simulations=200):
    """
    Get the action probabilities from the MCTS search for a given board state.
    
    Args:
        state: The initial game state (State object).
        mcts: The MCTS instance used for search.
        simulations: Number of MCTS simulations to run.
    
    Returns:
        action_probs: A numpy array of size 81 representing the action probabilities.
    """
    # Perform MCTS search
    #best_action = mcts.search(state, simulations)
    mcts.search(state, simulations)
    
    print("Done one iteration of MCTS")
    
    # Initialize action probabilities
    action_probs = np.zeros(81)
    
    # Get visit counts for valid actions
    root_node = mcts.get_or_create_node(state)
    #child_Ns = {child.action: child.N for child in root_node.children}
    child_Ns = {get_all_valid_actions(root_node.state)[idx]: child.N for idx, child in enumerate(root_node.children)}
    
    if root_node.N > 0:
        for action, N in child_Ns.items():
            action_probs[action_to_index(action)] = N / root_node.N
    #print(state, action_probs)
    return action_probs


In [20]:
def playgame(mcts, simulations=200):
    """
    Simulate one game of Ultimate Tic-Tac-Toe, utilizing MCTS for decision-making.
    
    Args:
        mcts: The MCTS instance used for decision-making.
        simulations: Number of MCTS simulations per move.
    
    Returns:
        game_mem: A list of game memory (state, player, action probability, game result).
    """
    game_mem = []
    state = State()
    
    while True:
        #print(state)
        # Check if the game is over
        if len(get_all_valid_actions(state)) == 0 or is_terminal(state):
            #need to include terminal state? but what policy to put?
            #remember model predicts utility for current player, so
            #if result==1 (p1 win) and current player==1 then 1*1=1
            #if result==1 (p1 win) and current player==2 then 1*-1=-1
            #if result==-1 (p2 win) and current player==1 then -1*1=-1
            #if result==-1 (p2 win) and current player==2 then -1*-1=1
            result = 2*terminal_utility(state)-1
            for mem in game_mem:
                mem[3] = result * (1 if mem[1]==1 else -1)
            return game_mem

        # Get action probabilities using MCTS
        policy = get_action_probs(state, mcts, simulations)
        policy = policy / np.sum(policy)  # Normalize the policy
        #game_mem.append([board_to_array(state), state.fill_num, policy, None])
        game_mem.append([state, state.fill_num, policy, None])
        
        # Choose an action based on the policy
        action_index = np.random.choice(len(policy), p=policy)
        action = index_to_action(action_index)
        
        #print("Policy:", policy)
        #print("Chosen action:", action)
        
        state = change_state(state, action)


In [21]:
def get_action_probs_batch(states, mcts, simulations):
    """
    Get action probabilities from MCTS for a batch of states.
    
    Args:
        states: List of `State` objects.
        mcts: The MCTS instances used for search.
        simulations: Number of MCTS simulations per move.
    
    Returns:
        action_probs: A NumPy array of shape (num_games, 81).
    """
    num_games = len(states)
    action_probs = np.zeros((num_games, 81))

    # Run MCTS searches in parallel
    with Pool(processes=num_games) as pool:
        best_actions = pool.starmap(mcts.search, [(state, simulations) for state in states])

    print("Done one iteration of MCTS")

    # Compute visit counts for valid actions
    for i, state in enumerate(states):
        root_node = mcts.get_or_create_node(state)
        child_Ns = {get_all_valid_actions(root_node.state)[idx]: child.N for idx, child in enumerate(root_node.children)}

        if root_node.N > 0:
            for action, N in child_Ns.items():
                action_probs[i, action_to_index(action)] = N / root_node.N
    
    return action_probs


def playgames(mctss, num_games=8, simulations=200):
    """
    Run multiple games in parallel using batched MCTS.
    
    Args:
        mctss: List of MCTS instances
        num_games: Number of games to run in parallel.
        simulations: Number of MCTS simulations per move.
    
    Returns:
        games_mem: A list containing `num_games` game memory logs.
    """
    games_mem = [[] for _ in range(num_games)]
    states = [State() for _ in range(num_games)]
    active_games = np.ones(num_games, dtype=bool)  # Track which games are still running

    while active_games.any():
        # Apply actions and check game termination
        for i in range(num_games):
            if active_games[i]:
                if len(get_all_valid_actions(states[i])) == 0 or is_terminal(states[i]):
                    #remember model predicts utility for current player, so
                    #if result==1 (p1 win) and current player==1 then 1*1=1
                    #if result==1 (p1 win) and current player==2 then 1*-1=-1
                    #if result==-1 (p2 win) and current player==1 then -1*1=-1
                    #if result==-1 (p2 win) and current player==2 then -1*-1=1
                    result = 2*terminal_utility(states[i])-1
                    for mem in games_mem[i]:
                        mem[3] = result * (1 if mem[1]==1 else -1)
                    active_games[i] = False  # Mark game as finished
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=num_games) as executor:
            results = executor.map(lambda mcts,state: get_action_probs(state,mcts,simulations), mctss, states)
        
        policies = list(results)
        
        # Get action probabilities using batched MCTS
        #policies = get_action_probs_batch(states, mctss, simulations)
        #policies = np.nan_to_num([policy/np.sum(policy) for policy in policies], nan=1/81)
        
        # Sample actions using the policies
        #action_indices = np.array([np.random.choice(81, p=policy) for policy in policies])
        #actions = [index_to_action(int(idx)) for idx in action_indices]
        
        # Save state, player, policy in memory
        for i in range(num_games):
            if active_games[i]:
                policy = policies[i] / np.sum(policies[i])  # Normalize policy
                #games_mem[i].append([board_to_array(states[i]), states[i].fill_num, policies[i], None])
                #games_mem[i].append([board_to_array(states[i]), states[i].fill_num, policy, None])
                games_mem[i].append([states[i], states[i].fill_num, policy, None])
                action_index = np.random.choice(len(policy), p=policy)
                action = index_to_action(action_index)
                #states[i] = change_state(states[i], actions[i])
                states[i] = change_state(states[i], action)
    
    return games_mem


In [22]:
def test_playgame():
    mcts = MCTS(model, cpuct)
    start_time = time.time()
    p=[]
    for _ in range(parallel_games):
        p.append(playgame(mcts))
    end_time = time.time()
    print(end_time-start_time)

In [23]:
def test_playgames():
    mctss = [MCTS(model, cpuct) for _ in range(parallel_games)]
    start_time=time.time()
    ps=playgames(mctss,parallel_games)
    end_time=time.time()
    print(end_time-start_time)

In [24]:
def train_model(model, games_mem, epochs=4, batch_size=32, lr=0.0001):
    """
    Train the neural network using the game memory (states, policies, and results).
    
    Args:
        model: The neural network model to train.
        games_mem: A list of list of game memories, each containing [state_array, current_player, policy, result].
    
    Returns:
        model: The trained model.
    """
    print("Training Network")
    sumlen =sum([len(game_mem) for game_mem in games_mem])
    print("Length of game_mem:", sumlen)
    
    states = []
    policies = []
    values = []

    # Prepare the training data from game memory
    for game_mem in games_mem:
        for mem in game_mem:
            # Extract state, policy, and result (value)
            states.append(mem[0])  # mem[0] is the board state
            policies.append(mem[2])  # mem[2] is the action policy
            values.append(mem[3])   # mem[3] is the game result (1, 0, or -1)
        
    states = torch.stack([state_to_tensor(state) for state in states])
    policies = torch.tensor(policies, dtype=torch.float32).to(device)
    values = torch.tensor(values, dtype=torch.float32).to(device)
    
    #state = np.array(state)  # Convert the list of states to a numpy array
    #policy = np.array(policy)  # Convert the list of policies to a numpy array
    #value = np.array(value)  # Convert the list of values to a numpy array
        
    #states = torch.tensor(np.array(state), dtype=torch.float32).to(device)
    #policies = torch.tensor(np.array(policy), dtype=torch.float32).to(device)
    #values = torch.tensor(np.array(value), dtype=torch.float32).to(device)

    # Train the neural network on the collected data
    #states = states.unsqueeze(0)  # Add channel dimension, resulting shape: [batch_size, 1, 9, 9]
    optimizer = get_optimizer(model,lr)
    #Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")

        # Shuffle the game memory for each epoch
        indices = torch.randperm(sumlen)
        states = states[indices]
        policies = policies[indices]
        values = values[indices]

        # Mini-batch training
        for i in range(0, sumlen, batch_size):
            batch_states = states[i:i + batch_size]
            batch_policies = policies[i:i + batch_size]
            batch_values = values[i:i + batch_size]

            # Zero the gradients
            #optimizer.zero_grad()
            for param in model.parameters:
                param.grad=None
            
            total_loss = 0
            
            for j in range(batch_states.size(0)):
                # Get the single state for the current iteration
                #state_single = batch_states[j].unsqueeze(0).unsqueeze(1)  # Shape: [1, 1, 9, 9]
                #policy_single = batch_policies[j].unsqueeze(0)  # Shape: [1, 81]
                #value_single = batch_values[j]  # Shape: [1]
                state_single = batch_states[j]
                policy_single = batch_policies[j].unsqueeze(0)
                value_single = batch_values[j].unsqueeze(0).unsqueeze(1)
                
                # Forward pass for the single state
                predicted_policy, predicted_value = model(state_single)
                
                #print(predicted_policy.shape, policy_single.shape)
                #print(predicted_value.shape, value_single.shape)

                # Compute loss for the single state
                policy_loss = compute_policy_loss(predicted_policy, policy_single)
                value_loss = compute_value_loss(predicted_value, value_single)

                # Accumulate the total loss
                total_loss += (policy_loss + value_loss)

            # Backpropagation
            total_loss.backward()

            # Optimizer step (update weights)
            optimizer.step()

        print(f"Epoch {epoch + 1} completed.")
    return model

In [25]:
def pit(old_model, new_model):
    """
    Pits the old neural network (old_model) against the new one (new_model).
    The new network must win at least 52% of games to be accepted.

    Args:
        old_model: The old neural network.
        new_model: The newly trained neural network.

    Returns:
        True if the new network is better (win rate >= 52%), otherwise False.
    """
    print("Pitting networks...")

    old_model_wins = 0
    new_model_wins = 0
    total_games = n_pit_network
    nets = [None, old_model, new_model]

    for game in range(total_games):
        state = State()  # Start with an empty board
        mover = 1 if game%2==0 else -1

        while True:
            # Select which network plays
            net = nets[mover]
            
            # Get action probabilities from the network
            pi, _ = net(state_to_tensor(state))

            # Mask invalid actions
            valid_actions = get_all_valid_actions(state)
            if not valid_actions:
                break  # No valid actions left, game is a tie
   
            valid_mask = np.zeros(81)
            action_idxs = [action_to_index(a) for a in valid_actions]
            valid_mask[action_idxs] = 1
            pi = pi.detach().cpu().numpy().reshape(81) * valid_mask  # Convert to numpy and apply valid mask
            pi = pi / np.sum(pi)  # Normalize the policy to sum to 1
            
            #the problem here is that in inference mode, you should take best move
            #however, the normal alphazero method is to train the model to predict pi, v to improve MCTS
            #then during gameplay to use the best move from the improved MCTS
            #what is done here is instead to do MCTS and then use these to train the model to predict pi, v
            #and to use the best move from the predicted policy
            #since we need to train the model to predict v for use in minimax
            #therefore we cannot use best move here since best move will be deterministic
            # Choose the best move
            #action_index = int(np.argmax(policy))
            #action = torch.argmax(policy).item()
            action_index = np.random.choice(len(pi), p=pi) #or add dirichlet noise and take max
            
            # Execute the move
            #print(state, action_index)
            state = change_state(state, index_to_action(action_index))
            
            if is_terminal(state):
                #print(state)
                if terminal_utility(state) == 1.0:  # Player 1 wins
                    old_model_wins += 1 if game%2==0 else 0
                    new_model_wins += 1 if game%2==1 else 0
                if terminal_utility(state) == 0.0:  # Player 2 wins
                    old_model_wins += 1 if game%2==1 else 0
                    new_model_wins += 1 if game%2==0 else 0
                break  # Game over

            # Switch players            
            mover *= -1

    total_wins = old_model_wins + new_model_wins
    
    if total_wins == 0:
        print("All games ended in a tie.")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeTie{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(old_model.state_dict(), model_path)
        return False

    old_model_win_percent = old_model_wins / total_wins
    new_model_win_percent = new_model_wins / total_wins
    print(f"Old NN win rate: {old_model_win_percent:.2%} ({old_model_wins} / {total_wins} wins)")
    print(f"New NN win rate: {new_model_win_percent:.2%} ({new_model_wins} / {total_wins} wins)")

    if new_model_win_percent >= threshold:
        print("The new network is better!")
        now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
        filename = f'tictactoeWin{now}.pth'
        model_path = os.path.join(save_model_path, filename)
        torch.save(new_model.state_dict(), model_path)
        return True
    else:
        print("The new network lost.")
        return False


In [26]:
#model = UTTT()
#model.to(device)

def train(episodes=100, num_games=8, simulations=200):
    """
    Trains the neural network using self-play, MCTS, and reinforcement learning.
    Saves the best model based on self-play evaluations.
    """
    total_start_time = time.time()
    
    if os.path.isfile('temp.pth'):
        print("Resuming training from last checkpoint")
        model.load_state_dict(torch.load('temp.pth', map_location=device))
    else:
        print("Starting training from scratch")

    print("Starting training...")

    games_mem = []
    
    for episode in range(episodes):
        print(f"Episode {episode + 1}/{episodes}")
        start_time = time.time()
        #print(start_time)
        
        # Save the current model as a temporary model
        torch.save(model.state_dict(), 'temp.pth')
        
        # Load the saved model
        #old_model = nn.__class__()  # Instantiate a new model of the same class
        old_model = UTTT()
        old_model.load_state_dict(torch.load('temp.pth'))
        old_model.to(device)
        
        mctss = [MCTS(model, cpuct) for _ in range(parallel_games)]

        # Self-play to generate training data
        #for _ in range(playgames_before_training):
            #games_mem += playgame()
        games_mem += playgames(mctss, num_games=num_games, simulations=simulations)

        # Train the network with collected data
        train_model(model, games_mem)

        # Clear memory after training
        games_mem = []

        # Compare old vs. new network through self-play
        if pit(old_model, model):
            del old_model
        else:
            # If new NN is worse, revert back to old NN
            #model.load_state_dict(torch.load('temp.pth'))
            model.load_state_dict(old_model.state_dict())
            del old_model
        
        end_time = time.time()
        #print(end_time)
        
        #print((end_time-start_time)/60)
        print(f"Episode {episode + 1} took {(end_time - start_time) / 60} minutes.")

    # Save the final trained model
    now = datetime.utcnow().strftime("%Y-%m-%d_%H-%M-%S")
    filename = f'tictactoe_MCTS{episodes}_{now}.pth'
    model_path = os.path.join(save_model_path, filename)
    torch.save(model.state_dict(), model_path)
    torch.save(model.state_dict(), 'temp.pth')
    print(f"Training complete. Model saved as {filename}")
    
    total_end_time = time.time()
    print(f"Training {episodes} took {(total_end_time - total_start_time) / 3600} hours.")

In [14]:
train(50)

Resuming training from last checkpoint
Starting training...
Episode 1/20


KeyboardInterrupt: 

In [35]:
inference_model = UTTT()
inference_model.load_state_dict(torch.load('temp.pth'))
inference_model.to(device)

UTTT(
  (conv1): Conv2d(1, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (res2): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (res3): ResidualBlock(
    (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  )
  (fc1): Lin

In [36]:
def get_move(state, model):
    valid_actions = get_all_valid_actions(state)
    ev,best_action = float('inf'),None
    for action in valid_actions:
        next_state=change_state(state,action)
        _,v=model(state_to_tensor(next_state)) #this gives v from next player POV due to flip #v measures how good for the CURRENT PLAYER
        if v<ev: ev,best_action = v,action
    return best_action
            

In [38]:
class StudentAgent:
    def __init__(self):
        """Instantiates your agent.
        """

    def choose_action(self, state: State) -> Action:
        """Returns a valid action to be played on the board.
        Assuming that you are filling in the board with number 1.

        Parameters
        ---------------
        state: The board to make a move on.
        """
        #return best_move(state, 1, inference_model)
        return get_move(state, inference_model)

# Use this cell to test your agent in two full games against a random agent.
# The random agent will choose actions randomly among the valid actions.

class RandomStudentAgent(StudentAgent):
    def choose_action(self, state: State) -> Action:
        # If you're using an existing Player 1 agent, you may need to invert the state
        # to have it play as Player 2. Uncomment the next line to invert the state.
        # state = state.invert()

        # Choose a random valid action from the current game state
        return state.get_random_valid_action()

def run(your_agent: StudentAgent, opponent_agent: StudentAgent, start_num: int):
    your_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    opponent_agent_stats = {"timeout_count": 0, "invalid_count": 0}
    turn_count = 0
    
    state = State(fill_num=start_num)
    
    while not state.is_terminal():
        #print(state)
        turn_count += 1

        agent_name = "your_agent" if state.fill_num == 1 else "opponent_agent"
        agent = your_agent if state.fill_num == 1 else opponent_agent
        stats = your_agent_stats if state.fill_num == 1 else opponent_agent_stats

        start_time = time.time()
        action = agent.choose_action(state.clone())
        end_time = time.time()
        
        random_action = state.get_random_valid_action()
        if end_time - start_time > 3:
            print(f"{agent_name} timed out!")
            stats["timeout_count"] += 1
            action = random_action
        if not state.is_valid_action(action):
            print(f"{agent_name} made an invalid action!")
            stats["invalid_count"] += 1
            action = random_action
                
        #print(action)
        print(inference_model(state_to_tensor(state)))
        state = state.change_state(action)

    print(f"== {your_agent.__class__.__name__} (1) vs {opponent_agent.__class__.__name__} (2) - First Player: {start_num} ==")
        
    if state.terminal_utility() == 1:
        print("You win!")
    elif state.terminal_utility() == 0:
        print("You lose!")
    else:
        print("Draw")

    for agent_name, stats in [("your_agent", your_agent_stats), ("opponent_agent", opponent_agent_stats)]:
        print(f"{agent_name} statistics:")
        print(f"Timeout count: {stats['timeout_count']}")
        print(f"Invalid count: {stats['invalid_count']}")
        
    print(f"Turn count: {turn_count}\n")
    #print(state)

your_agent = lambda: StudentAgent()
opponent_agent = lambda: RandomStudentAgent()

run(your_agent(), opponent_agent(), 1)
run(your_agent(), opponent_agent(), 2)

(tensor([[9.4793e-06, 3.2631e-04, 3.0510e-05, 7.2048e-03, 5.4976e-05, 1.2403e-02,
         6.2163e-05, 3.5673e-02, 6.2828e-02, 5.7124e-05, 3.9309e-06, 5.4976e-03,
         1.7686e-04, 2.3642e-03, 6.4218e-03, 9.1988e-04, 6.8387e-06, 3.5382e-02,
         3.2539e-05, 9.5501e-03, 3.5474e-05, 5.0363e-06, 1.6491e-03, 3.9291e-03,
         5.5464e-04, 4.8744e-04, 3.0758e-02, 1.2322e-05, 2.8126e-04, 8.7949e-05,
         7.6690e-05, 5.6819e-02, 1.8095e-04, 1.2417e-03, 2.1750e-02, 6.2042e-02,
         4.2854e-03, 1.4375e-05, 2.0925e-05, 8.4585e-02, 5.8950e-02, 7.7374e-03,
         1.6837e-03, 2.1506e-02, 1.0092e-04, 7.8711e-02, 3.1605e-04, 1.5606e-03,
         5.4704e-02, 2.3885e-06, 4.3532e-04, 3.5001e-02, 1.0134e-04, 1.0264e-03,
         2.7607e-05, 4.7374e-05, 1.3499e-02, 1.8256e-03, 2.3566e-02, 4.7220e-03,
         1.7754e-03, 5.3280e-02, 1.4972e-04, 5.1974e-02, 8.0774e-04, 6.0268e-06,
         1.5922e-02, 1.6171e-05, 1.7093e-03, 3.0057e-02, 7.0029e-04, 3.2271e-03,
         2.3708e-04, 2.2771

(tensor([[1.3705e-07, 2.2298e-05, 5.9701e-07, 6.6502e-03, 2.2418e-06, 7.2853e-03,
         2.8680e-06, 3.1006e-02, 9.3215e-02, 3.5768e-06, 4.0729e-08, 2.7073e-03,
         5.9405e-06, 2.9994e-04, 3.5471e-03, 5.6607e-05, 9.4456e-08, 5.2030e-02,
         6.1024e-07, 6.2822e-03, 1.0846e-06, 5.6613e-08, 3.3132e-04, 1.1459e-03,
         8.3484e-05, 7.3528e-05, 2.7069e-02, 3.1828e-07, 2.1982e-05, 3.8426e-06,
         2.6368e-06, 5.0909e-02, 9.2492e-06, 1.4531e-04, 1.6383e-02, 8.5444e-02,
         2.4794e-03, 2.5898e-07, 3.2783e-07, 9.7356e-02, 8.3853e-02, 2.3615e-03,
         2.1213e-04, 2.7746e-02, 8.5661e-06, 1.3462e-01, 3.5564e-05, 5.3804e-04,
         3.0421e-02, 2.3058e-08, 3.6016e-05, 2.3047e-02, 6.1839e-06, 1.6647e-04,
         1.2552e-06, 1.9623e-06, 9.0346e-03, 2.7838e-04, 1.8606e-02, 6.4127e-04,
         4.8978e-04, 5.8440e-02, 6.4809e-06, 2.1101e-02, 1.5566e-04, 1.2385e-07,
         6.6289e-03, 5.1969e-07, 3.1415e-04, 1.5309e-02, 6.6922e-05, 1.4611e-03,
         1.6837e-05, 6.2330

(tensor([[5.4943e-07, 1.4166e-05, 1.1206e-06, 1.5631e-03, 1.2977e-05, 6.4233e-03,
         1.2474e-05, 9.5695e-03, 1.8055e-01, 2.1042e-06, 2.3686e-07, 3.2661e-03,
         3.0154e-05, 1.9625e-03, 3.2956e-03, 1.3939e-04, 3.5477e-07, 3.2080e-02,
         1.4461e-06, 1.6143e-02, 1.6182e-06, 6.7382e-08, 3.7785e-04, 2.6791e-03,
         1.4945e-04, 1.1786e-04, 1.5263e-02, 1.2254e-06, 2.2752e-05, 6.3550e-06,
         2.9669e-06, 2.6657e-02, 1.6393e-05, 5.2627e-04, 1.0050e-02, 1.9268e-02,
         1.4795e-03, 1.8093e-07, 1.3842e-06, 1.9663e-01, 3.7487e-02, 3.4976e-03,
         4.2445e-04, 2.2297e-02, 3.3875e-05, 1.2792e-01, 4.2079e-05, 6.1559e-04,
         2.4862e-02, 9.0260e-08, 3.8493e-04, 2.8675e-02, 7.6108e-06, 1.7594e-04,
         6.2905e-07, 4.7474e-06, 5.8980e-03, 1.8998e-04, 1.1845e-02, 2.1314e-03,
         9.2285e-04, 2.7705e-02, 1.9664e-05, 9.3245e-02, 3.3245e-04, 6.0005e-07,
         1.5557e-02, 6.3960e-07, 1.0639e-03, 2.3089e-02, 5.3937e-05, 2.3465e-03,
         1.4607e-05, 6.7315

(tensor([[3.2732e-07, 1.0149e-05, 6.9700e-07, 1.2734e-03, 8.8505e-06, 5.5716e-03,
         8.5533e-06, 8.3866e-03, 1.8203e-01, 1.2963e-06, 1.4302e-07, 2.9640e-03,
         2.3655e-05, 1.8017e-03, 2.8216e-03, 1.1179e-04, 2.0959e-07, 3.4044e-02,
         8.7547e-07, 1.5152e-02, 1.0210e-06, 3.5531e-08, 3.1569e-04, 2.3720e-03,
         1.1199e-04, 9.5848e-05, 1.3889e-02, 8.1597e-07, 1.6497e-05, 3.9292e-06,
         1.8560e-06, 2.7134e-02, 1.1939e-05, 4.3699e-04, 9.0919e-03, 1.8172e-02,
         1.2607e-03, 1.0348e-07, 8.8453e-07, 2.0444e-01, 3.6758e-02, 3.0942e-03,
         3.7694e-04, 2.1176e-02, 2.5106e-05, 1.3079e-01, 3.2629e-05, 5.3443e-04,
         2.6647e-02, 4.7312e-08, 3.1878e-04, 2.8262e-02, 4.7880e-06, 1.3937e-04,
         4.1675e-07, 3.1624e-06, 5.4001e-03, 1.3821e-04, 1.1288e-02, 1.8499e-03,
         7.2577e-04, 2.5942e-02, 1.3969e-05, 9.7989e-02, 2.6197e-04, 3.8247e-07,
         1.5817e-02, 4.1235e-07, 9.6710e-04, 2.0995e-02, 3.7922e-05, 2.0653e-03,
         1.0582e-05, 5.4725

(tensor([[1.7146e-07, 6.9741e-06, 4.1849e-07, 1.0303e-03, 5.7473e-06, 4.3842e-03,
         5.2569e-06, 7.3733e-03, 1.7793e-01, 8.0494e-07, 7.6111e-08, 2.7988e-03,
         1.6100e-05, 1.6084e-03, 2.2401e-03, 8.6352e-05, 1.3121e-07, 2.9934e-02,
         5.2043e-07, 1.3932e-02, 6.1331e-07, 1.7994e-08, 2.5743e-04, 1.8214e-03,
         7.9778e-05, 7.1744e-05, 1.2746e-02, 5.1447e-07, 1.0588e-05, 2.2730e-06,
         1.1186e-06, 2.4381e-02, 7.0756e-06, 3.2641e-04, 7.7602e-03, 1.5964e-02,
         9.9820e-04, 5.0428e-08, 5.2991e-07, 2.2610e-01, 3.3086e-02, 2.9798e-03,
         3.2568e-04, 1.7443e-02, 1.7806e-05, 1.4166e-01, 2.2545e-05, 3.8610e-04,
         2.9209e-02, 2.2768e-08, 2.5156e-04, 2.6892e-02, 2.8615e-06, 9.9760e-05,
         2.5300e-07, 2.0183e-06, 4.5569e-03, 9.9232e-05, 1.1595e-02, 1.4991e-03,
         5.3751e-04, 2.8394e-02, 9.3248e-06, 1.0177e-01, 1.8712e-04, 2.2207e-07,
         1.4018e-02, 2.4541e-07, 7.8018e-04, 1.7221e-02, 2.8913e-05, 1.7155e-03,
         7.1006e-06, 4.8832

(tensor([[1.8491e-07, 7.1640e-06, 4.2636e-07, 1.0102e-03, 6.2812e-06, 4.8380e-03,
         5.7500e-06, 7.2116e-03, 1.7967e-01, 8.0329e-07, 7.5377e-08, 2.7540e-03,
         1.5759e-05, 1.7291e-03, 2.3249e-03, 8.7205e-05, 1.3685e-07, 3.2671e-02,
         5.8623e-07, 1.4786e-02, 6.0227e-07, 1.9547e-08, 2.7262e-04, 1.8523e-03,
         8.3420e-05, 7.6199e-05, 1.3381e-02, 5.2080e-07, 1.0771e-05, 2.3997e-06,
         1.1716e-06, 2.5153e-02, 7.3204e-06, 3.2142e-04, 7.9796e-03, 1.7700e-02,
         1.0727e-03, 5.5628e-08, 5.6333e-07, 2.1450e-01, 3.5281e-02, 2.6432e-03,
         3.1940e-04, 1.8078e-02, 1.8036e-05, 1.4534e-01, 2.4735e-05, 4.3366e-04,
         2.7644e-02, 2.5164e-08, 2.4566e-04, 2.5880e-02, 2.7676e-06, 9.5815e-05,
         2.5624e-07, 2.1358e-06, 4.4239e-03, 9.9486e-05, 1.1075e-02, 1.6343e-03,
         5.6799e-04, 2.6842e-02, 9.7504e-06, 1.0210e-01, 2.0510e-04, 2.2142e-07,
         1.4203e-02, 2.3801e-07, 8.1514e-04, 1.7679e-02, 2.8970e-05, 1.7211e-03,
         7.1219e-06, 4.7024

(tensor([[1.8203e-07, 7.1129e-06, 4.3999e-07, 1.0065e-03, 6.1767e-06, 5.0504e-03,
         5.6011e-06, 7.2050e-03, 1.8759e-01, 7.9983e-07, 7.0854e-08, 2.7158e-03,
         1.4581e-05, 1.6715e-03, 2.2907e-03, 8.5192e-05, 1.3219e-07, 3.1405e-02,
         5.6747e-07, 1.4539e-02, 5.6164e-07, 1.9770e-08, 2.7560e-04, 1.8180e-03,
         8.3618e-05, 7.7132e-05, 1.3850e-02, 4.8899e-07, 1.0487e-05, 2.3168e-06,
         1.1750e-06, 2.5689e-02, 6.7920e-06, 3.2659e-04, 8.0642e-03, 1.7517e-02,
         1.0357e-03, 5.4443e-08, 5.6330e-07, 2.1084e-01, 3.5247e-02, 2.6916e-03,
         3.0268e-04, 1.8377e-02, 1.7863e-05, 1.4283e-01, 2.3086e-05, 4.2260e-04,
         2.6970e-02, 2.4876e-08, 2.2744e-04, 2.5114e-02, 2.7231e-06, 8.9399e-05,
         2.4899e-07, 2.0934e-06, 4.2550e-03, 9.7890e-05, 1.0904e-02, 1.5916e-03,
         5.6413e-04, 2.5347e-02, 8.8388e-06, 1.0472e-01, 1.9591e-04, 2.2242e-07,
         1.3937e-02, 2.2108e-07, 7.9917e-04, 1.7695e-02, 2.8425e-05, 1.7273e-03,
         6.7147e-06, 4.5304

(tensor([[1.8441e-06, 6.0181e-05, 7.0787e-06, 4.9053e-03, 2.7891e-05, 1.0298e-02,
         3.4567e-05, 2.1742e-02, 9.4709e-02, 1.1017e-05, 9.6297e-07, 4.8633e-03,
         6.2557e-05, 2.3093e-03, 5.5804e-03, 3.7153e-04, 1.7107e-06, 4.3531e-02,
         6.2283e-06, 1.3310e-02, 8.2327e-06, 7.3103e-07, 8.4438e-04, 3.2849e-03,
         3.2369e-04, 2.8010e-04, 2.2606e-02, 5.3322e-06, 8.3343e-05, 2.7104e-05,
         1.8725e-05, 3.9190e-02, 5.6793e-05, 8.4537e-04, 1.6353e-02, 4.0199e-02,
         3.5844e-03, 1.5577e-06, 6.1715e-06, 1.4081e-01, 5.4549e-02, 5.1229e-03,
         8.5187e-04, 2.4638e-02, 7.8763e-05, 1.0639e-01, 1.1252e-04, 1.0799e-03,
         3.7970e-02, 5.1982e-07, 5.2052e-04, 4.0160e-02, 2.8515e-05, 4.6332e-04,
         3.5256e-06, 1.5872e-05, 1.0225e-02, 6.1494e-04, 1.9897e-02, 3.3100e-03,
         1.4999e-03, 4.4439e-02, 6.1772e-05, 6.1468e-02, 6.3683e-04, 2.4568e-06,
         1.6056e-02, 3.4609e-06, 1.4652e-03, 2.3883e-02, 2.1987e-04, 3.3197e-03,
         6.3808e-05, 1.1235

(tensor([[1.1175e-06, 3.6544e-05, 3.7083e-06, 3.1671e-03, 2.0980e-05, 9.3233e-03,
         2.4268e-05, 1.7304e-02, 1.1187e-01, 5.2013e-06, 6.1797e-07, 4.1337e-03,
         5.0879e-05, 2.1830e-03, 4.7398e-03, 2.9945e-04, 9.1990e-07, 3.8223e-02,
         3.3223e-06, 1.4035e-02, 4.8362e-06, 2.8034e-07, 6.1713e-04, 3.0579e-03,
         2.3763e-04, 1.9297e-04, 1.7584e-02, 3.1568e-06, 5.3493e-05, 1.4844e-05,
         1.0393e-05, 3.4082e-02, 3.7007e-05, 7.4952e-04, 1.4944e-02, 2.7328e-02,
         2.4304e-03, 7.8999e-07, 3.5064e-06, 1.6917e-01, 5.0022e-02, 4.7598e-03,
         7.0757e-04, 2.2769e-02, 5.0471e-05, 1.2192e-01, 6.7214e-05, 7.8983e-04,
         3.4821e-02, 2.8073e-07, 4.8265e-04, 3.5095e-02, 1.8046e-05, 3.9798e-04,
         1.6986e-06, 9.4195e-06, 8.5169e-03, 4.2816e-04, 1.5731e-02, 2.8862e-03,
         1.3028e-03, 4.2763e-02, 4.2428e-05, 7.1683e-02, 5.2316e-04, 1.4762e-06,
         1.6900e-02, 1.6807e-06, 1.2776e-03, 2.5805e-02, 1.3871e-04, 2.9715e-03,
         3.6999e-05, 8.9844

(tensor([[5.3025e-07, 2.1537e-05, 1.9442e-06, 2.6360e-03, 1.1903e-05, 5.9657e-03,
         1.3212e-05, 1.6191e-02, 1.2154e-01, 3.1253e-06, 2.8106e-07, 4.0691e-03,
         4.2553e-05, 1.9673e-03, 3.7752e-03, 2.3467e-04, 4.2643e-07, 3.4654e-02,
         1.8994e-06, 1.1357e-02, 2.9078e-06, 1.2212e-07, 3.8606e-04, 2.4654e-03,
         1.4707e-04, 1.2792e-04, 1.9018e-02, 1.7376e-06, 3.0264e-05, 9.9168e-06,
         5.1021e-06, 2.9886e-02, 2.2643e-05, 6.2735e-04, 1.1012e-02, 2.7142e-02,
         1.9119e-03, 3.4222e-07, 1.8714e-06, 1.9379e-01, 4.1168e-02, 3.8966e-03,
         4.9675e-04, 2.0944e-02, 3.1652e-05, 1.3762e-01, 4.7151e-05, 5.7540e-04,
         4.1050e-02, 1.1295e-07, 3.9154e-04, 3.2895e-02, 1.0333e-05, 2.4765e-04,
         9.1917e-07, 5.0383e-06, 7.2026e-03, 3.1084e-04, 1.5152e-02, 2.5229e-03,
         9.9453e-04, 4.1101e-02, 3.3125e-05, 6.9101e-02, 3.5519e-04, 6.2863e-07,
         1.5031e-02, 1.0178e-06, 9.7981e-04, 2.3906e-02, 9.6615e-05, 2.5776e-03,
         2.6415e-05, 7.6440

(tensor([[3.5058e-07, 1.3610e-05, 1.1063e-06, 1.8258e-03, 9.6487e-06, 4.8715e-03,
         1.0848e-05, 1.2469e-02, 1.3637e-01, 1.9006e-06, 2.0801e-07, 3.5255e-03,
         3.7206e-05, 2.0305e-03, 3.3266e-03, 1.6405e-04, 2.8250e-07, 3.8094e-02,
         1.2380e-06, 1.2532e-02, 1.9231e-06, 6.2348e-08, 2.9924e-04, 2.3810e-03,
         1.1660e-04, 9.6227e-05, 1.6228e-02, 1.2166e-06, 2.1359e-05, 6.2380e-06,
         2.7784e-06, 2.5742e-02, 1.6727e-05, 5.0420e-04, 9.1715e-03, 2.3038e-02,
         1.5421e-03, 1.6740e-07, 1.2532e-06, 2.0988e-01, 3.8019e-02, 3.3637e-03,
         4.2048e-04, 2.1031e-02, 2.5989e-05, 1.4356e-01, 3.8033e-05, 5.0294e-04,
         3.5086e-02, 6.6785e-08, 3.7027e-04, 2.9740e-02, 6.6154e-06, 1.9424e-04,
         5.6296e-07, 3.6561e-06, 6.4320e-03, 2.1392e-04, 1.3294e-02, 2.2600e-03,
         8.6631e-04, 4.0832e-02, 2.3727e-05, 7.5267e-02, 2.9841e-04, 3.8537e-07,
         1.5378e-02, 6.4881e-07, 9.0027e-04, 2.1477e-02, 6.0327e-05, 2.2313e-03,
         1.7931e-05, 6.7527

(tensor([[3.0690e-07, 1.1365e-05, 8.0618e-07, 1.4952e-03, 8.1518e-06, 4.6241e-03,
         9.3955e-06, 1.0844e-02, 1.5695e-01, 1.4952e-06, 1.7356e-07, 3.2567e-03,
         3.3054e-05, 1.8959e-03, 3.2158e-03, 1.3193e-04, 2.1211e-07, 3.7870e-02,
         9.8662e-07, 1.2613e-02, 1.5405e-06, 4.1113e-08, 2.5113e-04, 2.3261e-03,
         9.9422e-05, 8.2597e-05, 1.5368e-02, 8.9703e-07, 1.7893e-05, 4.7660e-06,
         2.0839e-06, 2.5035e-02, 1.4003e-05, 4.6836e-04, 9.0595e-03, 2.0708e-02,
         1.3215e-03, 1.2500e-07, 1.0357e-06, 2.1863e-01, 3.7380e-02, 3.1809e-03,
         3.7039e-04, 2.1057e-02, 2.1931e-05, 1.3646e-01, 3.3215e-05, 5.0089e-04,
         3.0643e-02, 5.0812e-08, 3.2174e-04, 2.8218e-02, 5.5796e-06, 1.7524e-04,
         4.5854e-07, 3.0447e-06, 5.8355e-03, 1.8096e-04, 1.2950e-02, 1.9952e-03,
         8.3380e-04, 3.5865e-02, 1.9043e-05, 7.7228e-02, 2.7459e-04, 3.2708e-07,
         1.5723e-02, 5.0632e-07, 8.5461e-04, 2.1209e-02, 4.7659e-05, 2.1016e-03,
         1.3466e-05, 5.9128

(tensor([[3.1751e-07, 1.1985e-05, 8.7412e-07, 1.5557e-03, 8.7975e-06, 5.0848e-03,
         9.9708e-06, 1.0417e-02, 1.5463e-01, 1.5562e-06, 1.8824e-07, 3.3794e-03,
         3.1935e-05, 1.9029e-03, 3.4183e-03, 1.4764e-04, 2.4378e-07, 3.6468e-02,
         1.0465e-06, 1.3557e-02, 1.5977e-06, 4.4003e-08, 2.8121e-04, 2.2606e-03,
         1.0755e-04, 9.4816e-05, 1.5547e-02, 9.6348e-07, 1.8434e-05, 5.0355e-06,
         2.2790e-06, 2.6118e-02, 1.3876e-05, 4.8500e-04, 9.0202e-03, 2.1407e-02,
         1.3563e-03, 1.3136e-07, 1.0689e-06, 2.0769e-01, 3.7452e-02, 3.4385e-03,
         3.9666e-04, 2.0971e-02, 2.3985e-05, 1.3592e-01, 3.4190e-05, 5.2008e-04,
         3.2306e-02, 5.3164e-08, 3.3032e-04, 3.0323e-02, 5.7940e-06, 1.7598e-04,
         4.9183e-07, 3.2187e-06, 5.9858e-03, 1.8949e-04, 1.2443e-02, 2.0713e-03,
         8.4082e-04, 3.6649e-02, 1.9098e-05, 8.3112e-02, 2.7560e-04, 3.7207e-07,
         1.5910e-02, 5.1294e-07, 8.6423e-04, 2.1062e-02, 5.0381e-05, 2.0743e-03,
         1.4350e-05, 5.6182

In [37]:
s=State()
print(inference_model(state_to_tensor(s)))
s1=change_state(s,(0,0,0,0))
s2=change_state(s,(1,1,1,1))

print(inference_model(state_to_tensor(s1)))

print(inference_model(state_to_tensor(s2)))

s3 = torch.tensor(np.ones((9,9)),dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)
print(inference_model(s3))


s3 = torch.tensor(np.identity(9),dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)
print(inference_model(s3))

(tensor([[9.4793e-06, 3.2631e-04, 3.0510e-05, 7.2048e-03, 5.4976e-05, 1.2403e-02,
         6.2163e-05, 3.5673e-02, 6.2828e-02, 5.7124e-05, 3.9309e-06, 5.4976e-03,
         1.7686e-04, 2.3642e-03, 6.4218e-03, 9.1988e-04, 6.8387e-06, 3.5382e-02,
         3.2539e-05, 9.5501e-03, 3.5474e-05, 5.0363e-06, 1.6491e-03, 3.9291e-03,
         5.5464e-04, 4.8744e-04, 3.0758e-02, 1.2322e-05, 2.8126e-04, 8.7949e-05,
         7.6690e-05, 5.6819e-02, 1.8095e-04, 1.2417e-03, 2.1750e-02, 6.2042e-02,
         4.2854e-03, 1.4375e-05, 2.0925e-05, 8.4585e-02, 5.8950e-02, 7.7374e-03,
         1.6837e-03, 2.1506e-02, 1.0092e-04, 7.8711e-02, 3.1605e-04, 1.5606e-03,
         5.4704e-02, 2.3885e-06, 4.3532e-04, 3.5001e-02, 1.0134e-04, 1.0264e-03,
         2.7607e-05, 4.7374e-05, 1.3499e-02, 1.8256e-03, 2.3566e-02, 4.7220e-03,
         1.7754e-03, 5.3280e-02, 1.4972e-04, 5.1974e-02, 8.0774e-04, 6.0268e-06,
         1.5922e-02, 1.6171e-05, 1.7093e-03, 3.0057e-02, 7.0029e-04, 3.2271e-03,
         2.3708e-04, 2.2771

In [34]:
inference_model=UTTT()
s=State()
print(inference_model(state_to_tensor(s)))
s1=change_state(s,(0,0,0,0))
s2=change_state(s,(1,1,1,1))

print(inference_model(state_to_tensor(s1)))

print(inference_model(state_to_tensor(s2)))

s3 = torch.tensor(np.ones((9,9)),dtype=torch.float32).unsqueeze(0).unsqueeze(1).to(device)
print(inference_model(s3))

(tensor([[0.0075, 0.0089, 0.0076, 0.0097, 0.0064, 0.0301, 0.0132, 0.0142, 0.0232,
         0.0073, 0.0046, 0.0184, 0.0207, 0.0106, 0.0129, 0.0115, 0.0052, 0.0167,
         0.0099, 0.0173, 0.0065, 0.0034, 0.0077, 0.0069, 0.0080, 0.0060, 0.0095,
         0.0078, 0.0132, 0.0083, 0.0074, 0.0180, 0.0073, 0.0109, 0.0168, 0.0231,
         0.0131, 0.0066, 0.0053, 0.0220, 0.0129, 0.0129, 0.0117, 0.0150, 0.0079,
         0.0225, 0.0107, 0.0134, 0.0141, 0.0042, 0.0097, 0.0150, 0.0080, 0.0108,
         0.0055, 0.0095, 0.0232, 0.0098, 0.0244, 0.0069, 0.0176, 0.0161, 0.0114,
         0.0302, 0.0082, 0.0067, 0.0172, 0.0077, 0.0125, 0.0283, 0.0149, 0.0194,
         0.0061, 0.0097, 0.0246, 0.0129, 0.0108, 0.0050, 0.0082, 0.0083, 0.0124]],
       grad_fn=<SoftmaxBackward0>), tensor([[-0.2573]], grad_fn=<TanhBackward0>))
(tensor([[0.0077, 0.0098, 0.0076, 0.0131, 0.0075, 0.0247, 0.0146, 0.0131, 0.0246,
         0.0085, 0.0046, 0.0139, 0.0181, 0.0098, 0.0131, 0.0127, 0.0063, 0.0160,
         0.0124, 0.0151

In [98]:
import json
torch.set_printoptions(precision=10,threshold=2000)
model = UTTT()  # Your model
state_dict = model.state_dict()
print(sum(p.numel() for p in model.parameters()))
print(state_dict)

6349010
OrderedDict([('conv1.weight', tensor([[[[ 3.1599280238e-01, -1.7453873158e-01, -1.5413686633e-01],
          [ 1.9262206554e-01, -2.7660986781e-01,  1.9655749202e-01],
          [-1.9012737274e-01, -9.0148687363e-02,  3.2760399580e-01]]],


        [[[ 2.5253325701e-01, -2.2546005249e-01, -2.3694178462e-01],
          [ 3.9737224579e-02, -7.8327856958e-02, -1.1168801785e-01],
          [-1.7405840755e-01,  2.2267961502e-01,  2.9447036982e-01]]],


        [[[-1.6848883033e-01, -1.7537355423e-02,  2.0572642982e-01],
          [ 2.1543502808e-02, -1.2507376075e-01,  5.1644644700e-03],
          [-2.2771283984e-01, -1.4232861996e-01,  8.6099468172e-02]]],


        [[[ 8.3433628082e-02,  2.2211501002e-01,  1.9809398055e-01],
          [ 9.6673093736e-02,  2.9852300882e-01,  5.6699715555e-02],
          [-2.6685258746e-01, -2.0141530037e-01, -1.8569605052e-01]]],


        [[[-8.7623476982e-02,  3.0630892515e-01,  1.6581630707e-01],
          [-2.1793437004e-01, -1.1966534704e-01, 

In [102]:
#torch.set_printoptions(precision=10,threshold=float("inf"))

model=UTTT(False,4,"MIN")
print(sum(p.numel() for p in model.parameters()))
print(model)
state_dict = model.state_dict()
print(state_dict)

94330
UTTT(
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(8, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (conv2): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn2): BatchNorm2d(8, eps=1e-05, momentum=0.09, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=648, out_features=128, bias=True)
  (pi): Linear(in_features=128, out_features=81, bias=True)
  (v): Linear(in_features=128, out_features=1, bias=True)
)
OrderedDict([('conv1.weight', tensor([[[[ 0.2038275450, -0.1400500238, -0.0253705587],
          [ 0.0976642817,  0.2959917784, -0.2481583804],
          [-0.2956174612,  0.1206390485, -0.0556780919]]],


        [[[-0.1393942535, -0.2853843868, -0.1149618253],
          [ 0.1874097586, -0.0363383293,  0.0347592458],
          [ 0.0054167113,  0.0915203542, -0.2663037479]]],


        [[[ 0.3215827346, -0.3106656969,  0.1523028612],
          [ 0.303

In [68]:
state_dict_serializable = {k: v.tolist() for k, v in state_dict.items()}  
with open("model_weights.json", "w") as f:
    json.dump(state_dict_serializable, f)