In [None]:
#default_exp networks

In [None]:
import timm

In [None]:
#export
import torch
from torch import nn
import torch.nn.functional as F

from src.core import *
from src.layers import *

In [None]:
#export
class TransferLearningModel(nn.Module):
    "Transfer Learning with `encoder`"
    def __init__(self, encoder:nn.Module, c:int, cut:int=-2, **kwargs):
        """
        Args:
            encoder: the classifer to extract features
            c: number of output classes
            cut: number of layers to cut/keep from the encoder
            **kwargs: arguments for `create_head`
        """
        super(TransferLearningModel, self).__init__()
        
        self.encoder_name = encoder.__class__.__name__
        # cut layers from the encoder
        self.encoder = cut_model(encoder, cut)
        # create the custom head for the model
        feats  = num_features_model(self.encoder, in_chs=3) * 2
        self.c = c
        self.fc = create_head(feats, n_out=c, **kwargs)
           
    @property
    def encoder_class_name(self): 
        return self.encoder_name
        
    def forward(self, xb): 
        return self.fc(self.encoder(xb))

In [None]:
m = TransferLearningModel(timm.create_model('resnet18'), c=5, cut=-2)
print(m)

TransferLearningModel(
  (encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, mom

In [None]:
dummy_inp = torch.zeros((32, 3, 120, 120))
dummy_out = m(dummy_inp)
dummy_out.size()

torch.Size([32, 5])

In [None]:
#export
#TODO: add midlevel classification branch in learning.
class SnapMixTransferLearningModel(nn.Module):
    "Transfer Learning with model to be comaptible with Snapmix"
    def __init__(self, encoder:nn.Module, c:int, cut:int=-2, **kwargs):
        """
        Args:
            encoder: the classifer to extract features
            c: number of output classes
            cut: number of layers to cut/keep from the encoder
        """
        super(SnapMixTransferLearningModel, self).__init__()
        
        try   : feats  = encoder.fc.in_features
        except: feats  = encoder.classifier.in_features
        
        self.encoder_name = encoder.__class__.__name__
        # cut layers from the encoder
        self.encoder = cut_model(encoder, cut)
        # create the custom head for the model
        
        self.c     = c
        self.pool  = nn.AdaptiveAvgPool2d((1, 1))
        
        layers  = [nn.Flatten(), nn.BatchNorm1d(feats), nn.Dropout(0.5)]
        self.ls = nn.Sequential(*layers)
        
        self.fc = nn.Linear(feats, self.c)
        
    def mid_forward(self, xb, detach=True): 
        pass  
              
    @property
    def encoder_class_name(self): 
        return self.encoder_name
        
    def forward(self, xb): 
        fmps = self.encoder(xb)
        x = self.pool(fmps)
        x = self.ls(x)
        return self.fc(x)

In [None]:
m = SnapMixTransferLearningModel(timm.create_model('efficientnet_b0'), c=5, cut=-2)
print(m)

SnapMixTransferLearningModel(
  (encoder): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): SiLU(inplace=True)
    (3): Sequential(
      (0): Sequential(
        (0): DepthwiseSeparableConv(
          (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (act1): SiLU(inplace=True)
          (se): SqueezeExcite(
            (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (act1): SiLU(inplace=True)
            (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          )
          (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   

In [None]:
dummy_inp = torch.zeros((32, 3, 120, 120))
dummy_out = m(dummy_inp)
dummy_out.size()

torch.Size([32, 5])

In [None]:
from src.lightning.core import params

In [None]:
param_list = params(m.fc) + params(m.ls)
param_list

[Parameter containing:
 tensor([[-0.0150, -0.0100, -0.0052,  ..., -0.0071,  0.0052, -0.0142],
         [-0.0220, -0.0195, -0.0164,  ..., -0.0273, -0.0100,  0.0051],
         [-0.0248, -0.0279,  0.0106,  ..., -0.0233, -0.0221,  0.0018],
         [ 0.0131,  0.0140,  0.0067,  ..., -0.0050, -0.0079,  0.0048],
         [-0.0221, -0.0176,  0.0182,  ..., -0.0011,  0.0056, -0.0138]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.0102,  0.0055,  0.0102, -0.0047,  0.0215], requires_grad=True),
 Parameter containing:
 tensor([1., 1., 1.,  ..., 1., 1., 1.], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0.,  ..., 0., 0., 0.], requires_grad=True)]

In [None]:
opt = torch.optim.SGD(params=param_list, lr=1e03)
opt.step()

In [None]:
#hide
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 00a_lightning.core.ipynb.
Converted 00b_fastai.core.ipynb.
Converted 01_mixmethods.ipynb.
Converted 02_layers.ipynb.
Converted 02a_networks.ipynb.
Converted 03_optimizers.ipynb.
Converted index.ipynb.
