In [None]:
from fastai.vision.all import *

In [None]:
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
pat = r'/([^/]+)_\d+.*'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=64

pets = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
     get_items=get_image_files,
     splitter=RandomSplitter(),
     get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
     item_tfms=item_tfms,
     batch_tfms=batch_tfms
)
dls = pets.dataloaders(path, bs=bs)

In [None]:
!pip install timm >> /dev/null

In [None]:
from timm import create_model
net = create_model("vit_tiny_patch16_224", pretrained=True)

In [None]:
learn = vision_learner(dls, models.resnet18)

In [None]:
#| hide
from IPython.display import Markdown as md
def make_definition(word, pos, meaning, example):
    html = f'<h4 class="anchored" style="color:
    html += f'<p style="color:
    html += f'<p class="word-meaning" style="color:
    html += f'<p class="word-example" style="color:
    return md(html)

In [None]:
#| echo: false
make_definition("body", "noun", "The backbone of a neural network, typically pretrained", "The body of a Resnet 34 model")

In [None]:
#| echo: false
make_definition(
    "head", "noun", 
    "The last, or last few, layers of a neural network; typically consists of everything after the final pooling layer",
    "Predictions from the model are the outputs from the head of the network"
)

In [None]:
learn.model[-1]

In [None]:
net[-1]

In [None]:
len(learn.model)

In [None]:
len(net)

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1,1)
        self.l2 = nn.linear(1,1)
    def forward(self, x):
        return self.l2(self.l1(x))

In [None]:
class MyModel(nn.Sequential):
    def __init__(self):
        layers = [
            nn.Linear(1,1),
            nn.Linear(1,1),
        ]
        super().__init__(*layers)

In [None]:
net = MyModel()
net[0], net[1]

In [None]:
def custom_cut_model(model:nn.Module, cut:typing.Union[int, typing.Callable]):
    """
    Cuts `model` into an `nn.Sequential` based on `cut`. 
    """
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NameError("`cut` must either be an integer or a function")

In [None]:
class CustomTimmBody(nn.Module):
    """
    A small submodule to work with `timm` models more easily
    """
    def __init__(
        self, 
        model, 
        pretrained:bool=True, 
        cut=None, 
        n_in:int=3
    ):
        super().__init__()
        self.needs_pooling = model.default_cfg.get('pool_size', None)
        if cut is None:
            self.model = model
        else:
            self.model = custom_cut_model(model, cut)
    
    def forward(self, x): 
        if self.needs_pooling:
            return self.model.forward_features(x)
        else:
            return self.model(x)

In [None]:
body = CustomTimmBody(
    create_model("vit_tiny_patch16_224", pretrained=True, num_classes=0, in_chans=3)
).train()

In [None]:
head = create_head(body.model.num_features, dls.c, pool=None)

In [None]:
head

In [None]:
x = torch.randn(2,3,224,224)

In [None]:
out = head(body(x))
out, out.shape

In [None]:
apply_init?

In [None]:
apply_init(head)

In [None]:
head(body(x))

In [None]:
#| echo: false
make_definition("split", "adjective", "An arrangement of groups of layers by some criteria", "The model was split between the body and the head")

In [None]:
#| echo: false
make_definition("freeze", "verb", "To make certain layers of a model untrainable", "We froze the backbone of the pretrained model, but not the head")

In [None]:
def my_split_func(model:nn.Module):
    "A function that splits layers by their parameters"
    return L(model[0], model[1:]).map(params)

In [None]:
def splitter(model):
    "Splits a model by head and body"
    return L(model[0], model[1]).map(params)

In [None]:
learn = Learner(
    dls,
    nn.Sequential(body, head),
    splitter=splitter
)

In [None]:
print(learn.summary()[-250:])

In [None]:
learn.freeze()

In [None]:
print(learn.summary()[-295:])