In [1]:
from fastai2.basics import *
from fastai2.callback.all import *
from fastai2.vision.all import *

In [2]:
def get_convs(m):
  convs = []
  def _get_convs(m):
    if isinstance(m, nn.Conv2d): convs.append(m)
    for l in m.children(): _get_convs(l)
  _get_convs(m)
  return convs

In [3]:
m_vgg = vgg16_bn(pretrained=True)

In [4]:
vgg_convs = get_convs(m_vgg); len(vgg_convs), vgg_convs[:3]

(13,
 [Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)),
  Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))])

In [5]:
m = xresnet50(pretrained=False)

In [6]:
convs = get_convs(m); len(convs), convs[:3]

(55,
 [Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
  Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
  Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)])

In [7]:
def freeze_layer(m):
  for p in m.parameters(): p.requires_grad = False

In [8]:
# conv, vgg_conv = convs[0], vgg_convs[0]
# conv.weight.data.copy_(vgg_conv.weight.data[:32])
# # freeze_layer(conv)
# assert (conv.weight.data == vgg_conv.weight.data[:32]).all()

## Train

In [9]:
lbl2name = dict(n01440764='tench',
                n02102040='english springer',
                n02979186='cassete player',
                n03000684='chainsaw',
                n03028079='church',
                n03394916='french horn',
                n03417042='garbage truck',
                n03425413='gas pump',
                n03445777='golf ball',
                n03888257='parachute',
                )

In [10]:
path = untar_data(URLs.IMAGENETTE); path.ls()

(#3) [/home/lgvaz/.fastai/data/imagenette/log.csv,/home/lgvaz/.fastai/data/imagenette/val,/home/lgvaz/.fastai/data/imagenette/train]

In [11]:
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
          get_items=get_image_files,
          splitter=GrandparentSplitter(valid_name='val'),
          get_y=[parent_label, lbl2name.get])

In [12]:
dbunch = dblock.databunch(path,
                          bs=64,
                          item_tfms=[Resize(128)],
                          batch_tfms=[Flip(), Normalize(*imagenet_stats)])

In [13]:
def opt_func(*args, **kwargs):
  opt = RAdam(*args, mom=0.95, wd=1e-2, eps=1e-6, **kwargs)
  return Lookahead(opt)

In [14]:
learn = Learner(dbunch, m, LabelSmoothingCrossEntropy(), metrics=[accuracy], 
                opt_func=opt_func).to_fp16()

In [15]:
convs2 = get_convs(learn.model)

In [16]:
[p.requires_grad for p in convs2[0].parameters()]

[True]

In [17]:
learn.fit_flat_cos(5, 4e-3, pct_start=.72, wd=1e-2)

epoch,train_loss,valid_loss,accuracy,time
0,2.287325,2.158326,0.602,01:15
1,1.956262,1.889363,0.694,01:00
2,1.762481,1.714435,0.758,00:51
3,1.631533,1.983889,0.688,00:53
4,1.420746,1.432582,0.866,00:51
