In [None]:
# default_exp data

In [None]:
#export
from fastai2.vision.all import *
from fastai2.vision.gan import *
from fastai2.vision.gan import _conv, _conv_args, DenseResBlock
from colorup.core import *

In [None]:
#hide
from nbdev.showdoc import *

# Learner

> Train the model

We use a GAN to train the model.

## Prepare dataset

In [None]:
# size of resized images
img_size = 192

# hyper-parameters
aug_size = 92
batch_size = 20
partial_n = 2048 * 2  # to define one epoch
samples_per_update = batch_size * 1  # for gradient accumulation

In [None]:
# load images
path = Path('../data')
items = get_image_files(path, folders=[f'train_{img_size}', f'valid_{img_size}'])

In [None]:
# data augmentation
augment_tfms = aug_transforms(size=aug_size, min_scale=0.4)

In [None]:
# create dataloaders
dsrc = TfmdLists(items, tfms=[PILImage.create, Resize(img_size), RGBToLAB(), ToTensor(), Split_L_AB()], splits=FuncSplitter(lambda o:'valid' in str(o.parent))(items))
dls = dsrc.partial_dataloaders(bs=batch_size, partial_n=partial_n, after_batch=[AdjustType(), IntToFloatTensor(), *augment_tfms, Normalize.from_stats(mean=[0.5],std=[0.5])])

In [None]:
# Print stats
print(f'Number of images in training set: {len(dsrc.train)}')
print(f'Number of images in validation set: {len(dsrc.valid)}')

In [None]:
dls.show_batch()

## Create generator

In [None]:
config = unet_config(y_range=(-0.5,0.5), self_attention=True)
generator = unet_learner(dls=dls, arch=resnet34, n_in=1, n_out=2, config=config, loss_func=MSELossFlat(), pretrained=False, cbs=[SaveModelCallback(fname='generator', with_opt=True)])

## Create discriminator

In [None]:
def gan_critic(n_channels=3, nf=128, n_blocks=3, p=0.15):
    "Critic to train a `GAN`."
    layers = [
        _conv(n_channels, nf, ks=4, stride=2),
        nn.Dropout2d(p/2),
        DenseResBlock(nf, **_conv_args)]
    nf *= 2 # after dense block
    for i in range(n_blocks):
        layers += [
            nn.Dropout2d(p),
            _conv(nf, nf*2, ks=4, stride=2, self_attention=(i==0))]
        nf *= 2
    layers += [
        ConvLayer(nf, 1, ks=4, bias=False, padding=0, norm_type=NormType.Spectral, act_cls=None),
        nn.AdaptiveAvgPool2d(1)]
    return nn.Sequential(*layers)

In [None]:
critic = Learner(dls, gan_critic(nf=64), metrics=accuracy_multi, loss_func=BCEWithLogitsLossFlat(), cbs=[SaveModelCallback(fname='critic', with_opt=True)])

## Create GAN

In [None]:
@patch
def begin_batch(self: GANTrainer):
    "Clamp the weights with `self.clip` if it's not None, set the correct input/target."
    if self.training and self.clip is not None:
        for p in self.critic.parameters(): p.data.clamp_(-self.clip, self.clip)
    if not self.gen_mode:
        self.learn.xb, self.learn.yb = (torch.cat((*self.xb, *self.yb), dim=1),), self.xb

In [None]:
def gan_loss_from_critic(loss_crit):
    "Define loss functions for a GAN from `loss_crit`"
    def _loss_G(fake_pred):
        ones = fake_pred.new_ones(fake_pred.shape[0])
        return loss_crit(fake_pred, ones)

    def _loss_C(real_pred, fake_pred):
        # check we have same size of inputs
        ones  = real_pred.new_ones (real_pred.shape[0])
        zeros = fake_pred.new_zeros(fake_pred.shape[0])
        return (loss_crit(real_pred, ones) + loss_crit(fake_pred, zeros)) / 2

    return _loss_G, _loss_C    

In [None]:
class GANLoss(GANModule):
    "Wrapper around `crit_loss_func` and `gen_loss_func`"
    def __init__(self, gen_loss_func, crit_loss_func, gan_model, learn):
        super().__init__()
        store_attr(self, 'gen_loss_func,crit_loss_func,gan_model,learn')

    def generator(self, output, target):
        "Evaluate the `output` with the critic then uses `self.gen_loss_func`"
        img_gen = torch.cat((*self.learn.xb, output), dim=1)
        fake_pred = self.gan_model.critic(img_gen)
        self.gen_loss = self.gen_loss_func(fake_pred)
        return self.gen_loss

    def critic(self, real_pred, input):
        "Create some `fake_pred` with the generator from `input` and compare them to `real_pred` in `self.crit_loss_func`."
        fake = self.gan_model.generator(input).requires_grad_(False)
        img_gen = torch.cat((input, fake), dim=1)
        fake_pred = self.gan_model.critic(img_gen)
        self.crit_loss = self.crit_loss_func(real_pred, fake_pred)
        return self.crit_loss

In [None]:
@delegates()
class GANLearner(Learner):
    "A `Learner` suitable for GANs."
    def __init__(self, dls, generator, critic, gen_loss_func, crit_loss_func, switcher=None, gen_first=False,
                 switch_eval=True, show_img=True, clip=None, cbs=None, metrics=None, **kwargs):
        gan = GANModule(generator, critic)
        loss_func = GANLoss(gen_loss_func, crit_loss_func, gan, self)
        if switcher is None: switcher = FixedGANSwitcher(n_crit=5, n_gen=1)
        trainer = GANTrainer(clip=clip, switch_eval=switch_eval, show_img=show_img)
        cbs = L(cbs) + L(trainer, switcher)
        metrics = L(metrics) + L(*LossMetrics('gen_loss,crit_loss'))
        super().__init__(dls, gan, loss_func=loss_func, cbs=cbs, metrics=metrics, **kwargs)

    @classmethod
    def from_learners(cls, gen_learn, crit_learn, switcher=None, **kwargs):
        "Create a GAN from `learn_gen` and `learn_crit`."
        losses = gan_loss_from_critic(crit_learn.loss_func)
        return cls(gen_learn.dls, gen_learn.model, crit_learn.model, *losses, switcher=switcher, **kwargs)

GANLearner.from_learners = delegates(to=GANLearner.__init__)(GANLearner.from_learners)

In [None]:
switcher = AdaptiveGANSwitcher(critic_thresh=0.65)

In [None]:
learn = GANLearner.from_learners(generator, critic, switcher=switcher, opt_func=partial(Adam, mom=0.))

In [None]:
learn.fit_one_cycle(100)

In [None]:
learn.show_results()

# Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()