# Solve Tic-Tac-Toe using a Classification Neural Network

## TTT implementation checklist

- Make sure that we have a working impementation of TTT
  - Penalize you if you chose a move that is invalid (you lose your turn)
  - Needs to have a fully programatic interface with the ability to turn off full game board.
  - Unoccupied spaces need to have value 0.

## Draft Design

- Classification problem
- Loss function: cross entropy loss
- Optimizer: SGD
- for now the AI is always "O"

### Neural Network

- 9 inputs: each square on the board
- 9 outputs: the position that the AI should move to 
- 1 hidden layer with 16 neurons
- ReLU activation layers between

### Training Loop

- 100 **generations** to start
- For loop within each generation:
  - play 100 **matches**; this should equate to 5-9 moves per match, or more importantly, 500-900 **frames**.
    - the match should always have the AI player 1
    - the AI is only analyzing its own play
  - Record each frame of a match in an array.
  - For games where the AI **WINS** record it in an array for the generation.
  - The array of winning frames is your X train and test data for the NN.
  - Calculate the argmax to get the position to move to




In [8]:
import torch
from torch import nn
from sklearn.model_selection import train_test_split
import torchmetrics
import matplotlib.pyplot as plt
import sys
sys.path.insert(0, '/home/malone/Code/pytorch-bourke')
from helper_functions import plot_decision_boundary

from lib.game import Game
import lib.state

RANDOM_SEED = 1
NUM_GENERATION = 1
NUM_MATCHES = 100
AI_PLAYER_ID = 1

def aiMove(game):
    return (1,1)

# set up the game
game = Game(aiMove, 'random', hideGameOutput=True, randomSeed=RANDOM_SEED)

X = torch.Tensor()

def runMatch(game, X):
    winner = game.runGame()
    # print("winner is player", winner)

    if winner == AI_PLAYER_ID:
        # record the frames as the training set
        frames = torch.Tensor(game.frames).type(torch.int32)
        frames = frames.reshape(frames.shape[0], 9).squeeze()
        # print(frames)
        X = torch.cat((X, frames), dim=0)
    return X

class TttModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(in_features=9, out_features=16),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=16),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=9),
        )
    def forward(self, x):
        return self.linear_stack(x)
        
        
# training loop
for generation in range(NUM_GENERATION):
    for match in range(NUM_MATCHES):
        X = runMatch(game, X)
print(X)
        
        



tensor([[0., 1., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        ...,
        [2., 1., 0.,  ..., 0., 0., 0.],
        [2., 1., 0.,  ..., 0., 0., 0.],
        [2., 1., 0.,  ..., 0., 0., 2.]])
