In [None]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torchvision.transforms.functional import to_tensor
import os
from PIL import Image
from tqdm.notebook import trange
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# import wandb
import random
import albumentations as A
import matplotlib.pyplot as plt

In [None]:
BATCH_SIZE = 256
EPOCHS = 30

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq = nn.Sequential(
            self._block(3, 32, 3),
            self._block(32, 48, 3),
            self._block(48, 64, 3),
            self._block(64, 80, 3),
            self._block(80, 96, 3),
            self._block(96, 112, 3),
            self._block(112, 128, 3),
            self._block(128, 144, 3),
            self._block(144, 164, 3),
            self._block(164, 174, 3),
            nn.Flatten(),
            nn.Linear(2784, 3, bias=False),
            nn.BatchNorm1d(3),
        )

    def _block(self, input_dim, output_dim, kernel_size):
        return nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size, bias=False),
            nn.BatchNorm2d(output_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.seq(x)
        return F.log_softmax(x, dim=1)

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, dir, transform=None):
        self.dir = dir
        self.transform = transform
        self.n = len([imgFile for imgFile in os.listdir(dir) if imgFile.endswith('.jpg')])

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img = Image.open(os.path.join(self.dir, f"{idx}.jpg"))
        with open(os.path.join(self.dir, f"{idx}.txt")) as f:
            label = f.read()
        image = np.array(img)
        if self.transform:
            image = self.transform(image=image)['image']
        return to_tensor(img), int(label)

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.RandomCrop(22, 22, p=0.4),
    A.Blur(blur_limit=2, p=0.22),
    A.CLAHE(p=0.11, tile_grid_size=(1, 1)),
    A.Sharpen(p=0.18, alpha=0.1, lightness=1.5),
    A.Emboss(p=0.05),
    A.MultiplicativeNoise(p=0.22),
    A.RandomBrightness(p=0.4),
    A.RandomContrast(p=0.4),
    A.ColorJitter(p=1, hue=0.05),
    A.RandomGamma(p=0.08),
    A.Resize(24, 24, p=1.0, interpolation=Image.NEAREST),
])

model = Model().to(DEVICE)
optimizer = AdamW(model.parameters(), lr=0.001)

train_dataset = CustomDataset(
    dir='/home/robert/Documents/GitHub/OutSmarted/data/ConnectFour/fields', 
    transform=transform
)
print(train_dataset.n)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
def main():
    model.train()
    for epoch in trange(EPOCHS):
        for i, (img, label) in enumerate(train_loader):
            img, label = img.to(DEVICE), label.to(DEVICE)
            outputs = model(img)
            loss = F.nll_loss(outputs, label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if i % 350 == 0:
                print(f"loss example: {loss.item()}")

def show():
    model.eval()
    while True:
        idx = int(input())
        test1 = Image.open(f'/home/robert/Documents/GitHub/OutSmarted/data/ConnectFour/fields/{idx}.jpg')
        plt.imshow(test1)
        plt.show()
        test1 = transform(image=np.array(test1))['image']
        plt.imshow(test1)
        plt.show()
        with open(f'/home/robert/Documents/GitHub/OutSmarted/data/ConnectFour/fields/{idx}.txt') as f:
            print(f.read())
        output = model(to_tensor(test1).reshape(1, 3, 24, 24).to(DEVICE))
        print(torch.argmax(output).item())

In [None]:
# model.load_state_dict(torch.load('connectFourField.pt'))
main()

In [None]:
torch.save(model.state_dict(), 'connectFourField.pt')
torch.save(optimizer.state_dict(), 'optimF.pth')