In [None]:
#default_exp layer_features

In [None]:
#export
from fastai2.basics import *
from fastai2.vision.all import *
from fastai2.callback.all import *
from torchvision.models import vgg16, vgg19
from faststyle import *

In [None]:
#export
Features = namedtuple('Features', 'style content')

In [None]:
#export
def _prepare_model(m):
  m = m.to(default_device()).eval()
  for p in m.parameters(): p.requires_grad=False
  return m

In [None]:
#export
def _get_layers(m, idxs):
  return [m[i] for i in idxs]

In [None]:
#export
_imagenet_norm = Normalize.from_stats(*imagenet_stats)

In [None]:
#export
class FeatModels:
  @staticmethod
  def vgg16():
    m = vgg16(True).features
    stl_ls = _get_layers(m, (1, 11, 18, 25))
    cnt_ls = _get_layers(m, (20,))
    return dict(m=m, stl_ls=stl_ls, cnt_ls=cnt_ls, tfms=[_imagenet_norm])
  
  @staticmethod
  def vgg19():
    m = vgg19(True).features
    stl_ls = _get_layers(m, (1, 6, 11, 20, 29))
    cnt_ls = _get_layers(m, (22,))
    return dict(m=m, stl_ls=stl_ls, cnt_ls=cnt_ls, tfms=[_imagenet_norm])

In [None]:
#export
class LayerFeats:
  def __init__(self, m, stl_ls, cnt_ls, tfms=None):
    self.m, self.tfms = _prepare_model(m), Pipeline(tfms)
    self.stl_hooks = hook_outputs(stl_ls, detach=False)
    self.cnt_hooks = hook_outputs(cnt_ls, detach=False)
  
  def __call__(self, x):
    _ = self.m(self.tfms(x))
    return Features(style=self.stl_hooks.stored, content=self.cnt_hooks.stored)
  
  @classmethod
  def from_feat_m(cls, feat_m): return cls(**feat_m())

In [None]:
get_feats = LayerFeats.from_feat_m(FeatModels.vgg19)
tim = TensorImage.create('../examples/styles/abstract.jpg')
feats = get_feats(tim)

In [None]:
test_eq(len(feats), 2)
test_eq(len(feats.style), 5)
test_eq(len(feats.content), 1)

## Export -

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 01_data.ipynb.
Converted 02_layer_features.ipynb.
Converted 03_loss.ipynb.
Converted 04_learner.ipynb.
Converted 04_models.ipynb.
Converted 06_callback.ipynb.
