# Solve Tic-Tac-Toe using a Classification Neural Network

## TTT implementation checklist

- Make sure that we have a working impementation of TTT
  - Penalize you if you chose a move that is invalid (you lose your turn)
  - Needs to have a fully programatic interface with the ability to turn off full game board.
  - Unoccupied spaces need to have value 0.

## Draft Design

- Classification problem
- Loss function: cross entropy loss
- Optimizer: SGD
- for now the AI is always "O"

### Neural Network

- 9 inputs: each square on the board
- 9 outputs: the position that the AI should move to 
- 1 hidden layer with 16 neurons
- ReLU activation layers between

### Training Loop

- 100 **generations** to start
- For loop within each generation:
  - play 100 **matches**; this should equate to 5-9 moves per match, or more importantly, 500-900 **frames**.
    - the match should always have the AI player 1
    - the AI is only analyzing its own play
  - Record each frame of a match in an array.
  - For games where the AI **WINS** record it in an array for the generation.
  - The array of winning frames is your X train and test data for the NN.
  - Calculate the argmax to get the position to move to




In [77]:
import torch
from torch import nn
from sklearn.model_selection import train_test_split
import torchmetrics
import matplotlib.pyplot as plt
import random
import sys
sys.path.insert(0, '/home/malone/Code/pytorch-bourke')
from helper_functions import plot_decision_boundary

from lib.game import Game
import lib.state

RANDOM_SEED = 1
NUM_GENERATION = 10
NUM_MATCHES = 100
AI_PLAYER_ID = 1
NUM_EXPLORE = 5

def aiMove(game):
    # print("match", match)
    board = torch.Tensor(game.state.board).type(torch.float32)
    board = board.reshape(1, 9).squeeze()
    # print(board)
    move = model(board)
    # print(move)
    move = move.argmax() + 1
    # print("AI moves to", move)
    
    return game.keyToXy(move)

def runMatch(game, X):
    model.eval()
    winner = game.runGame()
    # print("winner is player", winner)
    # print("\n",game.frames, "\n")

    if winner == AI_PLAYER_ID:
        # record the frames as the training set
        frames = torch.Tensor(game.frames)
        frames = frames.reshape(frames.shape[0], 9).squeeze()
        # print(frames)
        X = torch.cat((X, frames), dim=0)
    return X, winner

class TttModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(in_features=9, out_features=16),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=16),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=9),
        )
    def forward(self, x):
        return self.linear_stack(x)

random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)


# set up the NN
model = TttModel()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
        
# play and training loop
for generation in range(NUM_GENERATION):
    # Play loop
    winCount = 0
    X = torch.Tensor()
    if generation > NUM_EXPLORE - 1:
        # start out playing randomly until we have some data to work with
        player1 = aiMove
    else: 
        player1 = 'random'
    
    print("Generation", generation, ": player1:", player1)
    game = Game(player1, 'random', hideGameOutput=True, randomSeed=None)
    
    for match in range(NUM_MATCHES):
        X, winner = runMatch(game, X)
        if winner == AI_PLAYER_ID: winCount += 1
    
    print(f"Completed {match + 1} matches | AI won {winCount} matches | X size: {X.shape}")
    # print(X[:50])
    
    # Train and testing
    
    # X_train, X_test, y_train, y_test = train_test_split(
    #     X,
    #     y,
    #     test_size=0.2,
    #     random_state=RANDOM_SEED)

    ## Training
    
    
    
    ## Testing



Generation 0 : player1: random
Completed 100 matches | AI won 36 matches | X size: torch.Size([238, 9])
Generation 1 : player1: random
Completed 100 matches | AI won 32 matches | X size: torch.Size([205, 9])
Generation 2 : player1: random
Completed 100 matches | AI won 51 matches | X size: torch.Size([335, 9])
Generation 3 : player1: random
Completed 100 matches | AI won 43 matches | X size: torch.Size([280, 9])
Generation 4 : player1: random
Completed 100 matches | AI won 40 matches | X size: torch.Size([255, 9])
Generation 5 : player1: <function aiMove at 0x7fc24b6fc940>
Completed 100 matches | AI won 0 matches | X size: torch.Size([0])
Generation 6 : player1: <function aiMove at 0x7fc24b6fc940>
Completed 100 matches | AI won 0 matches | X size: torch.Size([0])
Generation 7 : player1: <function aiMove at 0x7fc24b6fc940>
Completed 100 matches | AI won 0 matches | X size: torch.Size([0])
Generation 8 : player1: <function aiMove at 0x7fc24b6fc940>
Completed 100 matches | AI won 0 matche