# Pretrained GAN

In [None]:
import fastai
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

from torchvision.models import vgg16_bn

torch.cuda.set_device(2)

In [None]:
path = untar_data(URLs.PETS)
path_hr = path/'images'
path_lr = path/'crappy'

## Critic data

In [None]:
def crappify(fn,i):
    dest = path_lr/fn.relative_to(path_hr)
    dest.parent.mkdir(parents=True, exist_ok=True)
    img = PIL.Image.open(fn)
    targ_sz = resize_to(img, 96, use_min=True)
    img = img.resize(targ_sz, resample=PIL.Image.BILINEAR).convert('RGB')
    img.save(dest, quality=random.randint(10,70))

In [None]:
# il = ImageItemList.from_folder(path_hr)
# parallel(crappify, il.items)

In [None]:
bs,size=32,128
arch = models.resnet34
src = ImageItemList.from_folder(path, include=['images', 'crappy']).random_split_by_pct(0.1, seed=42)

In [None]:
ll = src.label_from_folder()

In [None]:
data_crit = (ll.transform(get_transforms(max_zoom=2.), size=size)
       .databunch(bs=bs).normalize(imagenet_stats))

data_crit.c = 3

In [None]:
data_crit.show_batch(rows=4, ds_type=DatasetType.Valid)

## Train critic

In [None]:
def conv(ni:int, nf:int, ks:int=3, stride:int=1, **kwargs):
    return conv_layer(ni, nf, ks=ks, stride=stride, leaky=0.2, norm_type=NormType.Spectral, **kwargs)

def critic(n_channels:int=3, nf:int=128, n_blocks:int=3, p:int=0.05):
    layers = [
        conv(n_channels, nf, ks=4, stride=2),
        nn.Dropout2d(p/2),
        conv(nf, nf)]
    for i in range(n_blocks):
        layers += [
            nn.Dropout2d(p),
            conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
        nf *= 2
    layers += [
        conv(nf, 1, ks=4, bias=False, padding=0, use_activ=False),
        nn.AdaptiveMaxPool2d(1),
        Flatten(full=True)]
    return nn.Sequential(*layers)

In [None]:
learn = Learner(data_crit, critic(), metrics=accuracy_thresh, loss_func=BCEWithLogitsFlat())

In [None]:
learn.fit_one_cycle(8, 1e-3)

In [None]:
learn.save('critic-pre')

## Pre-train generator

In [None]:
bs,size=32,128
arch = models.resnet34
src = ImageImageList.from_folder(path_lr).random_split_by_pct(0.1, seed=42)

In [None]:
def get_data(bs,size):
    data = (src.label_from_func(lambda x: path_hr/x.name)
           .transform(get_transforms(max_zoom=2.), size=size, tfm_y=True)
           .databunch(bs=bs).normalize(imagenet_stats, do_y=True))

    data.c = 3
    return data

In [None]:
data_gen = get_data(bs,size)

In [None]:
wd = 1e-3
learn = unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight, loss_func=MSELossFlat())

In [None]:
learn.fit_one_cycle(2, pct_start=0.8)

In [None]:
learn.unfreeze()

In [None]:
learn.fit_one_cycle(2, slice(1e-6,1e-3))

In [None]:
learn.show_results(rows=8)

In [None]:
learn.save('gen-pre')

## Train generator with critic loss

In [None]:
learn_crit = Learner(data_crit, critic().eval(), loss_func=BCEWithLogitsFlat()).load('critic-pre')

In [None]:
learn_crit.model = learn_crit.model.eval()

In [None]:
class CriticLoss(nn.Module):
    def __init__(self, critic, mult=1.):
        super().__init__()
        self.critic = critic
        requires_grad(self.critic.model, False)
        self.metric_names = ['pixel','critic']
        self.mult = mult
        
    def forward(self, input, target):
        pred = self.critic.model(input)
        critic_targ = pred.new_ones(pred.shape[0])
        critic_loss = self.critic.loss_func(pred, critic_targ)*self.mult
        px_loss = F.mse_loss(input,target)
        self.metrics = dict(zip(self.metric_names, [px_loss, critic_loss]))
        return px_loss + critic_loss

In [None]:
loss_func = CriticLoss(learn_crit, mult=0.01)

In [None]:
wd = 1e-3
learn = unet_learner(data_gen, arch, wd=wd, blur=True, norm_type=NormType.Weight,
                     loss_func=loss_func, callback_fns=LossMetrics).load('gen-pre')

In [None]:
x,y = data_gen.one_batch()

In [None]:
x = x.cuda().detach()
y = y.cuda().detach()

In [None]:
loss_func(x, y)

In [None]:
loss_func(y, y)

In [None]:
learn.lr_find()

In [None]:
learn.recorder.plot(skip_end=11)

In [None]:
learn.fit_one_cycle(1, pct_start=0.6)

In [None]:
learn.loss_func.mult = 100.

In [None]:
learn.fit_one_cycle(1, slice(1e-3), pct_start=0.5)

In [None]:
learn.show_results()

In [None]:
learn.show_results(rows=8)

## fin