In [None]:
%matplotlib inline

In [None]:
from functools import partial
from itertools import cycle
import math
from multiprocessing import Pool, cpu_count
from pathlib import Path
import sys

try:
    old_path
except NameError:
    old_path = sys.path.copy()
    sys.path = [Path.cwd().parent.as_posix()] + old_path

In [None]:
import torch
from torch import nn
from torch import optim
from torch.utils.data import Dataset
from torch.nn import functional as F
from torchvision import transforms as T
from torchvision import models
from torchvision.datasets.folder import pil_loader, is_image_file

import matplotlib.pyplot as plt
import pandas as pd
import PIL.Image
import PIL.ImageDraw
import numpy as np

In [None]:
from loop import train_classifier, make_phases
from loop import callbacks as C
from loop.schedule import CosineAnnealingSchedule
from loop.config import defaults

In [None]:
defaults.device = torch.device('cuda:0')

In [None]:
RAW_SIZE = 256, 256
IMAGE_SIZE = 128
CSV_PATH = Path.home()/'data'/'quick_draw'/'prepared'
IMAGENET_STATS = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
COLORS = ['#0095EF', '#3C50B1', '#6A38B3', '#A224AD', '#F31D64', '#FE433C']

In [None]:
class OneCycleSchedule:
    
    def __init__(self, t, linear_pct=0.2, eta_max=1.0, eta_min=None, div_factor=100):
        if eta_min is None:
            eta_min = eta_max / div_factor
        
        self.t = t
        self.linear_pct = linear_pct
        self.eta_max = eta_max
        self.eta_min = eta_min
        
        self.t_cosine = int(math.ceil(t*(1 - linear_pct))) + 1
        self.t_linear = int(math.floor(t*linear_pct))
        
        self.cosine = CosineAnnealingSchedule(eta_min, eta_max, t_max=self.t_cosine, t_mult=1)
        
        self.linear = lambda x: x*(eta_max - eta_min)/self.t_linear + eta_min
        
        self.iter = 0
    
    def update(self, **kwargs):
        self.iter += 1
        if self.iter <= self.t_linear:
            return self.linear(self.iter)
        else:
            return self.cosine.update()

In [None]:
def generate_schedule(schedule, n):
    return [schedule.update() for _ in range(n)]

In [None]:
def plot_schedule(schedule, n=1000, **fig_kwargs):
    xs, ys = zip(*list(enumerate(generate_schedule(schedule, n))))
    f, ax = plt.subplots(1, 1, **fig_kwargs)
    ax.plot(xs, ys, label='schedule')

In [None]:
plot_schedule(OneCycleSchedule(1000, eta_max=1.0, linear_pct=0.2), n=1000)

In [None]:
class ImageRenderer:
    """Converts string with strokes into PIL image."""

    def __init__(self, mode='b/w', bg='black', fg='white', lw: int=4,
                 colors=None):

        mode = mode if mode in ('b/w', 'rgb') else 'b/w'

        self.render_fn = {
            'b/w': render_bw,
            'rgb': render_rgb
        }[mode]

        self.mode = mode
        self.bg = bg
        self.fg = fg
        self.lw = lw
        self.colors = cycle(colors or COLORS)

    def render(self, strokes: str, image_size: tuple):
        x_ref, y_ref = RAW_SIZE
        x_max, y_max = image_size
        ratio = x_max/float(x_ref), y_max/float(y_ref)
        return self.render_fn(self, strokes, ratio, image_size)


def render_bw(renderer, strokes, ratio, image_size):
    bg, fg, lw = [getattr(renderer, x) for x in 'bg fg lw'.split()]

    x_ratio, y_ratio = ratio
    canvas = PIL.Image.new('RGB', image_size, color=bg)
    draw = PIL.ImageDraw.Draw(canvas)

    for segment in strokes.split('|'):
        chunks = [int(x) for x in segment.split(',')]
        while len(chunks) >= 4:
            (x1, y1, x2, y2), chunks = chunks[:4], chunks[2:]
            scaled = (
                int(x1 * x_ratio), int(y1 * y_ratio),
                int(x2 * x_ratio), int(y2 * y_ratio))
            draw.line(tuple(scaled), fill=fg, width=lw)

    return canvas


def render_rgb(renderer, strokes, ratio, image_size):
    colors, bg, lw = [getattr(renderer, x) for x in 'colors bg lw'.split()]

    x_ratio, y_ratio = ratio
    canvas = PIL.Image.new('RGB', image_size, color=bg)
    draw = PIL.ImageDraw.Draw(canvas)

    for segment, color in zip(strokes.split('|'), colors):
        chunks = [int(x) for x in segment.split(',')]
        while len(chunks) >= 4:
            (x1, y1, x2, y2), chunks = chunks[:4], chunks[2:]
            scaled = (
                int(x1 * x_ratio), int(y1 * y_ratio),
                int(x2 * x_ratio), int(y2 * y_ratio))
            draw.line(tuple(scaled), fill=color, width=lw)

    return canvas

In [None]:
default_renderer = ImageRenderer('rgb', bg='white')

In [None]:
class Doodles(Dataset):

    def __init__(self, root: Path, train: bool=True,
                 subset_size: int=None, image_size: int=RAW_SIZE,
                 renderer=default_renderer, transforms=None):

        subfolder = root/('train' if train else 'valid')
        if isinstance(image_size, int):
            image_size = image_size, image_size

        worker = partial(read_category, subset_size)
        with Pool(cpu_count()) as pool:
            data = pool.map(worker, subfolder.glob('*.csv'))

        merged = pd.concat(data)
        targets = merged.word.values
        classes = np.unique(targets)
        class2idx = {v: k for k, v in enumerate(classes)}
        labels = np.array([class2idx[c] for c in targets])

        self.root = root
        self.train = train
        self.subset_size = subset_size
        self.image_size = image_size
        self.renderer = renderer
        self.data = merged.drawing.values
        self.classes = classes
        self.class2idx = class2idx
        self.labels = labels
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        strokes, target = self.data[item], self.labels[item]
        img = self.renderer.render(strokes, self.image_size)
        if self.transforms is not None:
            img = self.transforms(img)
        return img, target
    
    
class TestImagesFolder(Dataset):

    def __init__(self, path, image_size=RAW_SIZE,
                 loader=pil_loader, pseudolabel=0):

        path = Path(path)

        if isinstance(image_size, int):
            image_size = image_size, image_size

        assert path.is_dir() and path.exists(), 'Not a directory!'
        assert path.stat().st_size > 0, 'Directory is empty'

        images = [file for file in path.iterdir() if is_image_file(str(file))]

        self.path = path
        self.image_size = image_size
        self.loader = loader
        self.images = images
        self.pseudolabel = pseudolabel

    def __len__(self):
        return len(self.images)

    def __getitem__(self, item):
        img = self.loader(self.images[item])
        img.thumbnail(self.image_size, PIL.Image.ANTIALIAS)
        return img, self.pseudolabel
    
    
def read_category(subset_size, path):
    if subset_size is None:
        return pd.read_csv(path)

    data = pd.DataFrame()
    for chunk in pd.read_csv(path, chunksize=min(10000, subset_size)):
        data = data.append(chunk)
        if len(data) >= subset_size:
            break

    return data[:subset_size]

In [None]:
def flat_model(model):
    """Converts model with nested modules into single list of modules"""
    
    def flatten(m):
        children = list(m.children())
        if not children:
            return [m]
        return sum([flatten(child) for child in children], [])
    
    return nn.Sequential(*flatten(model))

In [None]:
def as_sequential(model):
    return nn.Sequential(*list(model.children()))

In [None]:
def get_output_shape(model):
    """Pass a dummy input through the sequential model to get the output tensor shape."""
    first, *rest = flat_model(model)
    shape = first.in_channels, 128, 128
    dummy_input = torch.zeros(shape)
    out = model(dummy_input[None])
    return list(out.size())[1:]

In [None]:
class AdaptiveConcatPool2d(nn.Module):
    
    def __init__(self, size=1):
        super().__init__()
        self.avg = nn.AdaptiveAvgPool2d(size)
        self.max = nn.AdaptiveMaxPool2d(size)
        
    def forward(self, x):
        return torch.cat([self.max(x), self.avg(x)], 1)

In [None]:
class Flatten(nn.Module):
    
    def forward(self, x):
        return x.view(x.size(0), -1)

In [None]:
def init_weights(m):
    name = m.__class__.__name__
    with torch.no_grad():
        if name.find('Conv') != -1:
            nn.init.kaiming_normal_(m.weight, mode='fan_out')
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif name.find('BatchNorm') != -1:
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 1e-3)
        elif name.find('Linear') != -1:
            nn.init.kaiming_normal_(m.weight)
            nn.init.zeros_(m.bias)

In [None]:
def leaky_linear(ni, no, dropout=None, bn=True):
    layers = []
    if bn:
        layers.append(nn.BatchNorm1d(ni))
    if dropout is not None and dropout > 0:
        layers.append(nn.Dropout(dropout))
    layers.append(nn.Linear(ni, no))
    layers.append(nn.LeakyReLU(0.01, True))
    return nn.Sequential(*layers)

In [None]:
class Classifier(nn.Module):
    
    def __init__(self, n_classes, arch=models.resnet18, init_fn=init_weights):
        super().__init__()
        
        model = arch(True)
        seq_model = as_sequential(model)
        backbone, classifier = seq_model[:-2], seq_model[-2:]
        out_shape = get_output_shape(backbone)
        input_size = out_shape[0] * 2
        
        self.backbone = backbone
        self.top = nn.Sequential(
            AdaptiveConcatPool2d(),
            Flatten(),
            leaky_linear(input_size, 512, 0.25),
            leaky_linear(512, 256, 0.5),
            nn.Linear(256, n_classes)
        )
        
        self.init(init_fn)
        
    def freeze_backbone(self, freeze=True, bn=True):
        for child in self.backbone.children():
            name = child.__class__.__name__
            if not bn and name.find('BatchNorm') != -1:
                continue
            for p in child.parameters():
                p.requires_grad = not freeze
    
    def forward(self, x):
        return self.top(self.backbone(x))
    
    def init(self, fn=None):
        if fn is None:
            return
        self.top.apply(fn)                

In [None]:
epochs = 1
batch_size = 300
image_size = 224
n_train = 200
n_valid = 50

In [None]:
train_ds = Doodles(
    CSV_PATH, 
    train=True, 
    subset_size=n_train,
    image_size=IMAGE_SIZE, 
    transforms=T.Compose([
        T.Pad(4, padding_mode='reflect'),
        T.Resize(image_size),
        T.RandomAffine(degrees=5, 
                       translate=(0.1, 0.1), 
                       scale=(0.9, 1.1),
                       fillcolor='white'),
        T.RandomResizedCrop(image_size, scale=(0.8, 1.0)),
        T.ToTensor(),
        T.Normalize(*IMAGENET_STATS)
    ])
)

valid_ds = Doodles(
    CSV_PATH, 
    train=False,
    subset_size=n_valid, 
    image_size=IMAGE_SIZE,
    transforms=T.Compose([
        T.Pad(4, padding_mode='reflect'),
        T.Resize(image_size),
        T.ToTensor(),
        T.Normalize(*IMAGENET_STATS)
    ])
)

cb = C.CallbacksGroup([
    C.History(),
    C.RollingLoss(),
    C.StreamLogger(),
    C.ProgressBar(),
    C.Accuracy(),
    C.Scheduler(
        OneCycleSchedule(
            t=len(train_ds),
            linear_pct=0.2,
            eta_max=1.0,
            div_factor=100
        ),
        mode='batch'
    )
])

In [None]:
phases = make_phases(train_ds, valid_ds, batch_size)
model = Classifier(340, arch=models.resnet50)
model.freeze_backbone()
opt = optim.Adam(model.parameters(), lr=1e-2)
train_classifier(model, opt, phases, cb, epochs=epochs)