In [4]:
import sys
sys.path.append("../..")

import torch
import torch.nn as nn
from torch.nn import functional as F
from src.models import *

In [5]:
class TransparentGAP(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        self.feature_maps = F.adaptive_avg_pool2d(x, 1)
        return x

class AttentionGAP(nn.Module):
    def __init__(self, n_classes=28):
        super().__init__()
        self.n_classes = n_classes
    
    def forward(self, x):
        
        
    
class DynamicGapNet(nn.Module):
    def __init__(self, encoder, n_classes, gap_layer_idxs=None, dropout_prob=0.25):
        """
        encoder: nn.Sequential object
        gap_feature_start_idx: these are the layers after which a GAP layer will be placed
        """
        super().__init__()
        sizes, *_ = fastai.callbacks.model_sizes(encoder)
        if gap_layer_idxs is None:
            gap_layer_idxs = range(len(encoder))
        layers = []
        total_n_features = 0
        self.gaps = []
        for i, (size, module) in enumerate(zip(sizes, encoder)):
            layers.append(module)
            if i in gap_layer_idxs:
                total_n_features += size[1]
                gap = TransparentGAP()
                self.gaps.append(gap)
                layers.append(gap)
        self.model = nn.Sequential(*layers)
        self.fc = nn.Linear(total_n_features, n_classes)
        self.dropout_prob = dropout_prob

    def forward(self, x):
        _ = self.model(x)
        features = torch.cat([gap.feature_maps for gap in self.gaps], dim=1)
        bs = features.size()[0]
        features = features.squeeze()
        if bs == 1:
            features = features.unsqueeze(0)
        features = F.dropout(features, p=self.dropout_prob)
        logits = self.fc(features)
        return logits


def one_level_flatten(model):
    model_flattened = []
    for module in model:
        if isinstance(module, nn.Sequential):
            for inner_module in module:
                model_flattened.append(inner_module)
        else:
            model_flattened.append(module)
    return nn.Sequential(*model_flattened)


def gapnet_resnet34_four_channel_input_backbone(pretrained=True):
    encoder = resnet34_four_channel_input_one_fc(pretrained)[:-1]
    flattened_encoder = one_level_flatten(encoder)
    resnet34_gapnet = DynamicGapNet(flattened_encoder, n_classes=28, gap_layer_idxs=range(4, len(flattened_encoder)))
    return resnet34_gapnet

In [10]:
encoder = resnet34_four_channel_input_one_fc(pretrained=True)

In [11]:
encoder

Sequential(
  (0): Conv2d(4, 64, kernel_size=(7, 7), stride=(3, 3), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, ker