# Setup enviroment

In [None]:
!pip3 install hiddenlayer > /dev/null
!pip3 install ipdb > /dev/null
import json
import os
import numpy as np
import random

import torch as tr
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision.datasets.folder import default_loader
from torchvision.transforms import functional as ft
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFilter

from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score

import matplotlib

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import hiddenlayer as hl

from google.colab import drive

from tqdm.notebook import tqdm

tr.backends.cudnn.deterministic = False
tr.backends.cudnn.benchmark = True
tr.manual_seed(42)
np.random.seed(42)

In [None]:
drive.mount('./drive')

In [None]:
os.chdir("/content/drive/My Drive/Workspace/pic2sgf/train")

# Board extractor

In [None]:
class BoardExtractor():
    def __init__(self, board_size):
        self.size = board_size * 16
        B = np.array([[8, 8, 1],
                      [self.size-8, 8, 1],
                      [self.size-8, self.size-8, 1],
                      [8, self.size-8, 1]]).T
        self.T2 = np.linalg.inv(B[:, 0:3] * np.linalg.solve(B[:, 0:3], B[:, 3]))

    def __call__(self, img, vertexs):
        A = np.concatenate([vertexs.T, np.array([[1.0, 1.0, 1.0, 1.0]])], axis=0)
        T1 = A[:, 0:3] * np.linalg.solve(A[:, 0:3], A[:, 3])
        T = np.matmul(T1, self.T2)
        T /= T[2,2]
        board = img.transform((self.size, self.size),
                              method=Image.PERSPECTIVE,
                              data = T.reshape(-1),
                              resample=Image.BILINEAR)
        return board.transpose(Image.ROTATE_180), T[2,0:2]

# Dataset

## Class

In [None]:
class BoardPositionDataset(Dataset):
    def __init__(self, metadata_file, size):
        super(BoardPositionDataset, self).__init__()
        self.board_extractor = BoardExtractor(size)
        with open(metadata_file, 'r') as f:
            metadata = json.load(f)
        self.augment_images = True
        self.displacement_scale = 4
        self.images = []
        self.position = []
        self.corners = []
        self.filename = []
        self.upsampler = nn.Upsample(size, mode='bilinear', align_corners=True)
        for entry in tqdm(metadata):
            if entry['size'] != size: continue

            img = self.load_image("images/" + entry['filename'])
            corner = np.array(entry['corners'])
            corner[:,0] = corner[:,0] / 100 * img.size[0]
            corner[:,1] = corner[:,1] / 100 * img.size[1]

            self.filename.append(entry['filename'])
            self.corners.append(corner)
            self.images.append(img)
            self.position.append(tr.LongTensor(entry['positions']).\
                                 reshape(size, size).permute(1,0).flip(1) + 1)

    def load_image(self, filename):
        img = default_loader(filename)
        if img.size[0] > img.size[1]: 
            img = img.resize((512,384), resample=Image.BILINEAR)
        if img.size[0] < img.size[1]: 
            img = img.resize((384,512), resample=Image.BILINEAR)
        if img.size[0] == img.size[1]: 
            img = img.resize((384,384), resample=Image.BILINEAR)
        return img


    def augment(self, img, pos, dis):
        if self.augment_images:
            img = ft.adjust_brightness(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_contrast(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_gamma(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_saturation(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))

            if np.random.binomial(1, 0.2): 
                img = img.filter(ImageFilter.GaussianBlur(radius = np.clip(np.random.normal(loc=0.5, scale=0.25), 0, 1)))

            if np.random.binomial(1, 0.5): 
                img, pos, dis = ft.hflip(img), pos.flip(1), dis.flip(2)
                dis[0,:,:] *= -1
            # if np.random.binomial(1, 0.5): 
            #     img, pos, dis = ft.vflip(img), pos.flip(0), dis.flip(1)
            #     dis[1,:,:] *= -1
        return img, pos, dis

    def __len__(self):
      return len(self.images)

    def __getitem__(self, i):
        corner = np.copy(self.corners[i])
        wrong = 0
        displacement = np.random.normal(loc=0, scale=self.displacement_scale, size=(4,2))
        if self.augment_images:
            corner += displacement
        # if self.create_wrong and np.random.binomial(1, 0.1):
        #     wrong = True
        #     idx = np.random.randint(0, 4)
        #     corner[idx] += np.random.randint(12, 100, size=2) * (2*np.random.binomial([0,1], 0.5)-1)
        image, _ = self.board_extractor(self.images[i], corner)

        displacement = tr.stack([tr.Tensor(displacement[[2,3,1,0], 0].reshape(2, 2)),
                                 tr.Tensor(displacement[[2,3,1,0], 1].reshape(2, 2))], dim=0)
        displacement = self.upsampler(displacement.unsqueeze(0)).squeeze()

        position = self.position[i].clone()
        image, position, displacement = self.augment(image, position, displacement)
        return ft.to_tensor(image), position, displacement

In [None]:
dataset09 = BoardPositionDataset('metadata.json', 9)
dataset13 = BoardPositionDataset('metadata.json', 13)
dataset19 = BoardPositionDataset('metadata.json', 19)

train09, test09 = random_split(dataset09, (len(dataset09)-16, 16))
train13, test13 = random_split(dataset13, (len(dataset13)-4, 4))
train19, test19 = random_split(dataset19, (len(dataset19)-8, 8))

## Plots

In [None]:
dataset = dataset09

plt.figure(figsize=(16, 16))
for i in range(9):
    plt.subplot(3, 3, i+1)
    index = random.randrange(len(dataset))
    img, _, disp = dataset[index]
    plt.imshow(img.permute(1, 2, 0),
               extent=(0,1,0,1))
    plt.imshow(disp[0,:,:],
               cmap='seismic',
               extent=(0,1,0,1),
               norm = Normalize(vmin = -10, vmax = 10),
               alpha=0.4)
    plt.title(dataset.filename[index])
plt.show()

In [None]:
dataset = dataset13

plt.figure(figsize=(16, 16))
for i in range(9):
    plt.subplot(3, 3, i+1)
    index = random.randrange(len(dataset))
    img, bpos, _ = dataset[index]
    plt.imshow(img.permute(1, 2, 0),
               extent=(0, 1, 0, 1))
    plt.imshow(bpos,
               norm = Normalize(vmin = 0, vmax = 2),
               cmap='bwr',
               extent=(0,1,0,1),
               alpha=0.2)
    plt.title(dataset.filename[index])
plt.show()

# Model

In [None]:
class iblock(nn.Module):
    def __init__(self, dims):
        super(iblock, self).__init__()
        self.conv_path = nn.Sequential(nn.BatchNorm2d(dims), nn.GELU(),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1),
                                       nn.GELU(), nn.BatchNorm2d(dims),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1))
        
    def forward(self, x):
        return x + self.conv_path(x)


class GlobalFeatures(nn.Module):
    def __init__(self, in_dim, dim):
        super(GlobalFeatures, self).__init__()
        self.input = nn.Sequential(
            nn.Conv2d(in_dim, dim, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(dim),
            nn.Conv2d(dim, dim, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(dim),
            nn.AdaptiveAvgPool2d(1)
            )
        self.mixer = nn.Sequential(
            nn.Conv2d(in_dim + dim, in_dim, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(in_dim)
            )

    def forward(self, x):
        gf = self.input(x)
        gf = nn.functional.interpolate(gf, size=(x.shape[2], x.shape[3]))
        x = tr.cat([x, gf], dim=1)
        return self.mixer(x)


def Pooling(in_dim, out_dim):
    return nn.Sequential(nn.BatchNorm2d(in_dim), nn.GELU(),
                         nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
                         nn.MaxPool2d(2))


class Interpreter(nn.Module):
    def __init__(self):
        super(Interpreter, self).__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=2, stride=2),
            iblock(12), iblock(12),
            Pooling(12, 24),
            iblock(24), iblock(24),
            Pooling(24, 48),
            iblock(48), iblock(48),
            Pooling(48, 96),
            iblock(96), iblock(96),
            GlobalFeatures(96, 48),
            nn.GELU(), nn.BatchNorm2d(96)
            )
        self.displacement = nn.Conv2d(96, 2, kernel_size=1)
        self.position = nn.Sequential(
            nn.Conv2d(96, 48, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(48),
            nn.Conv2d(48, 24, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(24),
            nn.Conv2d(24, 3, kernel_size=1)
        )

    def forward(self, x):
        x = self.conv_blocks(x)
        position = self.position(x)
        displacement = self.displacement(x)
        return position, displacement
      
    def load(self, fname):
        self.load_state_dict(tr.load(fname, map_location=lambda storage, loc: storage))

    def save(self, fname):
        tr.save(self.state_dict(), fname)

print(sum(p.numel() for p in Interpreter().parameters()))

# Training

## Auxiliary

In [None]:
nc = [0.0,0.0,0.0]
for ds in [dataset09, dataset13, dataset19]:
    for _, pos, _ in ds:
        for i in range(3):
            nc[i] += (pos==i).sum()
tot = sum(nc)
class_weight = tr.Tensor([tot / nc[i] for i in range(3)]).cuda()

In [None]:
def acc(pos, pred):
    pos = pos.cpu().detach().numpy()
    pred= pred.cpu().detach().numpy().argmax(axis=1)
    res = balanced_accuracy_score(pos.reshape(-1), pred.reshape(-1))
    return res

def fit(model, data_loader, optimizer, class_weight):
    gtrue = []
    preds = []
    total_loss = 0
    total_disp_loss = 0
    for img, pos, disp in data_loader:
        img, pos, disp = img.cuda(), pos.cuda(), disp.cuda()
        ppos, pdisp = model(img)

        optimizer.zero_grad()
        pos_loss  = nn.functional.cross_entropy(ppos, pos, weight=class_weight)
        disp_loss = nn.functional.mse_loss(pdisp, disp)
        loss = pos_loss + 0.1*disp_loss
        loss.backward()
        optimizer.step()

        gtrue.append( pos)
        preds.append(ppos)

        total_loss += loss.data.item() / len(data_loader)
        total_disp_loss += disp_loss.mean().data.item() / len(data_loader)
    pos_acc = acc(tr.cat(gtrue), tr.cat(preds))
    return total_loss, pos_acc, total_disp_loss

def evaluate_model(model, data_loader, class_weight):
    gtrue = []
    preds = []
    total_loss = 0
    total_disp_loss = 0
    for img, pos, disp in data_loader:
        img, pos, disp = img.cuda(), pos.cuda(), disp.cuda()
        ppos, pdisp = model(img)

        pos_loss = nn.functional.cross_entropy(ppos, pos, weight=class_weight)
        disp_loss = nn.functional.mse_loss(pdisp, disp)
        total_loss += (pos_loss + 0.1*disp_loss).item() / len(data_loader)
        total_disp_loss += disp_loss.mean().item()

        gtrue.append( pos)
        preds.append(ppos)
    test_acc = acc(tr.cat(gtrue), tr.cat(preds))
    return test_acc, total_loss, total_disp_loss

## Training

In [None]:
model = Interpreter().cuda()
tr.manual_seed(42)
np.random.seed(42)

fname ='interpreter_35'

In [None]:
optimizer = tr.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

In [None]:
hist = hl.History()
canvas = hl.Canvas()

if os.path.isfile("models/" + fname + ".pmt"):
    model.load("models/" + fname + ".pmt")
    hist.load("models/" + fname + ".hist")
    epoch = len(hist['mean_best'].data) + 1
    mean_best = min(hist['mean_best'].data)
else:
    epoch = 0
    mean_best = float('inf')

dataset09.augment_images = True
dataset13.augment_images = True
dataset19.augment_images = True

train09_loader = DataLoader(train09, batch_size=8, shuffle=True)
train13_loader = DataLoader(train13, batch_size=8, shuffle=True)
train19_loader = DataLoader(train19, batch_size=8, shuffle=True)
test09_loader = DataLoader(test09, batch_size=8, shuffle=True)
test13_loader = DataLoader(test13, batch_size=8, shuffle=True)
test19_loader = DataLoader(test19, batch_size=8, shuffle=True)
      
epochs_without_improvement = 0  
while epochs_without_improvement < 200:
    model.train()
    train09_loss, train09_acc, train09_dloss = fit(model, train09_loader, optimizer, class_weight)
    train19_loss, train19_acc, train19_dloss = fit(model, train19_loader, optimizer, class_weight)
    train13_loss, train13_acc, train13_dloss = fit(model, train13_loader, optimizer, class_weight)

    model.eval()
    test09_acc, test09_loss, test09_dloss = evaluate_model(model, test09_loader, class_weight)
    test19_acc, test19_loss, test19_dloss = evaluate_model(model, test19_loader, class_weight)
    test13_acc, test13_loss, test13_dloss = evaluate_model(model, test13_loader, class_weight)
    
    mean_loss = (test09_loss + test13_loss + test19_loss) / 3

    scheduler.step(mean_loss)
    if mean_loss < mean_best:
        model.save("models/" + fname + ".pmt")
        hist.save("models/" + fname + ".hist")
        mean_best = mean_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    hist.log(epoch, mean_loss=mean_loss,
                    mean_best=mean_best,
                    train09_acc=train09_acc,
                    test09_acc =test09_acc,
                    train13_acc=train13_acc,
                    test13_acc =test13_acc,
                    train19_acc=train19_acc,
                    test19_acc =test19_acc,
                    test09_dloss=test09_dloss,
                    test13_dloss=test13_dloss,
                    test19_dloss=test19_dloss
             )

    if epoch > 0:
        with canvas:
            canvas.draw_plot([hist["mean_loss"],
                              hist["mean_best"]])
            canvas.draw_plot([hist["train09_acc"],
                              hist["train13_acc"],
                              hist["train19_acc"]])
            canvas.draw_plot([hist["test09_acc"],
                              hist["test13_acc"],
                              hist["test19_acc"]])
            canvas.draw_plot([hist["test09_dloss"],
                              hist["test13_dloss"],
                              hist["test19_dloss"]])
    epoch += 1

hist.summary()
hist.save("models/" + fname + ".hist")
model.load("models/" + fname + ".pmt")

12 - 2 - 495233 = 0.08741

10 - 2 - 344455 = 0.1020

10 - 3 - 498355 = 0.0987

12 - 1 - 273833 = 0.1222

12 - 3 - 716633 = 0.0901

12 - 2 - 525473 - MP = 0.07692

# Test

In [None]:
model.load("models/" + fname + ".pmt")
model.eval()

ds = dataset19
# ds.augment_images = False

nplots = 1
plt.figure(figsize=(20, 20))
for i, sample in enumerate(ds):
    img, pos = sample[0].unsqueeze(0).cuda(), sample[1].unsqueeze(0)
    pred, _ = model(img)
    pred = pred.detach().cpu()

    if acc(pos, pred) < 1.0:
        if nplots > 16: continue
        plt.subplot(4, 4, nplots)
        nplots += 1
        pred = pred.round().squeeze()
        plt.imshow(img.squeeze().cpu().permute(1, 2, 0),
                    extent=(0,1,0,1))
        plt.imshow(pred.argmax(axis=0),
                   norm = Normalize(vmin = 0, vmax = 2),
                   cmap='bwr',
                   extent=(0,1,0,1),
                   alpha=0.25)
        plt.title(ds.filename[i])

plt.show()

# Tilt distribution

In [None]:
board_extractor = {9 : BoardExtractor(9),
                   13: BoardExtractor(13),
                   19: BoardExtractor(19)}

filename = []
tilts = []
with open('metadata.json', 'r') as f:
    metadata = json.load(f)
    for entry in tqdm(metadata):
        img = default_loader("images/" + entry['filename'])
        if img.size[0] > img.size[1]: 
            img = img.resize((512,384), resample=Image.BILINEAR)
        if img.size[0] < img.size[1]: 
            img = img.resize((384,512), resample=Image.BILINEAR)
        if img.size[0] == img.size[1]: 
            img = img.resize((384,384), resample=Image.BILINEAR)
    
        corner = np.array(entry['corners'])
        corner[:,0] = corner[:,0] / 100 * img.size[0]
        corner[:,1] = corner[:,1] / 100 * img.size[1]

        filename.append(entry['filename'])
        _, tilt = board_extractor[entry['size']](img, corner)
        tilts.append(tilt)

In [None]:
np.concatenate(tilts, axis=0).max()