In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.optim import AdamW
from torchvision.transforms.functional import to_tensor
import torchvision.transforms.functional as FT
import os
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from collections import Counter
import albumentations as A
import wandb

In [2]:
BATCH_SIZE = 128
EPOCHS = 1000

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


In [None]:
# 168x168
config = [
    (32, 3, 1),
    (64, 3, 2),
    ["B", 1],
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 4],
    (256, 3, 2),
    ["B", 4],
    (256, 3, 2),
    ["B", 2],  # To this point is Darknet-53
    (128, 1, 1),
    (64, 3, 1),
    "S",
]
# 36 | 9*4

class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, bn_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels)
        self.leaky = nn.LeakyReLU(0.1)
        self.use_bn_act = bn_act

    def forward(self, x):
        if self.use_bn_act:
            return self.leaky(self.bn(self.conv(x)))
        else:
            return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels, use_residual=True, num_repeats=1):
        super().__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels, channels // 2, kernel_size=1),
                    CNNBlock(channels // 2, channels, kernel_size=3, padding=1),
                )
            ]
        self.use_residual = use_residual
        self.num_repeats = num_repeats

    def forward(self, x):
        for layer in self.layers:
            if self.use_residual:
                x = x + layer(x)
            else:
                x = layer(x)
        return x

class ScalePrediction(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels, in_channels, kernel_size=3, padding=1),
            CNNBlock(in_channels, 16, bn_act=False, kernel_size=1),
            nn.Flatten(),
            nn.Linear(576, 36),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.pred(x)

class BoardModel(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()
        self.in_channels = in_channels
        self.layers = self._create_conv_layers()

    def forward(self, x):
        for layer in self.layers:
          x = layer(x)
        return torch.flatten(x, start_dim=1)

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for idx, module in enumerate(config):
            if isinstance(module, tuple):
                out_channels, kernel_size, stride = module
                layers.append(
                    CNNBlock(
                        in_channels,
                        out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=1 if kernel_size >= 2 else 0,
                    )
                )
                in_channels = out_channels

            elif isinstance(module, list):
                num_repeats = module[1]
                layers.append(ResidualBlock(in_channels, num_repeats=num_repeats,))

            elif isinstance(module, str):
                layers += [
                    ResidualBlock(in_channels, use_residual=False, num_repeats=1),
                    CNNBlock(in_channels, in_channels // 2, kernel_size=1),
                    ScalePrediction(in_channels // 2),
                ]
        return layers

In [None]:
class CustomDataset(torch.utils.data.Dataset):
    def __init__(
        self, dir, transform=None):
        self.dir = dir
        self.transform = transform
        self.n = len([imgFile for imgFile in os.listdir(dir) if imgFile.endswith('.jpg')])

    def __len__(self):
        return self.n

    def __getitem__(self, index):
        label_path = os.path.join(self.dir, f"{index}.txt")
        boxes = []
        with open(label_path) as f:
            for label in f.readlines():
                class_label, x, y, width, height = [
                    float(x) if float(x) != int(float(x)) else int(x)
                    for x in label.replace("\n", "").split()
                ]
                boxes.append([class_label, x, y, width, height])
        boxes.sort()
        boxes = [box[1:] for box in boxes]
        boxes = torch.tensor(boxes)

        img_path = os.path.join(self.dir, f"{index}.jpg")
        image = Image.open(img_path)
        image = image.convert('L')
        if self.transform:
          image = np.array(image)
          image = self.transform(image=image)['image']
        image = to_tensor(image)

        return image, boxes

In [None]:
model = BoardModel().to(DEVICE)
optimizer = AdamW(
    model.parameters()
)
criterion = torch.nn.MSELoss(reduction='sum')

transform = A.Compose([
    A.InvertImg(p=0.1),
    A.ColorJitter(brightness=0.55, contrast=0.6, saturation=0.6, hue=0.6, p=0.4),
    A.GaussNoise(p=0.12),
    A.Blur(blur_limit=3, p=0.22),
    A.GlassBlur(max_delta=1, iterations=1, p=0.14),
    A.CLAHE(p=0.22, tile_grid_size=(4, 4)),
    A.Sharpen(p=0.18, alpha=0.2, lightness=1.5),
    A.Emboss(p=0.18),
    A.Equalize(p=0.04),
    A.MultiplicativeNoise(p=0.22),
    A.RandomBrightness(p=0.22),
    A.RandomContrast(p=0.22),
    A.RandomGamma(p=0.22),
    A.Solarize(threshold=128, p=0.2),
])

train_dataset = CustomDataset(
    "drive/MyDrive/boardsTTT",
    transform=transform
)
print(train_dataset.n)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
)

In [None]:
def main():
    wandb.init(project="TicTacToeYOLO", entity="robertfoerster")
    model.train()
    for epoch in range(EPOCHS):
        for batch_idx, (x, y) in enumerate(train_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            out = model(x).reshape(-1, 9, 4)
            loss = criterion(out, y)

            wandb.log({
                'train/train_loss': loss,
                'train/epoch': epoch,
            })

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    wandb.finish()

@torch.no_grad()
def show():
    model.eval()
    for x, y in train_loader:
        x, y = x.to(DEVICE), y.to(DEVICE)
        break
    out = model(x).reshape(-1, 9, 4)
    out, x = out.cpu().detach().numpy(), x.cpu().detach().numpy()
    for idx in range(9):
      plot_image(x[idx], out[idx])

def plot_image(image, boxes):
    im = image.reshape(168, 168)
    height, width = im.shape

    fig, ax = plt.subplots(1)
    ax.imshow(im, cmap='gray')

    for box in boxes:
        upper_left_x = box[0] - box[2] / 2
        upper_left_y = box[1] - box[3] / 2
        rect = patches.Rectangle(
            (upper_left_x * width, upper_left_y * height),
            box[2] * width,
            box[3] * height,
            linewidth=1,
            edgecolor="r",
            facecolor="none",
        )
        ax.add_patch(rect)
    plt.show()
            

In [None]:
main()

In [None]:
torch.save(model.state_dict(), "drive/MyDrive/ttttestM.pth")
torch.save(optimizer.state_dict(), "drive/MyDrive/ttttestO.pth")

In [None]:
show()