# Board

In [1]:
import torch

class Board:
    def __init__(self):
        self.board = torch.zeros(9, dtype=torch.float32)
        self.sum = 0
    
    def printBoard(self):
        for i in range(3):
            print(self.board[i*3:(i+1)*3])

    def _play(self, id, val):
        if self.board[val].item() == 0:
            self.sum += 1
            self.board[val] = id
            return True
        return False
    
    def _checkWin(self, id, val):
        b = self.board  # shorthand for readability
        match val:
            case 0:
                if (b[0] == b[1]).item() and (b[1] == b[2]).item() and (b[0] == id).item():
                    return True
                if (b[0] == b[3]).item() and (b[3] == b[6]).item() and (b[0] == id).item():
                    return True
                if (b[0] == b[4]).item() and (b[4] == b[8]).item() and (b[0] == id).item():
                    return True
            case 1:
                if (b[0] == b[1]).item() and (b[1] == b[2]).item() and (b[0] == id).item():
                    return True
                if (b[1] == b[4]).item() and (b[4] == b[7]).item() and (b[1] == id).item():
                    return True
            case 2:
                if (b[0] == b[1]).item() and (b[1] == b[2]).item() and (b[0] == id).item():
                    return True
                if (b[2] == b[5]).item() and (b[5] == b[8]).item() and (b[2] == id).item():
                    return True
                if (b[2] == b[4]).item() and (b[4] == b[6]).item() and (b[2] == id).item():
                    return True
            case 3:
                if (b[3] == b[4]).item() and (b[4] == b[5]).item() and (b[3] == id).item():
                    return True
                if (b[0] == b[3]).item() and (b[3] == b[6]).item() and (b[0] == id).item():
                    return True
            case 4:
                if (b[3] == b[4]).item() and (b[4] == b[5]).item() and (b[3] == id).item():
                    return True
                if (b[1] == b[4]).item() and (b[4] == b[7]).item() and (b[1] == id).item():
                    return True
                if (b[0] == b[4]).item() and (b[4] == b[8]).item() and (b[0] == id).item():
                    return True
                if (b[2] == b[4]).item() and (b[4] == b[6]).item() and (b[2] == id).item():
                    return True
            case 5:
                if (b[3] == b[4]).item() and (b[4] == b[5]).item() and (b[3] == id).item():
                    return True
                if (b[2] == b[5]).item() and (b[5] == b[8]).item() and (b[2] == id).item():
                    return True
            case 6:
                if (b[6] == b[7]).item() and (b[7] == b[8]).item() and (b[6] == id).item():
                    return True
                if (b[0] == b[3]).item() and (b[3] == b[6]).item() and (b[0] == id).item():
                    return True
                if (b[2] == b[4]).item() and (b[4] == b[6]).item() and (b[2] == id).item():
                    return True
            case 7:
                if (b[6] == b[7]).item() and (b[7] == b[8]).item() and (b[6] == id).item():
                    return True
                if (b[1] == b[4]).item() and (b[4] == b[7]).item() and (b[1] == id).item():
                    return True
            case 8:
                if (b[6] == b[7]).item() and (b[7] == b[8]).item() and (b[6] == id).item():
                    return True
                if (b[2] == b[5]).item() and (b[5] == b[8]).item() and (b[2] == id).item():
                    return True
                if (b[0] == b[4]).item() and (b[4] == b[8]).item() and (b[0] == id).item():
                    return True            
        return False
    
    def play(self, id, val):
        if self._play(id, val):
            if self._checkWin(id,val):
                return True,"win"
            if self.sum == 9:
                return True,"draw"
            return False,"none"
        return True,"invalid"
    
    def clear(self):
        self.board = torch.zeros(9, dtype=torch.float32)
        self.sum = 0
    
    @property
    def board_state(self):
        return self.board
    


# Imports

In [2]:
import random
import torch
import torch.nn as nn
from torch.distributions import Categorical
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Board setup and test
board = Board()
# board.printBoard()
# board.board_state

# MODEL

In [3]:
# Hyper-parameters
input_size = 9
hidden_sizes = (256,256,128,32)
num_classes = 9
num_epochs = 200000
batch_size = 9
game_chances = 9
temperature = 0.95
learning_rate = 3e-4
gamma = 0.95

In [4]:
# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_classes):
        super(NeuralNet, self).__init__()
        self.l1 = nn.Linear(2*input_size, hidden_sizes[0])
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.l3 = nn.Linear(hidden_sizes[1], hidden_sizes[2])
        self.l4 = nn.Linear(hidden_sizes[2], hidden_sizes[3])
        self.l5 = nn.Linear(hidden_sizes[3], num_classes)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        xin = torch.stack((self.relu(x),self.relu(-x)), dim=1).flatten()
        logits = torch.zeros_like(x)  # Create a tensor of all 0s
        logits = logits.masked_fill(x != 0, float('-inf'))        
        out = self.l1(xin)
        out = self.relu(out)
        out = self.l2(out)
        out = self.relu(out)
        out = self.l3(out)
        out = self.relu(out)
        out = self.l4(out)
        out = self.relu(out)
        out = self.l5(out)
        out = out + logits
        out = self.softmax(out)
        return out

model = NeuralNet(input_size, hidden_sizes, num_classes).to(device)

## Tests

In [5]:
# Test model by checking if it avoids already used positions
board.clear()
print("Empty board:")
board.printBoard()

# Make a few moves manually
board.play(1, 0)  # X in position 0
board.play(-1, 4)  # O in position 4
print("\nBoard after manual moves:")
board.printBoard()

# Test model prediction on this board state
with torch.no_grad():
    prediction = model(board.board_state.unsqueeze(0).to(device))
    print(f"\nModel probabilities for each position:")
    for i in range(9):
        status = "OCCUPIED" if board.board_state[i] != 0 else "FREE"
        print(f"Position {i}: {prediction[0][i].item():.4f} ({status})")
    
    # Check if model assigns near-zero probability to occupied positions
    occupied_positions = (board.board_state != 0).nonzero().flatten()
    free_positions = (board.board_state == 0).nonzero().flatten()
    
    print(f"\nOccupied positions: {occupied_positions.tolist()}")
    print(f"Probabilities for occupied positions: {[prediction[0][i].item() for i in occupied_positions]}")
    print(f"Free positions: {free_positions.tolist()}")
    print(f"Probabilities for free positions: {[prediction[0][i].item() for i in free_positions]}")

Empty board:
tensor([0., 0., 0.])
tensor([0., 0., 0.])
tensor([0., 0., 0.])

Board after manual moves:
tensor([1., 0., 0.])
tensor([ 0., -1.,  0.])
tensor([0., 0., 0.])

Model probabilities for each position:
Position 0: 0.0000 (OCCUPIED)
Position 1: 0.1309 (FREE)
Position 2: 0.1580 (FREE)
Position 3: 0.1479 (FREE)
Position 4: 0.0000 (OCCUPIED)
Position 5: 0.1456 (FREE)
Position 6: 0.1503 (FREE)
Position 7: 0.1412 (FREE)
Position 8: 0.1261 (FREE)

Occupied positions: [0, 4]
Probabilities for occupied positions: [0.0, 0.0]
Free positions: [1, 2, 3, 5, 6, 7, 8]
Probabilities for free positions: [0.1308642029762268, 0.15797899663448334, 0.1479322910308838, 0.1456100344657898, 0.15028022229671478, 0.14118605852127075, 0.12614811956882477]


## Training

In [6]:


# Loss and optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Pre-compute discount cache (outside training loop)
discount_cache = {}
for length in range(1, game_chances + 1):
    discount_cache[length] = torch.pow(gamma, torch.arange(length, device=device, dtype=torch.float32))
cummulative_loss = 0.0
# 3) Training loop
for epoch in range(num_epochs):
    if random.random() < 0.5:
        active_player = 1
    else:
        active_player = 0
    board.clear()
    gamequeue = []
    boards = []
    for i in range(game_chances):
        boards.append(board.board_state.clone())
        y = model(board.board_state.unsqueeze(0).to(device))
        # Create mask for valid moves (free positions on board)
        valid_mask = (board.board_state == 0).float().unsqueeze(0).to(device)
        num_valid = valid_mask.sum()

        # Mix model predictions with uniform over valid moves
        uniform_prob = valid_mask / num_valid
        y = y * temperature + uniform_prob * (1 - temperature)
        m = Categorical(probs=y)
        Y_out = m.sample()
        status,win = board.play(2*(i%2) - 1, Y_out.item())
        if i % 2 == active_player:
            log_prob = m.log_prob(Y_out)
            gamequeue.append((i, log_prob))
        if status:
            break
    
    # Zero gradients before computing loss
    optimizer.zero_grad()
    
    if win == "invalid":
        log_prob = m.log_prob(Y_out)
        loss = -100*log_prob
        print("Invalid move made!")
        print("Board state:\n", boards[-1].reshape(3,3))
        print("All boards in game:")
        for b in boards:
            print(b.reshape(3,3))
        print("Move attempted:", Y_out.item())
        break
    elif len(gamequeue) == 0:
        continue
    else:
        if win == "draw":
            reward_value = 0.1  
        else:
            reward_value = 1 if (i % 2) == active_player else -1

        # Stack all log_probs into a single tensor
        log_probs = torch.stack([lp for _, lp in gamequeue[::-1]])

        # Use cached discounts
        discounts = discount_cache[len(gamequeue)]

        # Vectorized loss calculation
        loss = -(discounts * reward_value * log_probs).sum() / len(gamequeue)
    
    # calculate gradients = backward pass
    loss.backward()

    # update weights
    optimizer.step()
    with torch.no_grad():
        cummulative_loss += loss.item()
    if (epoch+1) % 10 == 0:
        print('epoch ', epoch+1, ': loss = ', loss.item(),'; avg loss = ', cummulative_loss/(epoch+1))

epoch  10 : loss =  6.272278785705566 ; avg loss =  1.3235433101654053
epoch  20 : loss =  -3.8676397800445557 ; avg loss =  -0.33181673288345337
epoch  30 : loss =  -5.170615196228027 ; avg loss =  0.2429970363775889
epoch  40 : loss =  -4.9537787437438965 ; avg loss =  0.28513353019952775
epoch  50 : loss =  -3.6543936729431152 ; avg loss =  -0.05240284562110901
epoch  60 : loss =  0.5492885708808899 ; avg loss =  0.20315142075220743
epoch  70 : loss =  4.912382125854492 ; avg loss =  0.5093419517789568
epoch  80 : loss =  -5.064320087432861 ; avg loss =  0.20259353071451186
epoch  90 : loss =  -5.558191299438477 ; avg loss =  0.0712949640221066
epoch  100 : loss =  6.227092742919922 ; avg loss =  -0.16420884788036347
epoch  110 : loss =  -5.4786376953125 ; avg loss =  -0.21443601792508907
epoch  120 : loss =  6.215312480926514 ; avg loss =  -0.16754782845576605
epoch  130 : loss =  -4.992851257324219 ; avg loss =  -0.3197252259804652
epoch  140 : loss =  -4.905048847198486 ; avg los

In [1]:
import pickle
from google.colab import files

# Assuming 'model' is the object you want to save
filename = 'model.pkl'
with open(filename, 'wb') as f:
    pickle.dump(model, f)

files.download(filename)
import pickle
from google.colab import drive

# Mount Drive (you'll be prompted for authorization once)
drive.mount('/content/drive')

# Save your model
filename = '/content/drive/MyDrive/model.pkl'  # adjust path as needed
with open(filename, 'wb') as f:
    pickle.dump(model, f)

print(f"Model saved to {filename}")


NameError: name 'model' is not defined

## PLAYING

In [12]:
def play_against_model():
    """Interactive function to play tic-tac-toe against the trained model"""
    board.clear()
    print("Let's play Tic-Tac-Toe! You are X (1), AI is O (-1)")
    print("Enter positions 0-8 (top-left to bottom-right)")
    print("Board positions:")
    print("0 | 1 | 2")
    print("3 | 4 | 5") 
    print("6 | 7 | 8")
    print()
    
    human_turn = True
    
    while True:
        board.printBoard()
        
        if human_turn:
            try:
                move = int(input("Your move (0-8): "))
                if board.board_state[move] != 0:
                    print("Position already taken! Try again.")
                    continue
                status, win = board.play(1, move)
            except (ValueError, IndexError):
                print("Invalid input! Enter a number 0-8.")
                continue
        else:
            # AI's turn
            print("AI is thinking...")
            with torch.no_grad():
                prediction = model(board.board_state.unsqueeze(0).to(device))
                valid_moves = (board.board_state == 0).nonzero().flatten()
                
                # Get probabilities for valid moves only
                valid_probs = prediction[0][valid_moves]
                # Choose move with highest probability
                best_move_idx = torch.argmax(valid_probs)
                ai_move = valid_moves[best_move_idx].item()
                
                print(f"AI chooses position {ai_move}")
                status, win = board.play(-1, ai_move)
        
        if status:
            board.printBoard()
            if win == "win":
                winner = "You" if human_turn else "AI"
                print(f"{winner} wins!")
            elif win == "draw":
                print("It's a draw!")
            break
            
        human_turn = not human_turn

# Start the game
play_against_model()

Let's play Tic-Tac-Toe! You are X (1), AI is O (-1)
Enter positions 0-8 (top-left to bottom-right)
Board positions:
0 | 1 | 2
3 | 4 | 5
6 | 7 | 8

tensor([0., 0., 0.])
tensor([0., 0., 0.])
tensor([0., 0., 0.])
tensor([1., 0., 0.])
tensor([0., 0., 0.])
tensor([0., 0., 0.])
AI is thinking...
AI chooses position 4
tensor([1., 0., 0.])
tensor([ 0., -1.,  0.])
tensor([0., 0., 0.])
tensor([1., 0., 1.])
tensor([ 0., -1.,  0.])
tensor([0., 0., 0.])
AI is thinking...
AI chooses position 8
tensor([1., 0., 1.])
tensor([ 0., -1.,  0.])
tensor([ 0.,  0., -1.])
Position already taken! Try again.
tensor([1., 0., 1.])
tensor([ 0., -1.,  0.])
tensor([ 0.,  0., -1.])
tensor([1., 1., 1.])
tensor([ 0., -1.,  0.])
tensor([ 0.,  0., -1.])
You wins!
