In [1]:
#default_exp layers

In [2]:
#export
import torch
from torch import nn
import torch.nn.functional as F

In [3]:
#export
class AdaptiveConcatPool2d(nn.Module):
    "Layer that concats `AdaptiveAvgPool2d` and `AdaptiveMaxPool2d` from FastAI"

    def __init__(self, size=None):
        super(AdaptiveConcatPool2d, self).__init__()
        self.size = size or 1
        self.ap = nn.AdaptiveAvgPool2d(self.size)
        self.mp = nn.AdaptiveMaxPool2d(self.size)

    def forward(self, x):
        return torch.cat([self.mp(x), self.ap(x)], 1)

In [4]:
#export
class Mish(nn.Module):
    "Mish activation"
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x * (torch.tanh(F.softplus(x)))

In [5]:
#export
def cut_model(model: nn.Module, n: int = -2):
    "cuts `model` layers upto `n`"
    ls = list(model.children())[:n]
    encoder = nn.Sequential(*ls)
    return encoder


def num_features_model(m: nn.Module, in_chs:int = 3):
    "Return the number of output features for `m`."
    m.to('cpu')
    dummy_inp = torch.zeros((32, in_chs, 120, 120))
    dummy_out = m(dummy_inp)
    return dummy_out.size()[1]

In [6]:
import torchvision
orig_model = torchvision.models.resnet18(pretrained=False)
model = cut_model(orig_model, -2)

In [7]:
num_features_model(model), num_features_model(orig_model)

(512, 1000)

In [8]:
#export
def create_head(nf: int, n_out: int, lin_ftrs: int = 512, act: nn.Module = nn.ReLU(inplace=True)):
    "create a custom head for a classifier from FastAI"
    lin_ftrs = [nf, lin_ftrs, n_out]

    pool = AdaptiveConcatPool2d()

    layers = [pool, nn.Flatten()]

    layers += [
        nn.BatchNorm1d(lin_ftrs[0]),
        nn.Dropout(0.25),
        act,
        nn.Linear(lin_ftrs[0], lin_ftrs[1], bias=False),
        nn.BatchNorm1d(lin_ftrs[1]),
        nn.Dropout(0.5),
        act,
        nn.Linear(lin_ftrs[1], lin_ftrs[2], bias=False),
    ]
    return nn.Sequential(*layers)

In [9]:
create_head(num_features_model(model), n_out=3)

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): ReLU(inplace=True)
  (5): Linear(in_features=512, out_features=512, bias=False)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): ReLU(inplace=True)
  (9): Linear(in_features=512, out_features=3, bias=False)
)

In [18]:
#export
class TransferLearningModel(nn.Module):
    "Transfer Learning with `encoder`"
    def __init__(self, encoder:nn.Module, c:int, cut:int=-2, **kwargs):
        """
        Args:
            encoder: the classifer to extract features
            c: number of output classes
            cut: number of layers to cut/keep from the encoder
            **kwargs: arguments for `create_head`
        """
        super(TransferLearningModel, self).__init__()
        
        self.encoder_name = encoder.__class__.__name__
        # cut layers from the encoder
        self.encoder = cut_model(encoder, cut)
        # create the custom head for the model
        feats  = num_features_model(self.encoder, in_chs=3) * 2
        self.c = c
        self.fc = create_head(feats, n_out=c, **kwargs)
           
    @property
    def encoder_class_name(self): 
        return self.encoder_name
        
    def forward(self, xb): 
        return self.fc(self.encoder(xb))

In [19]:
m = TransferLearningModel(encoder=orig_model, c=5, cut=-2, lin_ftrs=255, act=nn.SiLU(inplace=True))

In [20]:
dummy_inp = torch.zeros((32, 3, 120, 120))
dummy_out = m(dummy_inp)
dummy_out.size()

torch.Size([32, 5])

In [21]:
#export
def replace_activs(model, func, activs: list = [nn.ReLU, nn.SiLU]):
    "recursively replace all the `activs` with `func`"
    for child_name, child in model.named_children():
        for act in activs:
            if isinstance(child, act):
                setattr(model, child_name, func)
        else:
            replace_activs(child, func)

In [22]:
replace_activs(orig_model, func=Mish())
m = TransferLearningModel(encoder=orig_model, c=5, cut=-2, lin_ftrs=255, act=Mish())

In [23]:
dummy_inp = torch.zeros((32, 3, 120, 120))
dummy_out = m(dummy_inp)
dummy_out.size()

torch.Size([32, 5])

In [25]:
print(m)

TransferLearningModel(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): Mish()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (

In [26]:
m.fc

Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): Flatten(start_dim=1, end_dim=-1)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Mish()
  (5): Linear(in_features=1024, out_features=255, bias=False)
  (6): BatchNorm1d(255, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Mish()
  (9): Linear(in_features=255, out_features=5, bias=False)
)

In [39]:
#export
class BasicTransferLearningModel(nn.Module):
    "Transfer Learning with model to be comaptible with Snapmix"
    def __init__(self, encoder:nn.Module, c:int, cut:int=-2, **kwargs):
        """
        Args:
            encoder: the classifer to extract features
            c: number of output classes
            cut: number of layers to cut/keep from the encoder
            **kwargs: arguments for `create_head`
        """
        super(BasicTransferLearningModel, self).__init__()
        
        try   : feats  = encoder.fc.in_features
        except: feats  = encoder.classifier.in_features
        
        self.encoder_name = encoder.__class__.__name__
        # cut layers from the encoder
        self.encoder = cut_model(encoder, cut)
        # create the custom head for the model
        
        self.c = c
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(feats, self.c)
           
    @property
    def encoder_class_name(self): 
        return self.encoder_name
        
    def forward(self, xb): 
        return self.fc(self.encoder(xb))      

In [40]:
m = BasicTransferLearningModel(orig_model, c=5)

In [41]:
m.encoder_class_name

'ResNet'

In [42]:
m

BasicTransferLearningModel(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Mish()
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): Mish()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    