In [None]:
from fastai.vision.all import *

In [None]:
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
pat = r'/([^/]+)_\d+.*'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=64

pets = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
     get_items=get_image_files,
     splitter=RandomSplitter(),
     get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
     item_tfms=item_tfms,
     batch_tfms=batch_tfms
)
dls = pets.dataloaders(path, bs=bs)

In [None]:
!pip install timm >> /dev/null

In [None]:
from timm import create_model
net = create_model("vit_tiny_patch16_224", pretrained=True)

In [None]:
learn = vision_learner(dls, models.resnet18)

In [None]:
#| hide
from IPython.display import Markdown as md
def make_definition(word, pos, meaning, example):
    html = f'<h4 class="anchored" style="color: #1f194c; display: flex; justify-content: space-between; margin-top: 20px;">{word}</h4>'
    html += f'<p style="color: #575a7b;">{pos}</p>'
    html += f'<p class="word-meaning" style="color: #575a7b;">{meaning}</p>'
    html += f'<p class="word-example" style="color: #575a7b; font-style: italic; border-left: 5px solid #3fb618; padding-left: 20px; margin-top: 30px;">{example}</p>'
    return md(html)

In [None]:
#| echo: false
make_definition("body", "noun", "The backbone of a neural network, typically pretrained", "The body of a Resnet 34 model")

<h4 class="anchored" style="color: #1f194c; display: flex; justify-content: space-between; margin-top: 20px;">body</h4><p style="color: #575a7b;">noun</p><p class="word-meaning" style="color: #575a7b;">The backbone of a neural network, typically pretrained</p><p class="word-example" style="color: #575a7b; font-style: italic; border-left: 5px solid #3fb618; padding-left: 20px; margin-top: 30px;">The body of a Resnet 34 model</p>

In [None]:
#| echo: false
make_definition(
    "head", "noun", 
    "The last, or last few, layers of a neural network; typically consists of everything after the final pooling layer",
    "Predictions from the model are the outputs from the head of the network"
)

<h4 class="anchored" style="color: #1f194c; display: flex; justify-content: space-between; margin-top: 20px;">head</h4><p style="color: #575a7b;">noun</p><p class="word-meaning" style="color: #575a7b;">The last, or last few, layers of a neural network; typically consists of everything after the final pooling layer></p><p class="word-example" style="color: #575a7b; font-style: italic; border-left: 5px solid #3fb618; padding-left: 20px; margin-top: 30px;">Predictions from the model are the outputs from the head of the network</p>

In [None]:
learn.model[-1]

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): fastai.layers.Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=37, bias=False)
)

In [None]:
net[-1]

TypeError: 'VisionTransformer' object is not subscriptable

In [None]:
len(learn.model)

2

In [None]:
len(net)

TypeError: object of type 'VisionTransformer' has no len()

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1,1)
        self.l2 = nn.linear(1,1)
    def forward(self, x):
        return self.l2(self.l1(x))

In [None]:
class MyModel(nn.Sequential):
    def __init__(self):
        layers = [
            nn.Linear(1,1),
            nn.Linear(1,1),
        ]
        super().__init__(*layers)

In [None]:
net = MyModel()
net[0], net[1]

(Linear(in_features=1, out_features=1, bias=True),
 Linear(in_features=1, out_features=1, bias=True))

In [None]:
def custom_cut_model(model:nn.Module, cut:typing.Union[int, typing.Callable]):
    """
    Cuts `model` into an `nn.Sequential` based on `cut`. 
    """
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NameError("`cut` must either be an integer or a function")

In [None]:
class CustomTimmBody(nn.Module):
    """
    A small submodule to work with `timm` models more easily
    """
    def __init__(
        self, 
        model, 
        pretrained:bool=True, 
        cut=None, 
        n_in:int=3
    ):
        super().__init__()
        self.needs_pooling = model.default_cfg.get('pool_size', None)
        if cut is None:
            self.model = model
        else:
            self.model = custom_cut_model(model, cut)
    
    def forward(self, x): 
        if self.needs_pooling:
            return self.model.forward_features(x)
        else:
            return self.model(x)

In [None]:
body = CustomTimmBody(
    create_model("vit_tiny_patch16_224", pretrained=True, num_classes=0, in_chans=3)
).train()

In [None]:
head = create_head(body.model.num_features, dls.c, pool=None)

In [None]:
head

Sequential(
  (0): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.25, inplace=False)
  (2): Linear(in_features=192, out_features=512, bias=False)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=512, out_features=37, bias=False)
)

In [None]:
x = torch.randn(2,3,224,224)

In [None]:
out = head(body(x))
out, out.shape

(tensor([[-0.0650, -0.1741,  0.1089, -1.1668, -0.6229,  0.8892,  0.4859, -0.1704,
          -1.4127,  0.7338,  1.0354,  0.6033,  0.3576, -0.2332,  0.7073, -0.7090,
           0.3852, -0.3440,  0.4645,  0.4209,  1.2090,  0.3201,  0.6480, -1.4800,
           0.7253, -0.1806,  0.7261,  0.6329,  0.5336, -1.4665, -0.9681, -0.3387,
          -0.3044, -0.6216,  2.3369, -0.0941,  0.3703],
         [-0.4785,  1.2014, -0.2310,  1.4840, -0.4752,  0.3363,  0.1472, -0.1076,
           0.8156, -0.6819, -0.6366, -0.0721, -0.8710,  0.2871, -0.4673,  0.5040,
           0.5288,  1.5585, -0.3499,  0.5983, -0.1188,  0.1523, -0.7708,  0.8939,
          -0.0318, -0.8048, -0.2581,  0.5921,  0.1012,  0.1626,  0.2249,  0.4605,
           0.1858, -0.4212, -0.0047,  0.6470, -0.7384]], grad_fn=<MmBackward0>),
 torch.Size([2, 37]))

In [None]:
apply_init?

[0;31mSignature:[0m [0mapply_init[0m[0;34m([0m[0mm[0m[0;34m,[0m [0mfunc[0m[0;34m=[0m[0;34m<[0m[0mfunction[0m [0mkaiming_normal_[0m [0mat[0m [0;36m0x7f21f43d5630[0m[0;34m>[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m Initialize all non-batchnorm layers of `m` with `func`.
[0;31mFile:[0m      /opt/conda/lib/python3.10/site-packages/fastai/torch_core.py
[0;31mType:[0m      function


In [None]:
apply_init(head)

In [None]:
head(body(x))

tensor([[ 0.2204, -3.4587, -0.5113, -1.4922, -1.2036,  3.9744, -1.5592, -1.1304,
          1.1073,  0.4745,  1.4827,  0.8954, -2.0673,  0.3289,  1.6994,  0.0623,
          1.7268,  2.5922, -1.4811, -1.4121,  0.7921,  1.5231,  1.2327, -0.0762,
          0.5696, -1.2702,  3.3962, -2.2976,  2.4296, -0.0874, -0.0975,  0.0168,
          2.2922,  2.0433,  1.1191,  1.1637, -2.1250],
        [ 1.1871,  0.2985,  2.6397, -2.9931,  3.5329, -3.3390,  3.3316, -0.8618,
          0.0611,  1.0972, -1.8489, -3.1779,  0.2882,  1.3150,  0.7034, -0.7141,
         -0.5197, -3.5473,  1.0325,  1.3873,  2.3772, -3.8408, -0.3776,  0.0446,
         -1.7974,  1.3227, -0.8745,  3.6397, -2.2262, -0.2738,  1.7177,  0.8619,
         -3.6088, -4.8258,  0.2685,  2.7378,  1.7348]], grad_fn=<MmBackward0>)

In [None]:
#| echo: false
make_definition("split", "adjective", "An arrangement of groups of layers by some criteria", "The model was split between the body and the head")

<h4 class="anchored" style="color: #1f194c; display: flex; justify-content: space-between; margin-top: 20px;">split</h4><p style="color: #575a7b;">adjective</p><p class="word-meaning" style="color: #575a7b;">An arrangement of groups of layers by some criteria</p><p class="word-example" style="color: #575a7b; font-style: italic; border-left: 5px solid #3fb618; padding-left: 20px; margin-top: 30px;">The model was split between the body and the head</p>

In [None]:
#| echo: false
make_definition("freeze", "verb", "To make certain layers of a model untrainable", "We froze the backbone of the pretrained model, but not the head")

<h4 class="anchored" style="color: #1f194c; display: flex; justify-content: space-between; margin-top: 20px;">freeze</h4><p style="color: #575a7b;">verb</p><p class="word-meaning" style="color: #575a7b;">To make certain layers of a model untrainable</p><p class="word-example" style="color: #575a7b; font-style: italic; border-left: 5px solid #3fb618; padding-left: 20px; margin-top: 30px;">We froze the backbone of the pretrained model, but not the head</p>

In [None]:
def my_split_func(model:nn.Module):
    "A function that splits layers by their parameters"
    return L(model[0], model[1:]).map(params)

In [None]:
def splitter(model):
    "Splits a model by head and body"
    return L(model[0], model[1]).map(params)

In [None]:
learn = Learner(
    dls,
    nn.Sequential(body, head),
    splitter=splitter
)

In [None]:
print(learn.summary()[-250:])

Total trainable params: 5,605,056
Total non-trainable params: 0

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback


In [None]:
learn.freeze()

In [None]:
print(learn.summary()[-295:])

Total trainable params: 128,256
Total non-trainable params: 5,476,800

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #1

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback
