# DRN- With skip connection coming at the front average pooled to 1x1 and attached to flattened layer

In [21]:
from fastai.vision import flatten_model
from torch import nn
from fastai.callbacks.hooks import model_sizes
import numpy as np

In [22]:
def get_hook_idxs(module):
    "Get the indexes of the layers where the size of the activation changes."
    sizes = model_sizes(module, size=(32, 32))
    feature_szs = [size[-1] for size in sizes]
    sfs_idxs = list(np.where(np.array(feature_szs[:-1]) != np.array(feature_szs[1:]))[0])
    if feature_szs[0] != feature_szs[1]: sfs_idxs = [0] + sfs_idxs
    return sfs_idxs

In [43]:
from drn import drn_d_22, drn_d_24, drn_c_26

module = drn_d_24(pretrained=False)

In [46]:
from fastai.vision import create_body

module = create_body(drn_d_24, False, cut=-1)
module

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (1): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padd

In [35]:
get_hook_idxs(module )

[1, 2, 3, 8]

In [26]:
model_sizes(module, size=(32, 32))

[torch.Size([1, 16, 32, 32]),
 torch.Size([1, 16, 32, 32]),
 torch.Size([1, 32, 16, 16]),
 torch.Size([1, 64, 8, 8]),
 torch.Size([1, 128, 4, 4]),
 torch.Size([1, 256, 4, 4]),
 torch.Size([1, 512, 4, 4]),
 torch.Size([1, 512, 4, 4]),
 torch.Size([1, 512, 4, 4]),
 torch.Size([1, 512, 1, 1])]

DRN(
  (layer0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (layer1): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (layer2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride

In [55]:
from fastai.callbacks import hook_outputs
import torch

In [102]:
class DRNModified(nn.Module):
    def __init__(self, backbone_fn, num_classes=100, **kwargs):
        super().__init__()
        self.backbone = create_body(backbone_fn, **kwargs)
        self.sizes = model_sizes(self.backbone, size=(32, 32))
        self.hook_idxs = get_hook_idxs(self.backbone)
        self.hooks = hook_outputs([self.backbone[i] for i in self.hook_idxs])
        self.convs = nn.ModuleList([nn.Sequential(nn.Conv2d(self.sizes[s][1], 64, kernel_size=3, padding=1),
                                                  nn.BatchNorm2d(64),
                                                  nn.ReLU(inplace=True),
                                                  nn.AdaptiveAvgPool2d(1)
                                   ) for s in self.hook_idxs])
        
        final_in = 64 * len(self.hook_idxs) + self.sizes[-1][1]
        self.final = nn.Conv2d(final_in, num_classes, kernel_size=1)
        
        
    def forward(self, inp):
        feats = [self.backbone(inp)]
        for i, conv in enumerate(self.convs):
            feats.append(conv(self.hooks[i].stored))
        self.remove()
        final_feats = torch.cat(feats, 1)        
        return self.final(final_feats).squeeze()
        
    def remove(self):
        for hook in self.hooks:
            hook.remove()        

In [103]:
model.sizes[1], model.sizes[2], model.sizes[3], model.sizes[8],

(torch.Size([1, 16, 32, 32]),
 torch.Size([1, 32, 16, 16]),
 torch.Size([1, 64, 8, 8]),
 torch.Size([1, 512, 4, 4]))

In [104]:
model.hook_idxs

[1, 2, 3, 8]

In [105]:
model.convs

ModuleList(
  (0): Sequential(
    (0): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): AdaptiveAvgPool2d(output_size=1)
  )
  (1): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): AdaptiveAvgPool2d(output_size=1)
  )
  (2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): AdaptiveAvgPool2d(output_size=1)
  )
  (3): Sequential(
    (0): Conv2d(512, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): AdaptiveAvgPool2

In [106]:
model = DRNModified(drn_d_24, pretrained=False, cut=-1).cuda()
model(torch.randn(64, 3, 32, 32).cuda()).shape

torch.Size([64, 100])

In [107]:
from fastai.vision import cnn_learner
cnn_learner??

In [140]:
class Mod(nn.Module):
    def __init__(self, backbone_fn, num_classes=100, middle_size=64, **kwargs):
        super().__init__()
        self.backbone = create_body(backbone_fn, **kwargs)
        self.sizes = model_sizes(self.backbone, size=(32, 32))
        self.hook_idxs = get_hook_idxs(self.backbone)
        self.hooks = hook_outputs([self.backbone[i] for i in self.hook_idxs])
        self.convs = nn.ModuleList([nn.Sequential(nn.Conv2d(self.sizes[s][1], middle_size, kernel_size=3, padding=1),
                                                  nn.BatchNorm2d(middle_size),
                                                  nn.ReLU(inplace=True),
                                                  nn.AdaptiveAvgPool2d(1)
                                   ) for s in self.hook_idxs])
        
        final_in = middle_size * len(self.hook_idxs) + self.sizes[-1][1]
        self.final = nn.Conv2d(final_in, num_classes, kernel_size=1)
        
        
    def forward(self, inp):
        feats = [self.backbone(inp)]
        for i, conv in enumerate(self.convs):
#             print(i)
            feats.append(conv(self.hooks[i].stored))
        self.hooks.remove()
        final_feats = torch.cat(feats, 1)        
        return self.final(final_feats).squeeze()
        
    def remove(self):
         if hasattr(self, "hooks"):
            self.hooks.remove()

In [142]:
model = Mod(drn_d_24, pretrained=False, cut=-1).cuda()
model(torch.randn(, 3, 32, 32).cuda()).shape

torch.Size([32, 100])

In [135]:
a = model.hooks[0].stored

In [137]:
b = model.hooks[0].stored

In [138]:
(a - b).sum()

tensor(-55112.4531, device='cuda:0')

In [139]:
torch.equal(a, b)

False

# Try fastai EfficientNet  models

In [None]:
from  efficientnet_pytorch import EfficientNet

In [None]:
EfficientNet.from_pretrained()

#  Try HRNET