In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# !pip install --upgrade albumentations
# !pip install opencv-python-headless==4.1.2.30
# !pip install wandb
!wandb login

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms.functional import to_tensor
from torch.optim import AdamW
import os
import albumentations as A
import numpy as np
from PIL import Image
import wandb

In [None]:
batch_size = 32
num_epochs = 400

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
class TicTacToeDataset(Dataset):
    def __init__(self, img_dir, transform=None):
        self.img_dir = img_dir
        self.imgs = os.listdir(self.img_dir)
        self.n = len(self.imgs)
        self.transform = transform

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.imgs[idx])
        label = int(self.imgs[idx].split('.')[1]) + 1
        image = Image.open(img_path)
        image = np.array(image)
        if self.transform:
            image = self.transform(image=image)['image']
        image = to_tensor(image)
        return image, label

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            self._block(1, 32, 3),
            self._block(32, 48, 3),
            self._block(48, 64, 3),
            self._block(64, 80, 3),
            self._block(80, 96, 3),
            self._block(96, 112, 3),
            self._block(112, 128, 3),
            self._block(128, 144, 3),
            self._block(144, 160, 3),
            self._block(160, 176, 3),
            Flatten(),
            nn.Linear(11264, 3, bias=False),
            nn.BatchNorm1d(3),
        )

    def _block(self, input_dim, output_dim, kernel_size):
        return nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size, bias=False),
            nn.BatchNorm2d(output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.seq(x)
        return F.log_softmax(x, dim=1)

class Flatten(nn.Module):
    def forward(self, x):
        return torch.flatten(x.permute(0, 2, 3, 1), 1)


In [None]:
model = Model().to(device)
optimizer = AdamW(model.parameters())

transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomCrop(26, 26, p=0.4),
    A.InvertImg(p=0.18),
    A.ColorJitter(brightness=0.55, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
    A.Blur(blur_limit=2, p=0.22),
    A.CLAHE(p=0.18, tile_grid_size=(2, 2)),
    A.Sharpen(p=0.18, alpha=0.1, lightness=1.5),
    A.Emboss(p=0.18),
    A.MultiplicativeNoise(p=0.22),
    A.RandomBrightness(p=0.4),
    A.RandomContrast(p=0.4),
    A.RandomGamma(p=0.2),
    A.Resize(28, 28, p=1.0, interpolation=Image.NEAREST),
])

dataset = TicTacToeDataset('drive/MyDrive/fields/fields', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

wandb.config = {
  "learning_rate": 0.001,
  "epochs": 400,
  "batch_size": 32
}

In [None]:
def main():
    wandb.init(project="TicTacToeClassification", entity="robertfoerster")
    example_ct = 0
    for epoch in range(num_epochs):
        for (value, label) in dataloader:
            value, label = value.to(device), label.to(device)
            output = model(value)
            loss = F.nll_loss(output, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            example_ct += len(value)
            metrics = {
                "train/train_loss": loss, 
                "train/epoch": epoch,
                "train/example_ct": example_ct
            }
            wandb.log(metrics)
    wandb.finish()

In [None]:
model.load_state_dict(torch.load('drive/MyDrive/tictactoeField.pth'))
optimizer.load_state_dict(torch.load('drive/MyDrive/optim.pth'))

In [None]:
main()

In [None]:
torch.save(model.state_dict(), 'drive/MyDrive/tictactoeField.pth')
torch.save(optimizer.state_dict(), 'drive/MyDrive/optim.pth')