In [None]:
from nbdev import *
%nbdev_default_export loss

Cells will be exported to srthesis.loss,
unless a different module is specified after an export flag: `%nbdev_export special.module`


# Losses

> Feature loss used as target

In [None]:
%nbdev_export_internal
import fastai
from fastai.vision import *
from fastai.callbacks import *
from torchvision.models import vgg16_bn

## Feature loss

Feature loss (also perceptual loss) using activations from a pretrained model. In this case both target and prediction is run through a VGG16 model. We're taking activations from all layers just before MaxPool2d and comparing them with L1 loss. In addition there is a standard L1 pixel level loss and and Gramm matrices of activations ($ G=A^{T}A $) compared with L1 loss.

$$ L = L_1 + L_{feat} + L_{Gramm} * 5\times10^{3}$$

Where 
$$ G_i = VGG_{16}(y)_i, \hat G_i = VGG_{16}(\hat y)_i$$

$$ L_{Gramm} = \sum_{a \in A_{y}, \hat a \in A_{\hat y}} L_1(G_i, \hat G_i)$$ 

The constant weight of Gramm matrix is taken from a fast.ai course.

In [None]:
%nbdev_export_internal

def _gram_matrix(x):
    n,c,h,w = x.size()
    x = x.view(n, c, -1)
    return (x @ x.transpose(1,2))/(c*h*w)   

In [None]:
vgg_m = vgg16_bn(True).features.cuda().eval()
requires_grad(vgg_m, False)

blocks = [i-1 for i,o in enumerate(children(vgg_m)) if isinstance(o,nn.MaxPool2d)]
blocks, [vgg_m[i] for i in blocks]

([5, 12, 22, 32, 42],
 [ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True),
  ReLU(inplace=True)])

In [None]:
%nbdev_export
class FeatureLoss(nn.Module):
    def __init__(self, layer_wgts=[20, 70, 10]):
        super().__init__()

        self.m_feat = models.vgg16_bn(True).features.cuda().eval()
        requires_grad(self.m_feat, False)
        blocks = [
            i - 1
            for i, o in enumerate(children(self.m_feat))
            if isinstance(o, nn.MaxPool2d)
        ]
        layer_ids = blocks[2:5]
        self.loss_features = [self.m_feat[i] for i in layer_ids]
        self.hooks = hook_outputs(self.loss_features, detach=False)
        self.wgts = layer_wgts
        self.metric_names = ['pixel'] + [f'feat_{i}' for i in range(len(layer_ids))]
        self.base_loss = F.l1_loss

    def _make_features(self, x, clone=False):
        self.m_feat(x)
        return [(o.clone() if clone else o) for o in self.hooks.stored]

    def forward(self, input, target):
        out_feat = self._make_features(target, clone=True)
        in_feat = self._make_features(input)
        self.feat_losses = [self.base_loss(input, target)]
        self.feat_losses += [
            self.base_loss(f_in, f_out) * w
            for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)
        ]
        self.feat_losses += [self.base_loss(_gram_matrix(f_in), _gram_matrix(f_out))*w**2 * 5e3
                             for f_in, f_out, w in zip(in_feat, out_feat, self.wgts)]

        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)

    def __del__(self):
        self.hooks.remove()

## L1 loss

In [None]:
%nbdev_export
class L1Loss(nn.Module):
    def __init__(self):
        super().__init__()

        self.metric_names = ['l1 loss']
        self.base_loss = F.l1_loss


    def forward(self, input, target):
       
        self.feat_losses = [self.base_loss(input, target)]
        self.metrics = dict(zip(self.metric_names, self.feat_losses))
        return sum(self.feat_losses)

## Use as `feat_loss`

In [None]:
%nbdev_export
feat_loss = FeatureLoss([5,15,2])

In [None]:
%nbdev_export
l1_loss = L1Loss()

In [None]:
%nbdev_hide
notebook2script()

Converted 00_core.ipynb.
Converted 01_utils.ipynb.
Converted 0__template.ipynb.
Converted 10_data.ipynb.
Converted 11_div2k.ipynb.
Converted 12_realsr.ipynb.
Converted 20_metrics.ipynb.
Converted 21_loss.ipynb.
Converted 22_callbacks.ipynb.
Converted 23_tensorboard.ipynb.
Converted 31_generator_learner.ipynb.
Converted 32_critic_learner.ipynb.
Converted 41_generator_pretraining.ipynb.
Converted 42_critic_pretraining.ipynb.
Converted 43_gan_training.ipynb.
Converted EXPERIMENTS JOURNAL.ipynb.
Converted augmentations.ipynb.
Converted graphs-tests.ipynb.
Converted sr reference.ipynb.
