In [1]:
from fastai.callback.core import Callback
from fastai.callback.hook import hook_outputs
from fastai.callback.training import ShowGraphCallback
from fastai.vision.augment import Resize, ResizeMethod
from fastai.vision.core import PILImage
from fastai.vision.data import (DataBlock, get_image_files, ImageBlock, ImageDataLoaders, 
                                RandomSplitter)
from fastai.vision.learner import unet_learner
from fastai.vision.models import unet, xresnet
import matplotlib.pyplot as plt
import numpy as np 
import pandas as pd
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.vgg import vgg19
from typing import List, Tuple

In [2]:
img_size = (224, 224)
n_channels = 3
bs = 8
# We only need a small subset of the CelebA dataset for this task
n_ds_items = 100

# Data

In [3]:
ds_path = Path('/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba')

In [None]:
# get_image_files is too slow, we can avoid its checks because
# we know ds_path has a flat structure with only images inside
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                   #get_items=get_image_files,
                   get_items=lambda path: path.ls()[:n_ds_items],
                   splitter=RandomSplitter(0.2),
                   item_tfms=Resize(img_size[0]))
dls = dblock.dataloaders(ds_path, path=ds_path, bs=bs)

In [None]:
dls.show_batch()

# Loss function

In [None]:
class FeaturesCalculator:
    def __init__(self, vgg_style_layers_idx:List[int], vgg_content_layers_idx:List[int],
                 vgg:nn.Module=None, device:torch.device=None):
        self.vgg = vgg19(pretrained=True) if vgg is None else vgg
        self.vgg.eval()
        if device is not None: self.vgg.to(device)
        modules_to_hook = [self.vgg.features[idx] for idx in (*vgg_style_layers_idx, *vgg_content_layers_idx)]
        self.hooks = hook_outputs(modules_to_hook, detach=False)
        self.style_ftrs_hooks = self.hooks[:len(vgg_style_layers_idx)]
        self.content_ftrs_hooks = self.hooks[len(vgg_style_layers_idx):]
        # TODO: when to remove hooks? `clean` method????
    
    def _get_hooks_out(self, hooks):
        return [h.stored for h in hooks]
    
    def calc_style(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.style_ftrs_hooks)
    
    def calc_content(self, img_t:torch.Tensor) -> List[torch.Tensor]:
        self.vgg(img_t)
        return self._get_hooks_out(self.content_ftrs_hooks)
    
    def calc_style_and_content(self, img_t:torch.Tensor) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
        self.vgg(img_t)
        style_ftrs = self._get_hooks_out(self.style_ftrs_hooks)
        content_ftrs = self._get_hooks_out(self.content_ftrs_hooks)
        return style_ftrs, content_ftrs

In [None]:
vgg_content_layers_idx = [22]
ftrs_calc = FeaturesCalculator([], vgg_content_layers_idx)

In [None]:
#n_mid_ftrs = 1000
#n_out_ftrs = 3
# encoder = xresnet.xresnet18(p=0.0, c_in=3, n_out=n_mid_ftrs, stem_szs=(32,32,64),
#                             widen=1.0, sa=False, 
#                             act_cls=defaults.activation, 
#                             ndim=2, ks=3, stride=2, 
#                             #**kwargs # Rest of kwargs go to ConvLayer
#                            )

# # Remove AvgPool2d. This could be optional
# encoder[8] = nn.Identity()
# # Remove final Flatten, Dropout and Linear
# encoder[9] = nn.Identity()
# encoder[10] = nn.Identity()
# encoder[11] = nn.Identity()
# model = unet.DynamicUnet(encoder, n_out_ftrs, img_size, blur=False, blur_final=True, self_attention=False,
#                          y_range=None, last_cross=True, bottle=False, 
#                          act_cls=defaults.activation,
#                          init=nn.init.kaiming_normal_, 
#                          norm_type=NormType.Batch 
#                         )


# IT'S BETTER TO USE unet_learner!!!

In [None]:
last_input_cb = Callback()
content_loss = nn.MSELoss(reduction='mean')


def loss_func(output, target):
    # last_input_cb.x is actually the same as `target` here, but if it wasn't this
    # would be the std way to access input to measure the preservation of the content
    input_content_ftrs = ftrs_calc.calc_content(last_input_cb.x)[0]
    output_content_ftrs = ftrs_calc.calc_content(output)[0]
    return content_loss(output_content_ftrs, input_content_ftrs)

# Learner

In [None]:
learner = unet_learner(dls, xresnet.xresnet18, normalize=True, n_out=n_channels, pretrained=False,
                       loss_func=loss_func, cbs=[last_input_cb, ShowGraphCallback()])

# Training

In [None]:
learner.fit(5)

In [None]:
learner.show_results()

# Example of prediction

In [None]:
sample_img = dls.valid_ds[0][0]
_, _, pred_img = learner.predict(sample_img)
pred_img.show()