# Setup enviroment

In [None]:
!pip3 install hiddenlayer > /dev/null
!pip3 install ipdb > /dev/null
import json
import os
import numpy as np
import random
from datetime import datetime as dt

import torch as tr
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision.datasets.folder import default_loader
from torchvision.transforms import functional as ft
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFilter

from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score

import matplotlib

import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import hiddenlayer as hl

from google.colab import drive

from tqdm.notebook import tqdm

tr.backends.cudnn.deterministic = False
tr.backends.cudnn.benchmark = True
tr.manual_seed(42)
np.random.seed(42)

In [None]:
drive.mount('./drive')

In [None]:
os.chdir("/content/drive/My Drive/Workspace/pic2sgf/training")

# Board extractor

In [None]:
class BoardExtractor():
    def __init__(self, board_size):
        self.size = board_size * 24
        B = np.array([[8, 8, 1],
                      [self.size-8, 8, 1],
                      [self.size-8, self.size-8, 1],
                      [8, self.size-8, 1]]).T
        self.T2 = np.linalg.inv(B[:, 0:3] * np.linalg.solve(B[:, 0:3], B[:, 3]))

    def __call__(self, img, vertexs):
        A = np.concatenate([vertexs.T, np.array([[1.0, 1.0, 1.0, 1.0]])], axis=0)
        T1 = A[:, 0:3] * np.linalg.solve(A[:, 0:3], A[:, 3])
        T = np.matmul(T1, self.T2)
        T /= T[2,2]
        board = img.transform((self.size, self.size),
                              method=Image.PERSPECTIVE,
                              data = T.reshape(-1),
                              resample=Image.BILINEAR)
        return board.transpose(Image.ROTATE_180), T[2,0:2]

# Dataset

In [None]:
class BoardPositionDataset(Dataset):
    def __init__(self, metadata_files, size, augment=True, create_wrong=True, only_labelled=True):
        super(BoardPositionDataset, self).__init__()
        self.board_extractor = BoardExtractor(size)
        self.augment_images = augment
        self.create_wrong = create_wrong
        self.only_labelled = only_labelled
        self.displacement_scale = 4
        self.images = []
        self.position = {}
        self.corners = []
        self.orig_corners = []
        self.filename = []
        self.size = []
        self.upsampler = nn.Upsample(size, mode='bilinear', align_corners=True)

        for file in metadata_files:
            print(f'Loading {file}...')
            with open(file, 'r') as f:
                metadata = json.load(f)
            for entry in tqdm(metadata):
                if int(entry['size']) != size: continue
                if self.only_labelled and ('positions' not in entry): continue
                self.size.append(entry['size'])
                
                img = self.load_image("images/" + entry['filename'])
                self.orig_corners.append(entry['corners'])
                corner = np.array(entry['corners'])/100
                corner[:, 0] *= img.size[0]
                corner[:, 1] *= img.size[1]
                
                self.corners.append(self.order_vertexs(corner, img.size))
                self.filename.append(entry['filename'])
                self.images.append(img)
                if 'positions' in entry:
                    self.position[len(self.images) - 1] = tr.LongTensor(entry['positions']).reshape(size, size).permute(1,0).flip(1) + 1

    def load_image(self, filename):
        img = default_loader(filename)
        if img.size[0] > img.size[1]:
            img = img.resize((512, 384), resample=Image.BILINEAR)
        else:
            img = img.resize((384, 512), resample=Image.BILINEAR).transpose(Image.ROTATE_90)
        return img


    def augment(self, img, pos, dis):
        if self.augment_images:
            img = ft.adjust_brightness(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_contrast(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_gamma(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))
            img = ft.adjust_saturation(img, np.clip(np.random.normal(loc=1, scale=0.3), 0.3, 1.8))

            if np.random.binomial(1, 0.2): 
                img = img.filter(ImageFilter.GaussianBlur(radius = np.clip(np.random.normal(loc=0.5, scale=0.25), 0, 1)))

            if np.random.binomial(1, 0.5): 
                img, dis = ft.hflip(img), dis.flip(2)
                if pos is not None:
                    pos = pos.flip(1)
                dis[0,:,:] *= -1
        return img, pos, dis

    def order_vertexs(self, v, img_size):
        w, h = img_size
        vc = v.copy()
        idxs = np.ones(4).astype(int)
        idxs[0] = np.linalg.norm(vc, ord=2, axis=1).argmin()
        vc[idxs[0]] = np.array([float('inf'), float('inf')])

        idxs[1] = np.linalg.norm(vc - np.array([w,0]), ord=2, axis=1).argmin()
        vc[idxs[1]] = np.array([float('inf'), float('inf')])

        idxs[2] = np.linalg.norm(vc - np.array([w,h]), ord=2, axis=1).argmin()
        vc[idxs[2]] = np.array([float('inf'), float('inf')])

        idxs[3] = np.linalg.norm(vc - np.array([0,h]), ord=2, axis=1).argmin()
        vc[idxs[3]] = np.array([float('inf'), float('inf')])

        last_prod = 0
        for i in range(len(v)):
            prev = v[idxs[(i-1)%4]] - v[idxs[i]]
            post = v[idxs[(i+1)%4]] - v[idxs[i]]
            cross_prod = np.cross(post, prev)
            if cross_prod * last_prod < 0:
                idxs[i], idxs[(i+1)%4] = idxs[(i+1)%4].copy(), idxs[i].copy()
                i += 1
            else:
                last_prod = cross_prod
        return v[idxs]

    def __len__(self):
      return len(self.images)

    def __getitem__(self, i):
        corner = np.copy(self.corners[i])
        wrong = 0
        displacement = np.random.normal(loc=0, scale=self.displacement_scale, size=(4,2))
        if self.augment_images:
            corner += displacement
        if self.create_wrong and np.random.binomial(1, 0.2):
            wrong = 1
            idx = np.random.randint(0, 4)
            corner[idx] += np.random.randint(12, 100, size=2) * (2*np.random.binomial([0,1], 0.5)-1)
        image, _ = self.board_extractor(self.images[i], corner)

        displacement = tr.stack([tr.Tensor(displacement[[2,3,1,0], 0].reshape(2, 2)),
                                tr.Tensor(displacement[[2,3,1,0], 1].reshape(2, 2))], dim=0)
        displacement = self.upsampler(displacement.unsqueeze(0)).squeeze()

        position = self.position[i].clone() if i in self.position else None
        if self.augment_images:
            image, position, displacement = self.augment(image, position, displacement)
        return (
            ft.to_tensor(image),
            position,
            displacement,
            tr.Tensor([wrong])
        )

In [None]:
dataset09 = BoardPositionDataset(['metadata.json', 'self_labels.json'], 9)
dataset13 = BoardPositionDataset(['metadata.json', 'self_labels.json'], 13)
dataset19 = BoardPositionDataset(['metadata.json', 'self_labels.json'], 19)

## Plots

In [None]:
dataset = dataset09

plt.figure(figsize=(16, 16))
for i in range(9):
    plt.subplot(3, 3, i+1)
    index = random.randrange(len(dataset))
    img, _, disp, _ = dataset[index]
    plt.imshow(img.permute(1, 2, 0),
               extent=(0,1,0,1))
    plt.imshow(disp[0,:,:],
               cmap='seismic',
               extent=(0,1,0,1),
               norm = Normalize(vmin = -10, vmax = 10),
               alpha=0.4)
    plt.title(dataset.filename[index])
plt.show()

In [None]:
dataset = dataset09

plt.figure(figsize=(16, 16))
for i in range(9):
    plt.subplot(3, 3, i+1)
    index = random.randrange(len(dataset))
    img, bpos, _, _ = dataset[index]
    plt.imshow(img.permute(1, 2, 0),
               extent=(0, 1, 0, 1))
    plt.imshow(bpos,
               norm = Normalize(vmin = 0, vmax = 2),
               cmap='bwr',
               extent=(0,1,0,1),
               alpha=0.2)
    plt.title(dataset.filename[index])
plt.show()

# Model

In [None]:
class iblock(nn.Module):
    def __init__(self, dims):
        super(iblock, self).__init__()
        self.conv_path = nn.Sequential(nn.BatchNorm2d(dims), nn.GELU(), nn.Dropout2d(0.1),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1),
                                       nn.GELU(), nn.BatchNorm2d(dims),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1))
        
    def forward(self, x):
        return x + self.conv_path(x)


def Pooling(in_dim, out_dim):
    return nn.Sequential(nn.BatchNorm2d(in_dim), nn.GELU(),
                         nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1),
                         nn.MaxPool2d(2))


class Interpreter(nn.Module):
    def __init__(self):
        super(Interpreter, self).__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=3),
            nn.GELU(), nn.BatchNorm2d(32),
            nn.Conv2d(32, 12, kernel_size=3, padding=1),
            iblock(12), iblock(12), iblock(12), iblock(12),
            Pooling(12, 24), 
            iblock(24), iblock(24), iblock(24), iblock(24),
            Pooling(24, 48), 
            iblock(48), iblock(48), iblock(48), iblock(48),
            Pooling(48, 96),
            iblock(96), iblock(96), iblock(96), iblock(96),
            nn.GELU(), nn.BatchNorm2d(96),
            )
        self.displacement = nn.Sequential(
            nn.Conv2d(96, 24, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(24),
            nn.Conv2d(24, 2, kernel_size=1)
        )
        self.position = nn.Sequential(
            nn.Conv2d(96, 24, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(24),
            nn.Conv2d(24, 3, kernel_size=1)
        )
        self.wrong = nn.Sequential(
            nn.Conv2d(96, 24, kernel_size=1),
            nn.GELU(), nn.BatchNorm2d(24),
            nn.Conv2d(24, 1, kernel_size=1),
            nn.AdaptiveMaxPool2d(1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.conv_blocks(x)
        position = self.position(x)
        displacement = self.displacement(x)
        wrong = self.wrong(x).view(-1, 1)
        return position, displacement, wrong
      
    def load(self, fname):
        self.load_state_dict(tr.load(fname, map_location=lambda storage, loc: storage))

    def save(self, fname):
        tr.save(self.state_dict(), fname)

print(sum(p.numel() for p in Interpreter().parameters()))

# Training

In [None]:
nc = [0.0,0.0,0.0]
for ds in [dataset09, dataset13, dataset19]:
    for _, pos, _, _ in ds:
        for i in range(3):
            nc[i] += (pos==i).sum()
smooth = 5000
tot = sum(nc) + 2 * smooth
nc[0] += smooth
nc[2] += smooth
class_weight = tr.Tensor([tot / nc[i] for i in range(3)]).cuda()
class_weight /= class_weight.sum()

In [None]:
def acc(pos, pred):
    pos = pos.cpu().detach().numpy()
    pred= pred.cpu().detach().numpy().argmax(axis=1)
    res = balanced_accuracy_score(pos.reshape(-1), pred.reshape(-1))
    return res

def fit(model, data_loader, optimizer, class_weight, backward=True):
    gtrue = []
    preds = []
    wtrue = []
    wpred = []
    total_loss = 0
    total_disp_loss = 0
    if backward: model.train()
    else: model.eval()
    for img, pos, disp, wr in data_loader:
        img, pos, disp, wr = img.cuda(), pos.cuda(), disp.cuda(), wr.cuda()
        ppos, pdisp, pwr = model(img)

        if backward: optimizer.zero_grad()
        correct = (wr < 0.5).squeeze(1)
        pos_loss  = nn.functional.cross_entropy(ppos[correct], pos[correct], weight=class_weight)
        disp_loss = nn.functional.mse_loss(pdisp[correct], disp[correct])
        wrong_loss = nn.functional.binary_cross_entropy(pwr, wr)
        loss = pos_loss + wrong_loss + 0.1*disp_loss
        if backward: loss.backward()
        if backward: optimizer.step()

        gtrue.append( pos[correct].detach())
        preds.append(ppos[correct].detach())
        wtrue.append( wr.detach())
        wpred.append(pwr.detach())

        total_loss += (pos_loss + wrong_loss).data.item() / len(data_loader)
        total_disp_loss += disp_loss.mean().data.item() / len(data_loader)
    pos_acc = acc(tr.cat(gtrue).cpu(), tr.cat(preds).cpu())
    wrong_acc = balanced_accuracy_score(tr.cat(wtrue).cpu() > 0.5, tr.cat(wpred).cpu() > 0.5)
    return total_loss, pos_acc, total_disp_loss, wrong_acc

def evaluate_model(model, data_loader, class_weight):
    return fit(model, data_loader, optimizer, class_weight, backward=False)

In [None]:
model = Interpreter().cuda()
tr.manual_seed(42)
np.random.seed(42)

fname ='interpreter_32'
# optimizer = tr.optim.Adam(model.parameters(), lr=0.002, weight_decay=1e-5)
# scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
optimizer = tr.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)

In [None]:
hist = hl.History()
canvas = hl.Canvas()

if os.path.isfile("models/" + fname + ".pmt"):
    model.load("models/" + fname + ".pmt")
    hist.load("models/" + fname + ".hist")
    epoch = len(hist['best'].data) + 1
    mean_best = min(hist['best'].data)
else:
    epoch = 0
    mean_best = float('inf')

dataset09.augment_images = True
dataset13.augment_images = True
dataset19.augment_images = True

train09_loader = DataLoader(dataset09, batch_size=12, shuffle=True)
train13_loader = DataLoader(dataset13, batch_size=12, shuffle=True)
train19_loader = DataLoader(dataset19, batch_size=12, shuffle=True)
      
epochs_without_improvement = 0  
while epochs_without_improvement < 200:
    train19_loss, train19_acc, train19_dloss, train19_wacc = fit(model, train19_loader, optimizer, class_weight)
    train09_loss, train09_acc, train09_dloss, train09_wacc = fit(model, train09_loader, optimizer, class_weight)
    train13_loss, train13_acc, train13_dloss, train13_wacc = fit(model, train13_loader, optimizer, class_weight)
    
    mean_train_loss = (train09_loss + train13_loss + train19_loss) / 3

    scheduler.step(mean_train_loss)
    if mean_train_loss < mean_best:
        model.save("models/" + fname + ".pmt")
        hist.save("models/" + fname + ".hist")
        mean_best = mean_train_loss
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1

    if epoch > 0:
        hist.log(epoch, train_loss=round(mean_train_loss, 5),
                        best=round(mean_best, 5),
                
                        tr09_pos=round(train09_acc, 5),
                        tr13_pos=round(train13_acc, 5),
                        tr19_pos=round(train19_acc, 5),
                
                        tr09_disp=round(train09_dloss, 5),
                        tr13_disp=round(train13_dloss, 5),
                        tr19_disp=round(train19_dloss, 5),
                
                        tr09_wacc=round(train09_wacc, 5),
                        tr13_wacc=round(train13_wacc, 5),
                        tr19_wacc=round(train19_wacc, 5)
                )
        with canvas:
            canvas.draw_plot([hist["train_loss"],
                              hist["best"]])
            canvas.draw_plot([hist["tr09_pos"],
                              hist["tr13_pos"],
                              hist["tr19_pos"]])
            canvas.draw_plot([hist["tr09_wacc"],
                              hist["tr13_wacc"],
                              hist["tr19_wacc"]])
            canvas.draw_plot([hist["tr09_disp"],
                              hist["tr13_disp"],
                              hist["tr19_disp"]])
    epoch += 1

hist.summary()
hist.save("models/" + fname + ".hist")
model.load("models/" + fname + ".pmt")

2x20: 0.01679 - 5.36

3x16: 0.01124 - 4.94

3x12: 0.01930 - 4.32

4x12: 0.00954 - 4.45

# Test

In [None]:
ds = BoardPositionDataset(['reddit.json', 'facebook_fixed.json'], 19, augment=False, create_wrong=False, only_labelled=False)

model = Interpreter()
fname ='interpreter_32'
model.load("models/" + fname + ".pmt")
model = model.eval()

In [None]:
chunk = 2
nplots = 1
plt.figure(figsize=(40, 40))
start = dt.now()
for i in range(36*chunk, min(36*(chunk+1), len(ds))):
    img, _, _, _ = ds[i]
    pred, _, _ = model(img.unsqueeze(0))
    pred = pred.detach().cpu()

    plt.subplot(6, 6, i+1 - 36*chunk)
    pred = pred.round().squeeze()
    plt.imshow(
        img.squeeze().cpu().permute(1, 2, 0),
        extent=(1,19,1,19))
    plt.imshow(
        pred.argmax(axis=0),
        norm = Normalize(vmin = 0, vmax = 2),
        cmap='bwr',
        extent=(1,19,1,19),
        alpha=0.25)
    plt.title(ds.filename[i])
print(f'Elapsed: {dt.now() - start}')
plt.show()

# Relabelling

In [None]:
with open('to_fix2.txt', 'r') as f:
    ignore = [line[:-1] + '.jpg' for line in f.readlines()]

model = Interpreter()
fname ='interpreter_31'
model.load("models/" + fname + ".pmt")
model.eval()

res = []
for s in [9, 13, 19]:
    ds = BoardPositionDataset(['reddit.json', 'facebook_fixed.json'], s, augment=False, create_wrong=False, only_labelled=False)
    for i in range(len(ds)):
        if ds.filename[i] in ignore: continue
        img, _, _, _ = ds[i]
        pred, _, _ = model(img.unsqueeze(0))
        pred = pred.detach().cpu()

        pred = pred.round().squeeze().argmax(axis=0).transpose(0,1).flipud()
        res.append(
            {
                'filename' : ds.filename[i],
                'positions' : [el-1 for row in pred.tolist() for el in row],
                'corners' : ds.orig_corners[i],
                'size' : ds.size[i]
            }
        )
json.dump(res, open('self_labels.json', 'w'), indent=2, separators=(',', ':'))

In [None]:
ds = BoardPositionDataset(['self_labels.json'], 19, augment=False, create_wrong=False, only_labelled=False)

plt.figure(figsize=(32, 32))
for i in range(25):
    plt.subplot(5, 5, i+1)
    index = random.randrange(len(ds))
    img, bpos, _, _ = ds[index]
    plt.imshow(img.permute(1, 2, 0),
               extent=(0, 1, 0, 1))
    plt.imshow(bpos,
               norm = Normalize(vmin = 0, vmax = 2),
               cmap='bwr',
               extent=(0,1,0,1),
               alpha=0.2)
    plt.title(ds.filename[index])
plt.show()

# Tilt distribution

In [None]:
board_extractor = {9 : BoardExtractor(9),
                   13: BoardExtractor(13),
                   19: BoardExtractor(19)}

filename = []
tilts = []
with open('metadata.json', 'r') as f:
    metadata = json.load(f)
    for entry in tqdm(metadata):
        img = default_loader("images/" + entry['filename'])
        if img.size[0] > img.size[1]: 
            img = img.resize((512,384), resample=Image.BILINEAR)
        if img.size[0] < img.size[1]: 
            img = img.resize((384,512), resample=Image.BILINEAR)
        if img.size[0] == img.size[1]: 
            img = img.resize((384,384), resample=Image.BILINEAR)
    
        corner = np.array(entry['corners'])
        corner[:,0] = corner[:,0] / 100 * img.size[0]
        corner[:,1] = corner[:,1] / 100 * img.size[1]

        filename.append(entry['filename'])
        _, tilt = board_extractor[entry['size']](img, corner)
        tilts.append(tilt)

In [None]:
np.concatenate(tilts, axis=0).max()