In [None]:
# export
from fastai2.basics import *
from fastai2.callback.all import *
import torchvision
import torch.nn.functional as F

from fastai2_utils.pytorch.model import set_requires_grad

In [None]:
# default_exp loss

# Loss
>

## Feature Loss

In [None]:
# export
class FeatureLoss():
    def __init__(self, device, layer_wgts=[20, 70, 10]):
        super().__init__()
        self.wgts = layer_wgts
        self.m_feat = torchvision.models.vgg16_bn(True).features.to(device).eval()
        set_requires_grad([self.m_feat], False)
        l_feat = [l for l in self.m_feat.children() if isinstance(l, nn.MaxPool2d)][2:5]
        self.hooks = hook_outputs(l_feat, detach=False)

    def make_features(self, x):
        ''' return list of output from 3 layers of vgg16 '''
        self.m_feat(x)
        return self.hooks.stored

    def __call__(self, inp, targ):
        inp = inp.float()
        targ = targ.float()
        inp_feats = self.make_features(inp)
        targ_feats = self.make_features(targ)

        # feat_losses = [mae of raw pixels, mae of 1st layer, mae of 2nd layer, mae of 3rd layer]
        feat_losses = [F.l1_loss(inp, targ)]
        feat_losses += [
            F.l1_loss(inp_feat, targ_feat) * w
            for inp_feat, targ_feat, w in zip(inp_feats, targ_feats, self.wgts)
        ]
        return torch.stack(feat_losses).mean(dim=0)

    def __del__(self):
        self.hooks.remove()

In [None]:
loss = FeatureLoss('cpu')

In [None]:
inp = torch.ones(16, 3, 64, 64, requires_grad=True)
targ = torch.ones(16, 3, 64, 64)
l = loss(inp, targ)
test_eq(l, 0)
test_eq(l.requires_grad, True)

In [None]:
inp = torch.zeros(16, 3, 64, 64, requires_grad=True)
targ = torch.ones(16, 3, 64, 64)
l = loss(inp, targ)
test_close(l, 1, eps=1)
test_eq(l.requires_grad, True)

## Export -

In [None]:
# hide
from nbdev.export import notebook2script
notebook2script()

Converted 01_eda[script].ipynb.
Converted 01_gen_coco_tiny_data[script].ipynb.
Converted 02_data_coco.ipynb.
Converted 03_model.ipynb.
Converted 04_loss.ipynb.
Converted 90a_fulltest_train_lm.ipynb.
Converted 95a_train_lm[script].ipynb.
Converted Untitled.ipynb.
Converted index.ipynb.
