In [37]:
def check_win(board, player):
    """
    Check if the player can win in the next move and return all such moves.
    """
    winning_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == player and board[c] == ' ':
            winning_moves.append(c)
        if board[a] == board[c] == player and board[b] == ' ':
            winning_moves.append(b)
        if board[b] == board[c] == player and board[a] == ' ':
            winning_moves.append(a)
    return winning_moves

def check_block(board, player):
    """
    Check if the opponent can win in the next move and return all blocking moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    return check_win(board, opponent)

def is_unblocked_line(board, player, a, b, c):
    """
    Check if a line (defined by indices a, b, c) is unblocked for the given player.
    """
    return ((board[a] == board[b] == player and board[c] == ' ') or
            (board[a] == board[c] == player and board[b] == ' ') or
            (board[b] == board[c] == player and board[a] == ' '))

def check_fork(board, player):
    """
    Check if the player can create a fork and return all such moves.
    """
    fork_moves = []
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for i in range(9):
        if board[i] == ' ':
            board[i] = player  # Temporarily place the player's marker
            unblocked_lines = 0  # Count of unblocked lines
            for a, b, c in winning_combinations:
                if is_unblocked_line(board, player, a, b, c):
                    unblocked_lines += 1
            if unblocked_lines >= 2:
                fork_moves.append(i)
            board[i] = ' '  # Reset the board
    return fork_moves

def is_two_in_a_row(board, player, a, b, c):
    return ((board[a] == player and board[b] == player and board[c] == ' ') or
            (board[a] == player and board[c] == player and board[b] == ' ') or
            (board[b] == player and board[c] == player and board[a] == ' '))

# Update the check_block_fork function to include the additional logic
def check_block_fork(board, player):
    opponent = 'O' if player == 'X' else 'X'
    forks = check_fork(board, opponent)

    # Check if opponent has two opposite corners
    if board[0] == board[8] == opponent:
        if board[2] == ' ' or board[6] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn
    if board[2] == board[6] == opponent:
        if board[0] == ' ' or board[8] == ' ':
            to_rtn = []
            for i in [1, 3, 5, 7]:
                if board[i] == ' ': to_rtn.append(i)
            return to_rtn

    if len(forks) > 1:
        two_in_a_row_moves = []
        winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                                (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                                (0, 4, 8), (2, 4, 6)]  # diagonals
        for i in range(9):
            if board[i] == ' ':
                board[i] = player  # Temporarily place the player's marker
                for a, b, c in winning_combinations:
                    if is_two_in_a_row(board, player, a, b, c):
                        two_in_a_row_moves.append(i)
                        break  # No need to check further combinations for this move
                board[i] = ' '  # Reset the board
        # Return intersection of two_in_a_row_moves and forks
        return list(set(two_in_a_row_moves) & set(forks))
    if len(forks) == 1:
        return forks
    return []

def check_center(board):
    """
    Check if the center is free and return it as a move if it is.
    """
    return [4] if board[4] == ' ' else []

def check_opposite_corner(board, player):
    """
    Check if the opponent is in the corner, and the opposite corner is free, and return all such moves.
    """
    opponent = 'O' if player == 'X' else 'X'
    opposite_corners = [(0, 8), (2, 6), (6, 2), (8, 0)]
    moves = []
    for a, b in opposite_corners:
        if board[a] == opponent and board[b] == ' ':
            moves.append(b)
        if board[b] == opponent and board[a] == ' ':
            moves.append(a)
    return moves

def check_empty_corner(board):
    """
    Check for any empty corners and return all such moves.
    """
    corners = [0, 2, 6, 8]
    return [corner for corner in corners if board[corner] == ' ']

def check_empty_side(board):
    """
    Check for any empty sides and return all such moves.
    """
    sides = [1, 3, 5, 7]
    return [side for side in sides if board[side] == ' ']

# Update Main Function

def get_optimal_moves(board, player):
    """
    Get all optimal moves for the given board and player.
    """
    for check in [check_win, check_block, check_fork, check_block_fork]:
        moves = check(board, player)
        if moves:
            return moves

    for check in [check_center, check_empty_corner, check_empty_side]:
        moves = check(board)
        if moves:
            return moves

    for check in [check_opposite_corner]:
        moves = check(board, player)
        if moves:
            return moves

    return []  # Should never reach this point in a valid game of Tic-Tac-Toe

In [38]:
# Initialize list to store sequences of all finished games
finished_games = []

def is_winner(board, player):
    """
    Check if the player has won on the current board.
    """
    winning_combinations = [(0, 1, 2), (3, 4, 5), (6, 7, 8),  # rows
                            (0, 3, 6), (1, 4, 7), (2, 5, 8),  # columns
                            (0, 4, 8), (2, 4, 6)]  # diagonals
    for a, b, c in winning_combinations:
        if board[a] == board[b] == board[c] == player:
            return True
    return False

def simulate_game(board, move_sequence, next_player):
    """
    Simulate a game of Tic-Tac-Toe recursively.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
        next_player (str): The player to move next ('X' or 'O').
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        finished_games.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games.append(move_sequence[:])
        return
    
    # Optimal player's move
    if next_player == 'X':
        optimal_moves = get_optimal_moves(board, next_player)
        for move in optimal_moves:
            board[move] = next_player  # Make the move
            move_sequence.append(move)  # Record the move
            simulate_game(board, move_sequence, 'O')  # Recursive call
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence
    
    # All moves for the non-optimal player
    else:
        for move in range(9):
            if board[move] == ' ':
                board[move] = next_player  # Make the move
                move_sequence.append(move)  # Record the move
                simulate_game(board, move_sequence, 'X')  # Recursive call
                board[move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'X' going first
simulate_game(initial_board, initial_move_sequence, 'X')

# Show some of the finished games to verify correctness
finished_games[:10], len(finished_games)

([[4, 0, 2, 1, 6],
  [4, 0, 2, 3, 6],
  [4, 0, 2, 5, 6],
  [4, 0, 2, 6, 3, 1, 5],
  [4, 0, 2, 6, 3, 5, 8, 1, 7],
  [4, 0, 2, 6, 3, 5, 8, 7, 1],
  [4, 0, 2, 6, 3, 7, 5],
  [4, 0, 2, 6, 3, 8, 5],
  [4, 0, 2, 7, 6],
  [4, 0, 2, 8, 6]],
 488)

In [39]:
import pickle

# Save the list of lists to appropriate file
with open('finished_games.pkl', 'wb') as f:
    pickle.dump(finished_games, f)

# Load the list of lists from appropriate file
with open('finished_games.pkl', 'rb') as f:
    finished_games = pickle.load(f)

In [40]:
finished_games

[[4, 0, 2, 1, 6],
 [4, 0, 2, 3, 6],
 [4, 0, 2, 5, 6],
 [4, 0, 2, 6, 3, 1, 5],
 [4, 0, 2, 6, 3, 5, 8, 1, 7],
 [4, 0, 2, 6, 3, 5, 8, 7, 1],
 [4, 0, 2, 6, 3, 7, 5],
 [4, 0, 2, 6, 3, 8, 5],
 [4, 0, 2, 7, 6],
 [4, 0, 2, 8, 6],
 [4, 0, 6, 1, 2],
 [4, 0, 6, 2, 1, 3, 7],
 [4, 0, 6, 2, 1, 5, 7],
 [4, 0, 6, 2, 1, 7, 8, 3, 5],
 [4, 0, 6, 2, 1, 7, 8, 5, 3],
 [4, 0, 6, 2, 1, 8, 7],
 [4, 0, 6, 3, 2],
 [4, 0, 6, 5, 2],
 [4, 0, 6, 7, 2],
 [4, 0, 6, 8, 2],
 [4, 0, 8, 1, 2, 3, 5],
 [4, 0, 8, 1, 2, 3, 6],
 [4, 0, 8, 1, 2, 5, 6],
 [4, 0, 8, 1, 2, 6, 5],
 [4, 0, 8, 1, 2, 7, 5],
 [4, 0, 8, 1, 2, 7, 6],
 [4, 0, 8, 2, 1, 3, 7],
 [4, 0, 8, 2, 1, 5, 7],
 [4, 0, 8, 2, 1, 6, 7],
 [4, 0, 8, 2, 1, 7, 6, 3, 5],
 [4, 0, 8, 2, 1, 7, 6, 5, 3],
 [4, 0, 8, 3, 6, 1, 7],
 [4, 0, 8, 3, 6, 1, 2],
 [4, 0, 8, 3, 6, 2, 7],
 [4, 0, 8, 3, 6, 5, 7],
 [4, 0, 8, 3, 6, 5, 2],
 [4, 0, 8, 3, 6, 7, 2],
 [4, 0, 8, 5, 6, 1, 7],
 [4, 0, 8, 5, 6, 1, 2],
 [4, 0, 8, 5, 6, 2, 7],
 [4, 0, 8, 5, 6, 3, 7],
 [4, 0, 8, 5, 6, 3, 2],
 [4, 0, 8, 5, 6,

In [41]:
# Initialize list to store sequences of all finished games
finished_games_O_first = []
o_wins = 0

def simulate_game_O_first(board, move_sequence):
    """
    Simulate a game of Tic-Tac-Toe recursively where 'O' goes first.
    
    Parameters:
        board (list): The current game board.
        move_sequence (list): The sequence of moves made so far.
    """
    
    # Check for game over conditions (win or draw)
    if is_winner(board, 'X') or is_winner(board, 'O'):
        # if is_winner(board, 'X'): print("X won!")
        # if is_winner(board, 'O'): 
        #     print("O won!")
        finished_games_O_first.append(move_sequence[:])
        return
    if ' ' not in board:
        finished_games_O_first.append(move_sequence[:])
        return
    
    # All moves for the non-optimal player
    for move in range(9):
        if board[move] == ' ':
            board[move] = 'O'  # Make the move
            move_sequence.append(move)  # Record the move

            # Optimal player's move
            optimal_moves = get_optimal_moves(board, 'X')
            for x_move in optimal_moves:
                board[x_move] = 'X'  # Make the move
                move_sequence.append(x_move)  # Record the move 
                simulate_game_O_first(board, move_sequence)  # Recursive call
                board[x_move] = ' '  # Undo the move
                move_sequence.pop()  # Remove the last move from the sequence
                
            board[move] = ' '  # Undo the move
            move_sequence.pop()  # Remove the last move from the sequence

# Initialize board and move_sequence
initial_board = [' ' for _ in range(9)]
initial_move_sequence = []

# Start the simulation with 'O' going first
simulate_game_O_first(initial_board, initial_move_sequence)

# Show some of the finished games to verify correctness
finished_games_O_first[:10], len(finished_games_O_first)

([[0, 4, 1, 2, 3, 6],
  [0, 4, 1, 2, 5, 6],
  [0, 4, 1, 2, 6, 3, 7, 5],
  [0, 4, 1, 2, 6, 3, 8, 5],
  [0, 4, 1, 2, 7, 6],
  [0, 4, 1, 2, 8, 6],
  [0, 4, 2, 1, 3, 7],
  [0, 4, 2, 1, 5, 7],
  [0, 4, 2, 1, 6, 7],
  [0, 4, 2, 1, 8, 7]],
 656)

In [42]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch

# Revised TicTacToeDataset class
class TicTacToeDataset(Dataset):
    def __init__(self, finished_games_X_first, finished_games_O_first):
        self.data = []
        
        # For games where X went first
        for game in finished_games_X_first:
            for i in range(1, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i]
                self.data.append((sub_seq, game[i]))
        
        # For games where X went second
        for game in finished_games_O_first:
            for i in range(0, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i+1]
                self.data.append((sub_seq, game[i+1]))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sub_seq, target = self.data[idx]
        # Pad the sub_seq to have length 9
        padded_sub_seq = sub_seq + [10] * (10 - len(sub_seq))
        return torch.tensor(padded_sub_seq, dtype=torch.long).view(-1), torch.tensor(target, dtype=torch.long).view(-1)

# Create Dataset
dataset = TicTacToeDataset(finished_games, finished_games_O_first)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Fetch a batch to check
j = 800
for i, (X, y) in enumerate(dataloader):
    if i == j:
        print("X:", X)
        print("y:", y)
        break

X: tensor([[ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1,  0,  3, 10, 10, 10, 10],
        [ 9,  7, 10, 10, 10, 10, 10, 10, 10, 10],
        [ 9,  7,  4,  1, 10, 10, 10, 10, 10, 10]])
y: tensor([[0],
        [8],
        [4],
        [0]])


In [43]:
from torch.utils.data import Dataset, DataLoader
import torch

# Define the TicTacToeDataset class
class TicTacToeDataset(Dataset):
    def __init__(self, finished_games_X_first, finished_games_O_first):
        self.data = []
        
        # For games where X went first
        for game in finished_games_X_first:
            for i in range(1, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i]
                self.data.append((sub_seq, game[i]))
        
        # For games where X went second
        for game in finished_games_O_first:
            for i in range(0, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i+1]
                self.data.append((sub_seq, game[i+1]))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# Custom collate function to handle sequences of varying lengths
def custom_collate(batch):
    sequences, labels = zip(*batch)
    
    # Pad sequences in the batch
    max_length = max([len(seq) for seq in sequences])
    padded_sequences = [seq + [10] * (max_length - len(seq)) for seq in sequences]  # Using 10 as the padding token
    
    # Convert to PyTorch tensors
    padded_sequences = torch.tensor(padded_sequences, dtype=torch.long)
    labels = torch.tensor(labels, dtype=torch.long)
    
    return padded_sequences, labels

# Initialize the DataLoader with the custom collate function
tic_tac_toe_dataset = TicTacToeDataset(finished_games, finished_games_O_first)
data_loader = DataLoader(tic_tac_toe_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate)

# Show a sample batch to validate the Dataset and DataLoader
for batch in data_loader:
    sample_batch = batch
    break

sample_batch

(tensor([[ 9,  1, 10, 10, 10, 10],
         [ 9,  4,  7,  6,  2,  0],
         [ 9,  4,  3,  8,  0,  6],
         [ 9,  4,  0,  6,  2,  1]]),
 tensor([4, 5, 5, 8]))

In [54]:
import torch.nn as nn
import torch.nn.functional as F

# Modify the TicTacToeTransformer model to correctly pick logits for the last non-padding token
class TicTacToeTransformer(nn.Module):
    def __init__(self, d_model, nhead, num_layers, num_classes):
        super(TicTacToeTransformer, self).__init__()
        
        # Embedding Layer
        self.embedding = nn.Embedding(11, d_model)  # 11 classes: board positions (0-8), start (9), and padding (10)
        
        # Positional Encoding
        self.positional_encoding = nn.Parameter(torch.randn(1, 10, d_model))
        
        # Transformer Decoder
        self.transformer = nn.Transformer(d_model, nhead, num_layers, num_layers, dim_feedforward=d_model*4)
        
        # Output Layer
        self.fc_out = nn.Linear(d_model, num_classes)
        
    def forward(self, x, targets=None):
        # Robustly get the index of the first padding token (index = 10) for each sequence in the batch
        pad_idx = torch.tensor([seq.tolist().index(10) if 10 in seq.tolist() else seq.size(0) - 1 for seq in x])
        
        # Embedding
        x = self.embedding(x)
        
        # Positional Encoding (Added to the embedding)
        x += self.positional_encoding[:, :x.size(1), :]
        
        # Permute the tensor dimensions to match transformer's expected input shape
        x = x.permute(1, 0, 2)
        
        # Transformer Forward Pass
        tgt_mask = self.transformer.generate_square_subsequent_mask(x.size(0)).to(x.device)
        x = self.transformer(x, x, tgt_mask=tgt_mask)
        
        # Permute back to (batch_size, seq_length, d_model) for the output layer
        x = x.permute(1, 0, 2)
        
        # Output Layer
        # Use gathered indices to pick the logits corresponding to the last non-padding token
        logits = self.fc_out(x[torch.arange(x.size(0)), pad_idx, :])
        
        # Loss calculation, if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets, ignore_index=10)  # 10 is the padding token
        
        return logits, loss

d_model = 64  # Dimension of the model
nhead = 4  # Number of heads in multi-head attention
num_layers = 2  # Number of transformer layers
num_classes = 9  # Number of board positions

# Initialize the model
model = TicTacToeTransformer(d_model, nhead, num_layers, num_classes)

# Perform a forward pass using the sample batch
sample_sequences, sample_labels = sample_batch
logits, loss = model(sample_sequences, sample_labels)

logits.shape, F.softmax(logits, dim=1), loss

(torch.Size([4, 9]),
 tensor([[0.1236, 0.0442, 0.0478, 0.1667, 0.0296, 0.2468, 0.2442, 0.0390, 0.0583],
         [0.0513, 0.1462, 0.0429, 0.0510, 0.0358, 0.2795, 0.2578, 0.0923, 0.0432],
         [0.0969, 0.1136, 0.0777, 0.0382, 0.0543, 0.3206, 0.0970, 0.0861, 0.1156],
         [0.0804, 0.0847, 0.1030, 0.1008, 0.0719, 0.3122, 0.0896, 0.0526, 0.1048]],
        grad_fn=<SoftmaxBackward0>),
 tensor(2.0469, grad_fn=<NllLossBackward0>))

In [52]:
import math
from torch import Tensor
from typing import Tuple

# Define the simplified TicTacToeTransformer model
class SimplifiedTicTacToeTransformer(nn.Module):
    def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):
        super(SimplifiedTicTacToeTransformer, self).__init__()
        
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(d_model, dropout)
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(ntoken, d_model)
        self.d_model = d_model
        self.linear = nn.Linear(d_model, 9)
        
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src: Tensor, src_mask: Tensor = None, targets: Tensor = None) -> Tuple[Tensor, Tensor]:
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        logits = self.linear(output)
        print(logits.shape)
        return logits

# Define the PositionalEncoding class again
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)
    
# Function to generate square subsequent mask
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

# Redefine the necessary hyperparameters
ntoken = 11  # 0-8 for board positions, 9 for start, 10 for padding
d_model = 64
nhead = 4
d_hid = 256
nlayers = 3
dropout = 0.2

# Initialize the simplified model again
simplified_model = SimplifiedTicTacToeTransformer(ntoken, d_model, nhead, d_hid, nlayers, dropout)

# Define the device variable
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create the square subsequent mask again with the correct size
src_mask = generate_square_subsequent_mask(sample_sequences.size(1)).to(device)

# Perform a forward pass using the sample batch
logits = simplified_model(sample_sequences.permute(1, 0), src_mask=src_mask, targets=sample_labels)

logits.shape, F.softmax(logits, dim=1)


torch.Size([6, 4, 9])


(torch.Size([6, 4, 9]),
 tensor([[[0.2807, 0.2355, 0.2865, 0.2034, 0.2825, 0.2711, 0.1565, 0.2041,
           0.2069],
          [0.1950, 0.3081, 0.1943, 0.3162, 0.1968, 0.1643, 0.2591, 0.2750,
           0.2542],
          [0.3123, 0.2164, 0.1833, 0.2451, 0.1711, 0.3713, 0.2827, 0.2103,
           0.3482],
          [0.2120, 0.2400, 0.3359, 0.2352, 0.3496, 0.1933, 0.3017, 0.3106,
           0.1907]],
 
         [[0.3099, 0.2837, 0.2939, 0.2390, 0.3274, 0.2378, 0.2928, 0.3799,
           0.2013],
          [0.2031, 0.2823, 0.2485, 0.3118, 0.2276, 0.1883, 0.2217, 0.1929,
           0.3940],
          [0.2362, 0.2099, 0.1452, 0.2836, 0.1700, 0.3191, 0.2195, 0.2289,
           0.2067],
          [0.2509, 0.2241, 0.3124, 0.1655, 0.2750, 0.2549, 0.2660, 0.1982,
           0.1980]],
 
         [[0.2143, 0.2171, 0.3191, 0.2872, 0.2843, 0.1812, 0.3155, 0.4129,
           0.1538],
          [0.2764, 0.2399, 0.1718, 0.2547, 0.2114, 0.2654, 0.2737, 0.1782,
           0.1823],
          [0.2919, 0

In [53]:
# Train the transformer
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    
    epoch_loss = 0
    
    for i, (X, y) in enumerate(iterator):
        optimizer.zero_grad()
        
        # Create the square subsequent mask again with the correct size
        src_mask = generate_square_subsequent_mask(X.size(1)).to(device)
        
        # Forward Pass
        logits = model(X.permute(1, 0), src_mask=src_mask, targets=y)
        
        # Calculate loss
        loss = criterion(logits.view(-1, logits.size(-1)), y.view(-1))
        
        # Backpropagation
        loss.backward()
        
        # Clip the gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        
        # Update weights
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / len(iterator)

# Define params and train
ntoken = 11  # 0-8 for board positions, 9 for start, 10 for padding
d_model = 64
nhead = 4
d_hid = 256
nlayers = 3

# Initialize the model
model = SimplifiedTicTacToeTransformer(ntoken, d_model, nhead, d_hid, nlayers, dropout).to(device)

# Define the optimizer and criterion
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=10)  # 10 is the padding token

# Define the clip value
clip = 1

# Train the model
for epoch in range(10):
    train_loss = train(model, data_loader, optimizer, criterion, clip)
    print("Epoch:", epoch+1, "Train Loss:", train_loss)

torch.Size([8, 4, 9])


ValueError: Expected input batch_size (32) to match target batch_size (4).

In [46]:
# Import necessary modules
import torch.nn.functional as F

# Define a simple MLP model for Tic-Tac-Toe
class TicTacToeMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TicTacToeMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, targets=None):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc_out(x)
        
        # Loss calculation, if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets, ignore_index=-1)
        
        return logits, loss

# Initialize the MLP model with corrected input dimension
input_dim = sample_sequences.size(1)  # Input dimension should match the sequence length
hidden_dim = 128  # Arbitrary size for hidden layers
output_dim = 9  # Output dimension should match the vocabulary size

mlp_model = TicTacToeMLP(input_dim, hidden_dim, output_dim).to(device)

# Perform a forward pass using the sample batch
mlp_logits, mlp_loss = mlp_model(sample_sequences.float(), sample_labels)

mlp_logits.shape, F.softmax(mlp_logits, dim=1), mlp_loss

(torch.Size([4, 9]),
 tensor([[0.2941, 0.0840, 0.0923, 0.0213, 0.1440, 0.1753, 0.1583, 0.0093, 0.0213],
         [0.2907, 0.1156, 0.0643, 0.0552, 0.1204, 0.1470, 0.1305, 0.0328, 0.0435],
         [0.2183, 0.1062, 0.0841, 0.0501, 0.1135, 0.2261, 0.1321, 0.0402, 0.0294],
         [0.2347, 0.0868, 0.0692, 0.0600, 0.1250, 0.2060, 0.1210, 0.0559, 0.0415]],
        grad_fn=<SoftmaxBackward0>),
 tensor(2.1312, grad_fn=<NllLossBackward0>))

In [47]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torch

# Revised TicTacToeDataset class
class TicTacToeDataset(Dataset):
    def __init__(self, finished_games_X_first, finished_games_O_first):
        self.data = []
        
        # For games where X went first
        for game in finished_games_X_first:
            for i in range(1, len(game), 2):  # Only X's moves
                sub_seq = [9] + game[:i]
                self.data.append((sub_seq, game[i]))
        
        # # For games where X went second
        # for game in finished_games_O_first:
        #     for i in range(0, len(game), 2):  # Only X's moves
        #         sub_seq = [9] + game[:i+1]
        #         self.data.append((sub_seq, game[i+1]))
                
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sub_seq, target = self.data[idx]
        # Pad the sub_seq to have length 9
        padded_sub_seq = sub_seq + [10] * (10 - len(sub_seq))
        return torch.tensor(padded_sub_seq, dtype=torch.long).view(-1), torch.tensor(target, dtype=torch.long).view(-1)

# Create Dataset
dataset = TicTacToeDataset(finished_games, finished_games_O_first)

# Create DataLoader
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

# Fetch a batch to check
j = 20
for i, (X, y) in enumerate(dataloader):
    if i == j:
        print("X:", X)
        print("y:", y)
        break

X: tensor([[ 9,  4,  0,  8,  2,  1, 10, 10, 10, 10],
        [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
        [ 9,  4,  0,  8, 10, 10, 10, 10, 10, 10],
        [ 9,  4,  0,  8,  2,  1, 10, 10, 10, 10]])
y: tensor([[6],
        [0],
        [2],
        [7]])


In [48]:
# Correcting the custom collate function to handle the tensor conversion properly for padding
def custom_collate_fixed_length(batch):
    sequences, labels = zip(*batch)
    
    # Pad sequences to the maximum length of 10
    padded_sequences = [torch.cat((torch.tensor(seq, dtype=torch.long), 
                                   torch.tensor([10] * (10 - len(seq)), dtype=torch.long))) for seq in sequences]
    
    # Convert to PyTorch tensors
    padded_sequences = torch.stack(padded_sequences, dim=0)
    labels = torch.tensor(labels, dtype=torch.long)
    
    return padded_sequences, labels

# Initialize the DataLoader with the fixed-length custom collate function
data_loader_fixed = DataLoader(tic_tac_toe_dataset, batch_size=64, shuffle=True, collate_fn=custom_collate_fixed_length)

# Show a sample batch to validate the Dataset and DataLoader
for batch in data_loader_fixed:
    sample_batch_fixed = batch
    break

sample_batch_fixed

(tensor([[ 9,  2,  4,  6,  3,  7, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  6,  4,  2,  7,  0, 10, 10, 10, 10],
         [ 9,  3,  4,  1,  0,  8,  6,  7, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  4,  5,  2,  6,  1, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  4,  2,  5, 10, 10, 10, 10, 10, 10],
         [ 9,  5,  4,  1, 10, 10, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  4,  8,  0,  3,  1, 10, 10, 10, 10],
         [ 9,  2,  4,  5,  8,  0,  1,  3, 10, 10],
         [ 9,  4,  6,  1,  7,  2, 10, 10, 10, 10],
         [ 9,  4,  3,  0,  8,  1, 10, 10, 10, 10],
         [ 9,  4,  0,  8, 10, 10, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  3,  4,  5,  8,  6, 10, 10, 10, 10],
         [ 9,  4, 10, 10, 10, 10, 10, 10, 10, 10],
         [ 9,  5, 10, 10, 10, 1

In [49]:
from torch.optim import Adam

# Define a simple MLP model for Tic-Tac-Toe
class TicTacToeMLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TicTacToeMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x, targets=None):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        logits = self.fc_out(x)
        
        # Loss calculation, if targets are provided
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

# Initialize the MLP model with corrected input and output dimensions
input_dim = 10  # Input dimension should match the maximum sequence length for tic-tac-toe, including start token
hidden_dim = 32  # Arbitrary size for hidden layers
output_dim = 9  # Output dimension should match the number of board positions

mlp_model = TicTacToeMLP(input_dim, hidden_dim, output_dim)

# Define the optimizer and loss function
optimizer = Adam(mlp_model.parameters(), lr=0.001)

# Training loop
num_epochs = 1000

for epoch in range(num_epochs):
    for i, (sequences, labels) in enumerate(data_loader_fixed):
        sequences, labels = sequences.float(), labels
        
        # Forward pass
        logits, loss = mlp_model(sequences, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    if epoch % (num_epochs/10) == 0:
        print(f"Epoch [{epoch + 1}/{num_epochs}], Step [{i + 1}/{len(data_loader_fixed)}], Loss: {loss.item():.4f}")


  from .autonotebook import tqdm as notebook_tqdm


Epoch [1/1000], Step [56/56], Loss: 2.1892
Epoch [101/1000], Step [56/56], Loss: 1.6747
Epoch [201/1000], Step [56/56], Loss: 1.5074
Epoch [301/1000], Step [56/56], Loss: 1.5098
Epoch [401/1000], Step [56/56], Loss: 1.2258
Epoch [501/1000], Step [56/56], Loss: 1.2334
Epoch [601/1000], Step [56/56], Loss: 1.3643
Epoch [701/1000], Step [56/56], Loss: 1.3485
Epoch [801/1000], Step [56/56], Loss: 1.4411
Epoch [901/1000], Step [56/56], Loss: 1.2040
