# Setup enviroment

In [None]:
!pip3 install hiddenlayer > /dev/null
!pip3 install ipdb > /dev/null
import json
import os
import numpy as np
import random
import time

import torch as tr
import torch.nn as nn
from torch.nn import BCELoss
from scipy import stats
from torch.utils.data import Dataset, DataLoader, random_split

from torchvision.datasets.folder import default_loader
from torchvision.transforms import functional as ft
from torchvision import transforms
from PIL import Image, ImageDraw, ImageFilter

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
import hiddenlayer as hl

from google.colab import drive

from tqdm.notebook import tqdm

from scipy import ndimage

tr.backends.cudnn.deterministic = False
tr.backends.cudnn.benchmark = True

In [None]:
drive.mount('./drive')

In [None]:
os.chdir("/content/drive/My Drive/Workspace/pic2sgf/training")

# Dataset

## Metadata fixing

In [None]:
with open('facebook.json', 'r') as f:
    metadata = json.load(f)

for i in tqdm(range(len(metadata))):
    filename = metadata[i]['filename']
    img = default_loader("images/" + filename)
    if img.size[0] < img.size[1]:
        ccoords = [[row[1], 100-row[0]] for row in metadata[i]['corners']]
        metadata[i]['corners'] = ccoords

with open('facebook_fixed.json', 'w') as f:
    json.dump(metadata, f)

## Dataset loading

In [None]:
class VertexSegmenterDataset(Dataset):
    def __init__(self, metadata_files):
        super(Dataset, self).__init__()
        self.filename = []
        self.corner_coords =  []
        self.label = []
        self.image = []

        for file in metadata_files:
            with open(file, 'r') as f:
                metadata = json.load(f)
            self.filename.extend([entry['filename'] for entry in metadata])
            self.corner_coords.extend([entry['corners'] for entry in metadata])
        
        self.image = [self.load_image("images/" + filename) for filename in tqdm(self.filename)]
        self.label = [self.gen_label(corner) for corner in self.corner_coords]

    def load_image(self, filename):
        img = default_loader(filename)
        if img.size[0] > img.size[1]:
            img = img.resize((512, 384), resample=Image.BILINEAR)
        else:
            img = img.resize((384, 512), resample=Image.BILINEAR).transpose(Image.ROTATE_90)
        return img

    def gen_label(self, corners):
        w, h = int(512 / 2), int(384 / 2)
        im = Image.new('RGB', (w, h))
        draw = ImageDraw.Draw(im)
        coords = [(int(c[0] * w / 100), int(c[1] * h / 100)) for c in corners]
        draw.polygon(coords, fill=(255, 0, 0), outline=(255, 255, 0))
        for c in coords:
            draw.ellipse([c[0]-1, c[1]-1, c[0]+1, c[1]+1], fill=(255, 255, 255))
        im = im.filter(ImageFilter.GaussianBlur(radius = 2))
        return im
    
    def augment(self, img, lbl):
        img = ft.adjust_brightness(img, np.clip(np.random.normal(loc=1, scale=0.2), 0, 2))
        img = ft.adjust_contrast(img, np.clip(np.random.normal(loc=1, scale=0.2), 0, 2))
        img = ft.adjust_gamma(img, np.clip(np.random.normal(loc=1, scale=0.2), 0, 2))
        img = ft.adjust_saturation(img, np.clip(np.random.normal(loc=1, scale=0.2), 0, 2))

        if np.random.binomial(1, 0.1): 
            img = img.filter(ImageFilter.GaussianBlur(radius = np.clip(np.random.normal(loc=2, scale=1), 0, 2)))

        if np.random.binomial(1, 0.5):
            img = ft.vflip(img)
            lbl = ft.vflip(lbl)

        if np.random.binomial(1, 0.5):
            img = ft.hflip(img)
            lbl = ft.hflip(lbl)

        angle = np.random.random()*180 - 90
        img_size, lbl_size = img.size, lbl.size
        fill_color = (np.random.randint(0,255), np.random.randint(0,255), np.random.randint(0,255))
        img = ft.rotate(img, angle, resample=Image.NEAREST, expand=True, fill=fill_color)
        lbl = ft.rotate(lbl, angle, resample=Image.BICUBIC, expand=True)
        img = img.resize(img_size, resample=Image.BILINEAR)
        lbl = lbl.resize(lbl_size, resample=Image.BICUBIC)
        
        img, lbl = ft.to_tensor(img), ft.to_tensor(lbl)
        mv = lbl.max(dim=1)[0].max(dim=1)[0]
        lbl = lbl / mv.unsqueeze(1).unsqueeze(1)
        return img, lbl

    def __len__(self):
      return len(self.filename)

    def __getitem__(self, i):
        image = self.image[i]
        label = self.label[i]
        return self.augment(image, label)

dataset = VertexSegmenterDataset(['facebook_fixed.json', 'reddit.json', 'metadata.json'])

In [None]:
plt.figure(figsize=(20,40))
nplot = 1
for i in range(40):
    plt.subplot(10, 4, nplot)
    nplot += 1
    index = random.randrange(len(dataset))
    img, lbl = dataset[index]
    plt.title(dataset.filename[index])
    plt.imshow(img.permute(1,2,0),
               extent=(0, 1, 0, 1))
    plt.imshow(lbl.permute(1,2,0),
               extent=(0, 1, 0, 1),
               alpha=0.25)
plt.show()

# Models

In [None]:
class iblock(nn.Module):
    def __init__(self, dims):
        super(iblock, self).__init__()
        self.conv_path = nn.Sequential(nn.BatchNorm2d(dims), nn.GELU(),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1),
                                       nn.GELU(), nn.BatchNorm2d(dims),
                                       nn.Conv2d(dims, dims, kernel_size=3, padding=1))
        
    def forward(self, x):
        return x + self.conv_path(x)


def Pooling(in_dim, out_dim):
    return nn.Sequential(nn.BatchNorm2d(in_dim), nn.GELU(),
                         nn.Conv2d(in_dim, out_dim, kernel_size=2, stride=2))


class Segmenter(nn.Module):
    def __init__(self):
        super(Segmenter, self).__init__()
        self.downscale = nn.ModuleList([Pooling(10, 20), Pooling(20, 40), Pooling(40, 80), Pooling(80, 80)])
        self.upscale = nn.ModuleList([nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                      nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])
        
        self.pre_cnn = nn.Sequential(
            nn.Conv2d(3, 12, kernel_size=2, stride=2),
            nn.BatchNorm2d(12), nn.GELU(),
            nn.Conv2d(12, 10, kernel_size=1),
            )

        self.in_cnn = nn.ModuleList([nn.Sequential(iblock(10), iblock(10), iblock(10)),
                                     nn.Sequential(iblock(20), iblock(20), iblock(20)),
                                     nn.Sequential(iblock(40), iblock(40), iblock(40)),
                                     nn.Sequential(iblock(80), iblock(80), iblock(80))
                                     ])

        self.bottom = nn.Sequential(iblock(80), iblock(80),
                                    nn.BatchNorm2d(80), nn.GELU(),
                                    nn.Conv2d(80, 80, kernel_size=1))

        self.out_cnn = nn.ModuleList([nn.Sequential(iblock(80), iblock(80),
                                                    nn.GELU(), nn.BatchNorm2d(80),
                                                    nn.Conv2d(80, 40, kernel_size=1)),
                                      nn.Sequential(iblock(40), iblock(40),
                                                    nn.GELU(), nn.BatchNorm2d(40),
                                                    nn.Conv2d(40, 20, kernel_size=1)),
                                      nn.Sequential(iblock(20), iblock(20),
                                                    nn.GELU(), nn.BatchNorm2d(20),
                                                    nn.Conv2d(20, 10, kernel_size=1)),
                                      nn.Sequential(iblock(10), iblock(10),
                                                    nn.GELU(), nn.BatchNorm2d(10))
        ])
            
        self.last_cnn = nn.Sequential(nn.Conv2d(10, 3, kernel_size=1), nn.Sigmoid())

    def forward(self, x):
        x = self.pre_cnn(x)
        mid = []
        for i in range(len(self.in_cnn)):
            x = self.in_cnn[i](x)
            mid.append(x)
            x = self.downscale[i](x)
        
        x = self.bottom(x)

        for i in range(len(self.out_cnn)):
            x = self.out_cnn[i]( self.upscale[i](x) + mid.pop() )
        x = self.last_cnn(x)
        return x

    def load(self, fname):
        self.load_state_dict(tr.load(fname, map_location=lambda storage, loc: storage))

    def save(self, fname):
        tr.save(self.state_dict(), fname)

npar = sum(p.numel() for p in Segmenter().parameters())
print(f"{npar} parameters")

In [None]:
class DiceLoss(tr.nn.Module):
    def __init__(self, layer_weights):
        super(DiceLoss, self).__init__()
        self.layer_weights = layer_weights
        self.smooth = 1e-6

    def forward(self, pred, target):
        intersection = (pred * target).sum(2).sum(2)
        sum_A = pred.sum(2).sum(2)
        sum_B = target.sum(2).sum(2)
        loss = 1 - (2 * intersection + self.smooth) / (sum_A + sum_B + self.smooth)
        loss = 100 * loss.mean(0)
        return loss, loss * self.layer_weights

In [None]:
class IoTLoss(tr.nn.Module):
    def __init__(self, layer_weights):
        super(IoTLoss, self).__init__()
        self.layer_weights = layer_weights
        self.smooth = 1e-6

    def forward(self, pred, target):
        intersection = (pred * target).sum(2).sum(2)
        union = tr.max(pred, target).sum(2).sum(2)
        loss = 1 - (intersection + self.smooth) / (union + self.smooth)
        loss = 100 * loss.mean(0)
        return loss, loss * self.layer_weights

In [None]:
class WBCELoss(tr.nn.Module):
    def __init__(self, layer_weights):
        super(WBCELoss, self).__init__()
        self.layer_weights = layer_weights
        self.bce = BCELoss(reduction='none')

    def forward(self, pred, target):
        loss = 1000 * self.bce(pred, target).mean(2).mean(2).mean(0)
        return loss, loss * self.layer_weights

# Training

## Base training

In [None]:
model = Segmenter().cuda()

fname ='segmenter_2'

In [None]:
hist = hl.History()
canvas = hl.Canvas()

epoch = 0
best_corner = float('inf')
best_total = float('inf')
loss_func = None

if os.path.isfile("models/" + fname + ".pmt"):
    model.load("models/" + fname + ".pmt")
    hist.load("models/" + fname + ".hist")
    epoch = len(hist['best_total'].data) + 1
    best_total = min(hist['best_total'].data)
    best_corner = min(hist['best_corner'].data)
    train_total=best_total
    train_corner=best_corner



# loss_func = WBCELoss(tr.Tensor([1.0, 0.5, 0.1]).cuda())
# optimizer = tr.optim.Adam(model.parameters(), lr=0.02, weight_decay=1e-5)
# scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10, min_lr=1e-9)
# train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

# loss_func = IoTLoss(tr.Tensor([1.0, 1.0, 1.0]).cuda())
# optimizer = tr.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-5)
# scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10, min_lr=1e-9)
# train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=2)

loss_func = IoTLoss(tr.Tensor([0.2, 0.01, 1.0]).cuda())
optimizer = tr.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-6, momentum=0.9)
scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=20, min_lr=1e-9)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

epochs_without_improvement = 0
while epochs_without_improvement < 300:
    model.train()
    train_total *= 0.5
    train_corner *= 0.5
    for img, lbl in train_loader:
        img, lbl = img.cuda(), lbl.cuda()
        pred = model(img)
        
        optimizer.zero_grad()
        loss, wloss = loss_func(pred, lbl)

        wloss.mean().backward()
        optimizer.step()
        
        train_total += 0.5 * loss.mean().data.item() / len(train_loader)
        train_corner += 0.5 * loss[2].data.item() / len(train_loader)
    
    scheduler.step(train_corner)
    if train_corner < best_corner:
        if epoch > 50:
            model.save("models/" + fname + ".pmt")
            hist.save("models/" + fname + ".hist")
        best_corner = train_corner
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
    best_total = min(train_total, best_total)
        
    if epoch > 0:
        hist.log(epoch, train_total = train_total,
                        best_total = best_total,
                        train_corner= train_corner,
                        best_corner= best_corner)

        with canvas:
            canvas.draw_plot([hist["train_total"],
                              hist["best_total"]])
            canvas.draw_plot([hist["train_corner"],
                              hist["best_corner"]])
    epoch += 1

hist.summary()
hist.save("models/" + fname + ".hist")
model.load("models/" + fname + ".pmt")


ReLU: 54.1864

GELU: 53.6563

last_128: 54.2754

last_64: 53.7557

w10: 52.8409

w12: 52.4468

## Finetunning

In [None]:
hist = hl.History()
canvas = hl.Canvas()

model.load("models/" + fname + ".pmt")
epoch = 0
best_total = float('inf')
best_corner = float('inf')


loss_func = IoTLoss(tr.Tensor([0.1, 0.1, 1.0]).cuda())
optimizer = tr.optim.SGD(model.parameters(), lr=0.01, weight_decay=1e-6)
scheduler = tr.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=10, min_lr=1e-9)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)
epochs_without_improvement = 0

train_total = 40
train_corner = 53
while epochs_without_improvement < 100:
    model.train()
    train_total *= 0.5
    train_corner *= 0.5
    batch_loss = 0
    optimizer.zero_grad()
    for img, lbl in train_loader:
        img, lbl = img.cuda(), lbl.cuda()
        pred = model(img)
        
        loss, wloss = loss_func(pred, lbl)
        wloss = wloss.mean() / len(train_loader)
        wloss.backward()

        train_total += 0.5 * loss.mean().data.item() / len(train_loader)
        train_corner += 0.5 * loss[2].data.item() / len(train_loader)

    optimizer.step()

    
    scheduler.step(train_corner)
    if train_corner < best_corner:
        if epoch > 10:
            model.save("models/" + fname + "_ft.pmt")
            hist.save("models/" + fname + "_ft.hist")
        best_corner = train_corner
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
    best_total = min(train_total, best_total)
        
    if epoch > 0:
        hist.log(epoch, train_total = train_total,
                        best_total = best_total,
                        train_corner= train_corner,
                        best_corner= best_corner)

        with canvas:
            canvas.draw_plot([hist["train_total"],
                              hist["best_total"]])
            canvas.draw_plot([hist["train_corner"],
                              hist["best_corner"]])
    epoch += 1

hist.summary()
hist.save("models/" + fname + "_ft.hist")
model.load("models/" + fname + "_ft.pmt")

# Test


## Corner detector

In [None]:
fname ='segmenter_2'
params_path = "models/" + fname + ".pmt"
unet = Segmenter()
unet.load(params_path)
unet = unet.cpu()
unet.save('models/segmenter_3.pmt')

In [None]:
fname ='segmenter_2'
params_path = "models/" + fname + ".pmt"

class CornerDetector():
    def __init__(self, gpu=False):
        self.unet = Segmenter()
        self.unet.load(params_path)
        self.unet.eval()
        if gpu: 
            self.unet = self.unet.cuda()
        self.gpu = gpu

    def segment(self, image):
        tensor = ft.to_tensor(image).unsqueeze(0)
        if self.gpu:
            tensor = tensor.cuda()
        segmentation = self.unet(tensor)
        segmentation = segmentation.detach().cpu().numpy().squeeze()
        segmentation[segmentation < 0.1] = 0.0
        return segmentation

    def detect_corner(self, image, seg):
        segmentation = np.copy(seg)
        ccomponent, ncomponent = ndimage.label(segmentation[0])
        greather_component = stats.mode(ccomponent[ccomponent>0], axis=None)[0]
        segmentation = segmentation[2]
        segmentation[ccomponent != greather_component] = 0.0
        
        ccomponent, ncomponent = ndimage.label(segmentation)
        if ncomponent < 4: raise Exception(f"Missing {4 - ncomponent} corners.")

        confidence = np.zeros((4))
        vertexs = -np.ones((4, 2))
        for i in range(4):
            max_probability = segmentation.max()
            confidence[i] = max_probability
            max_position = np.where(segmentation == max_probability)

            mask = ccomponent[max_position[0][0], max_position[1][0]] == ccomponent
            p = np.where(mask)
            w = segmentation[mask]
            w /= w.sum()
            vertexs[i] = np.array([(p[0] * w).sum(), (p[1] * w).sum()])
            segmentation[mask] = 0.0
        vertexs = 2 * vertexs[:,[1,0]]
        idxs = self.order_vertexs(vertexs, image.size)
        return vertexs[idxs]

    def order_vertexs(self, v, img_size):
        w, h = img_size
        vc = v.copy()
        idxs = np.ones(4).astype(int)
        idxs[0] = np.linalg.norm(vc, ord=2, axis=1).argmin()
        vc[idxs[0]] = np.array([float('inf'), float('inf')])

        idxs[1] = np.linalg.norm(vc - np.array([w,0]), ord=2, axis=1).argmin()
        vc[idxs[1]] = np.array([float('inf'), float('inf')])

        idxs[2] = np.linalg.norm(vc - np.array([w,h]), ord=2, axis=1).argmin()
        vc[idxs[2]] = np.array([float('inf'), float('inf')])

        idxs[3] = np.linalg.norm(vc - np.array([0,h]), ord=2, axis=1).argmin()
        vc[idxs[3]] = np.array([float('inf'), float('inf')])

        last_prod = 0
        for i in range(len(v)):
            prev = v[idxs[(i-1)%4]] - v[idxs[i]]
            post = v[idxs[(i+1)%4]] - v[idxs[i]]
            cross_prod = np.cross(post, prev)
            if cross_prod * last_prod < 0:
                idxs[i], idxs[(i+1)%4] = idxs[(i+1)%4].copy(), idxs[i].copy()
                i += 1
            else:
                last_prod = cross_prod
        return idxs


## Run tests

In [None]:
vertex_detector = CornerDetector(gpu=False)

plt.figure(figsize=(20,40))
for i in range(8*4):
    plt.subplot(8, 4, i+1)
    img, _ = dataset[i+350]
    img = ft.to_pil_image(img)

    segmentation = vertex_detector.segment(img)
    try:
        pred_vertex = vertex_detector.detect_corner(img, segmentation)
    except:
        print('Corner missing!')
        continue
    draw = ImageDraw.Draw(img)
    draw.polygon([(pred_vertex[i][0], pred_vertex[i][1]) for i in range(4)])
    del draw
    plt.imshow(img, extent = (0,1,0,1))
    plt.imshow(segmentation[2],
                cmap='jet',
                norm=Normalize(0, 1),
                extent = (0,1,0,1),
                alpha=0.5)
plt.show()

In [None]:
vertex_detector = CornerDetector(gpu=False)

nplots = 0
dist = []
plt.figure(figsize=(20,40))
nmc = 0
for i in tqdm(range(len(dataset))):
    img = dataset.image[i]
    true_vertex = np.array(dataset.corner_coords[i])
    segmentation = vertex_detector.segment(img)
    fail = False
    try:
        pred_vertex = vertex_detector.detect_corner(img, segmentation)
    except:
        fail = True
        nmc += 1
        continue

    true_vertex[:,0] *= 512 / 100
    true_vertex[:,1] *= 384 / 100
    
    vo = vertex_detector.order_vertexs(true_vertex, img.size)
    true_vertex = true_vertex[vo]

    d = np.sqrt(((pred_vertex - true_vertex)**2).sum(1))
    dist.append(d)

    if ((d > 30).any() or fail) and nplots < 20:
        plt.subplot(5, 4, nplots+1)
        nplots += 1
        draw = ImageDraw.Draw(img)
        draw.polygon([(pred_vertex[i][0], pred_vertex[i][1]) for i in range(4)])
        del draw
        plt.imshow(img, extent = (0,1,0,1))
        plt.imshow(segmentation[2],
                    cmap='jet',
                    norm=Normalize(0, 1),
                    extent = (0,1,0,1),
                    alpha=0.5)
        plt.title(dataset.filename[i])
print(f"Missing corners: {nmc}")
dist = np.concatenate(dist)
plt.show()

In [None]:
dist[dist<50].mean()

In [None]:
np.histogram(dist)