In [1]:
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *

In [2]:
def get_convs(m):
  convs = []
  def _get_convs(m):
    if isinstance(m, nn.Conv2d): convs.append(m)
    for l in m.children(): _get_convs(l)
  _get_convs(m)
  return convs

In [3]:
m_vgg = vgg16_bn(pretrained=True)

In [4]:
vgg_convs = get_convs(m_vgg); len(vgg_convs), vgg_convs[:3]

(13,
 [Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))])

In [5]:
m = xresnet50(pretrained=False)

In [6]:
convs = get_convs(m); len(convs), convs[:3]

(55,
 [Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
  Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)])

In [7]:
def freeze_layer(m):
  for p in m.parameters(): p.requires_grad = False

In [8]:
vgg_convs

[Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))]

In [9]:
conv_ws = defaultdict(list)
for conv in vgg_convs:
  conv_ws[conv.in_channels].append(conv.weight.data)
conv_ws = {k: torch.cat(v, dim=0) for k, v in conv_ws.items()}

In [10]:
for conv in convs:
#   conv, vgg_conv = convs[0], vgg_convs[0]
  ni = conv.in_channels
  if ni not in conv_ws: continue
  vgg_conv = conv_ws[ni]
  if conv.kernel_size[0] != vgg_conv.shape[2]: continue
  print('doing')
  nf = min(conv.out_channels, vgg_conv.shape[0])
  conv.weight.data[:nf].copy_(vgg_conv[:nf])
  freeze_layer(conv)
  assert (conv.weight.data[:nf] == vgg_conv[:nf]).all()

doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing
doing


## Train

In [11]:
lbl2name = dict(n01440764='tench',
                n02102040='english springer',
                n02979186='cassete player',
                n03000684='chainsaw',
                n03028079='church',
                n03394916='french horn',
                n03417042='garbage truck',
                n03425413='gas pump',
                n03445777='golf ball',
                n03888257='parachute',
                )

In [12]:
path = untar_data(URLs.IMAGENETTE); path.ls()

(#3) [/home/lgvaz/.fastai/data/imagenette/log.csv,/home/lgvaz/.fastai/data/imagenette/val,/home/lgvaz/.fastai/data/imagenette/train]

In [13]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
          get_items=get_image_files,
          splitter=GrandparentSplitter(valid_name='val'),
          get_y=[parent_label, lbl2name.get])

In [14]:
dbunch = dblock.databunch(path,
                          bs=64,
                          item_tfms=[Resize(128)],
                          batch_tfms=[Flip(), Normalize(*imagenet_stats)])

In [15]:
def opt_func(*args, **kwargs):
  opt = RAdam(*args, mom=0.95, wd=1e-2, eps=1e-6, **kwargs)
  return Lookahead(opt)

In [16]:
learn = Learner(dbunch, m, LabelSmoothingCrossEntropy(), metrics=[accuracy], 
                opt_func=opt_func).to_fp16()

In [17]:
convs2 = get_convs(learn.model)

In [18]:
[p.requires_grad for p in convs2[0].parameters()]

[False]

In [19]:
learn.fit_flat_cos(5, 4e-3, pct_start=.72, wd=1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.268738,2.866354,0.424,00:56
1,1.887301,1.802849,0.754,00:50
2,1.737028,1.922075,0.67,00:50
3,1.631754,1.862088,0.722,00:50
4,1.403864,1.407076,0.868,00:50
