In [3]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.models import efficientnet_v2_l, efficientnet_v2_m, efficientnet_v2_s  
#from .cbam import CBAMBlock
        
    
class EffnetV2_L(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1, dropout = 0.4, use_sigmoid = False, use_attention = False):
        super().__init__()
        
        self.use_sigmoid = use_sigmoid
        self.use_attention = use_attention
        self.dropout = dropout
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.avgpool = torch.nn.Identity()
        self.model.classifier = torch.nn.Sequential(nn.Dropout(self.dropout), nn.Linear(1280, self.out_features))
        self.sigmoid = torch.nn.Sigmoid()
        if self.use_attention:
            self.spatial_attention = SpatialAttention(feature_map_size = 16, n_channels=1280)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x):
        if self.use_attention:
            x = self.model.features(x)
            x = self.spatial_attention(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        else:
            x = self.model.features(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        if self.use_sigmoid:
            x = self.sigmoid(x)
        return x

In [12]:
model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1', layer_norm = torch.nn.GroupNorm())

TypeError: __init__() missing 2 required positional arguments: 'num_groups' and 'num_channels'

In [11]:
import torch.nn as nn
import torchvision.models as models

model = models.efficientnet_v2_s()

# Create a list of all batch normalization modules in the model
bn_modules = [module for module in model.modules() if isinstance(module, nn.BatchNorm2d)]

# Remove all batch normalization modules from the model
for bn_module in bn_modules:
    print(bn_module)
    bn_module_index = list(model.modules()).index(bn_module)
    model_modules = list(model.modules())
    model_modules[bn_module_index] = nn.Identity()
    new_model = nn.Sequential(*model_modules)
    model = new_model

print(model)



BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(96, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(48, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
BatchNorm2d(64, eps=0.001, momentum=

KeyboardInterrupt: 

In [22]:
import torch
import torchvision.models as models
import torch.nn as nn

# Load the pre-trained EfficientNet-B0 model
model = models.efficientnet_b0(pretrained=True)


# Access the block containing batch normalization


  f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, "


In [83]:
# Access the batch normalization layer
for feature in model.features:
    for layer in feature:
            if isinstance(layer, nn.BatchNorm2d):
                setattr(layer, 'track_running_stats', False)  # Disable running stats tracking
                setattr(model, name, nn.Identity())
            if isinstance(layer, nn.Sequential):
                for i in layer:
                    if isinstance(i, nn.BatchNorm2d):
                        i = torch.nn.Identity()

In [84]:
model.named_modules

<bound method Module.named_modules of EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=False)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, af

In [67]:
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        setattr(module, 'track_running_stats', False)  # Disable running stats tracking
        setattr(model, name, nn.Identity())  # Replace the module with identity

RuntimeError: OrderedDict mutated during iteration

In [68]:
bn_layers = []
for name, module in model.named_modules():
    if isinstance(module, nn.BatchNorm2d):
        bn_layers.append(name)

# Iterate over the batch normalization layers and replace with identity
for name in bn_layers:
    module = getattr(model, name)
    setattr(module, 'track_running_stats', False)  # Disable running stats tracking
    setattr(model, name, nn.Identity())  

In [70]:
for name, module in model.named_modules():
    if isinstance(module, nn.Sequential) and len(module) == 3:
        # Replace the batchnorm layer with an identity layer
        module[1] = nn.Identity()

In [82]:
# Load the pre-trained EfficientNet-v2-S model
model = models.efficientnet_v2_s(pretrained=True)

# Iterate over the ConvBnAct2d blocks and remove batch normalization
for name, module in model.named_modules():
    print(name, module)
    if isinstance(module, nn.ConvBnAct2d):
        # Replace the ConvBnAct2d block with Conv2d + SiLU
        new_module = nn.Sequential(
            module.conv,
            nn.SiLU()
        )
        setattr(model, name, new_module)

 EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
 

AttributeError: module 'torch.nn' has no attribute 'ConvBnAct2d'

In [6]:
"""
Creates a EfficientNetV2 Model as defined in:
Mingxing Tan, Quoc V. Le. (2021). 
EfficientNetV2: Smaller Models and Faster Training
arXiv preprint arXiv:2104.00298.
import from https://github.com/d-li14/mobilenetv2.pytorch
"""

import torch
import torch.nn as nn
import math

__all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl']


def _make_divisible(v, divisor, min_value=None):
    """
    This function is taken from the original tf repo.
    It ensures that all layers have a channel number that is divisible by 8
    It can be seen here:
    https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
    :param v:
    :param divisor:
    :param min_value:
    :return:
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


# SiLU (Swish) activation function
if hasattr(nn, 'SiLU'):
    SiLU = nn.SiLU
else:
    # For compatibility with old PyTorch versions
    class SiLU(nn.Module):
        def forward(self, x):
            return x * torch.sigmoid(x)

 
class SELayer(nn.Module):
    def __init__(self, inp, oup, reduction=4, num_groups = 4):
        super(SELayer, self).__init__()
        self.num_groups = num_groups
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
                nn.Linear(oup, _make_divisible(inp // reduction, 8)),
                SiLU(),
                nn.Linear(_make_divisible(inp // reduction, 8), oup),
                nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


def conv_3x3_bn(inp, oup, stride, num_groups = 4):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
        nn.GroupNorm(num_channels=oup, num_groups=num_groups),
        SiLU()
    )


def conv_1x1_bn(inp, oup, num_groups = 4):
    return nn.Sequential(
        nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
        nn.GroupNorm(num_channels=oup, num_groups=num_groups),
        SiLU()
    )


class MBConv(nn.Module):
    def __init__(self, inp, oup, stride, expand_ratio, use_se, num_groups = 4):
        super(MBConv, self).__init__()
        assert stride in [1, 2]

        hidden_dim = round(inp * expand_ratio)
        self.identity = stride == 1 and inp == oup
        if use_se:
            self.conv = nn.Sequential(
                # pw
                nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
                nn.GroupNorm(num_channels=hidden_dim, num_groups=num_groups),
                SiLU(),
                # dw
                nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
                nn.GroupNorm(num_channels=hidden_dim, num_groups=num_groups),
                SiLU(),
                SELayer(inp, hidden_dim),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.GroupNorm(num_channels=oup, num_groups=num_groups),
            )
        else:
            self.conv = nn.Sequential(
                # fused
                nn.Conv2d(inp, hidden_dim, 3, stride, 1, bias=False),
                nn.GroupNorm(num_channels=hidden_dim, num_groups=num_groups),
                SiLU(),
                # pw-linear
                nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
                nn.GroupNorm(num_channels=oup, num_groups=num_groups),
            )


    def forward(self, x):
        if self.identity:
            return x + self.conv(x)
        else:
            return self.conv(x)


class EffNetV2(nn.Module):
    def __init__(self, cfgs, num_classes=1000, width_mult=1.):
        super(EffNetV2, self).__init__()
        self.cfgs = cfgs

        # building first layer
        input_channel = _make_divisible(32 * width_mult, 8)
        layers = [conv_3x3_bn(1, input_channel, 2)]
        # building inverted residual blocks
        block = MBConv
        for t, c, n, s, use_se in self.cfgs:
            output_channel = _make_divisible(c * width_mult, 8)
            for i in range(n):
                layers.append(block(input_channel, output_channel, s if i == 0 else 1, t, use_se))
                input_channel = output_channel
        self.features = nn.Sequential(*layers)
        self.features.add_module('last_conv', conv_1x1_bn(input_channel, 1280))
        # building last several layers
        output_channel = 1280
        #self.conv = conv_1x1_bn(input_channel, output_channel)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.classifier = nn.Linear(output_channel, num_classes)

        self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        print(x.shape)
        #x = self.conv(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.GroupNorm):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.weight.data.normal_(0, 0.001)
                m.bias.data.zero_()


def effnetv2_s(**kwargs):
    """
    Constructs a EfficientNetV2-S model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  24,  2, 1, 0],
        [4,  48,  4, 2, 0],
        [4,  64,  4, 2, 0],
        [4, 128,  6, 2, 1],
        [6, 160,  9, 1, 1],
        [6, 256, 15, 2, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_m(**kwargs):
    """
    Constructs a EfficientNetV2-M model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  24,  3, 1, 0],
        [4,  48,  5, 2, 0],
        [4,  80,  5, 2, 0],
        [4, 160,  7, 2, 1],
        [6, 176, 14, 1, 1],
        [6, 304, 18, 2, 1],
        [6, 512,  5, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_l(**kwargs):
    """
    Constructs a EfficientNetV2-L model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  7, 2, 0],
        [4,  96,  7, 2, 0],
        [4, 192, 10, 2, 1],
        [6, 224, 19, 1, 1],
        [6, 384, 25, 2, 1],
        [6, 640,  7, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)


def effnetv2_xl(**kwargs):
    """
    Constructs a EfficientNetV2-XL model
    """
    cfgs = [
        # t, c, n, s, SE
        [1,  32,  4, 1, 0],
        [4,  64,  8, 2, 0],
        [4,  96,  8, 2, 0],
        [4, 192, 16, 2, 1],
        [6, 256, 24, 1, 1],
        [6, 512, 32, 2, 1],
        [6, 640,  8, 1, 1],
    ]
    return EffNetV2(cfgs, **kwargs)

In [7]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from torchvision.models import efficientnet_v2_l
#from .cbam import CBAMBlock

class SpatialAttention(torch.nn.Module):
    def __init__(self, feature_map_size = 16, n_channels = 1280):
        super().__init__()
    
        self.n_channels = n_channels
        self.feature_map_size = feature_map_size
        self.keys = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0)
        self.queries = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0)
        self.values = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0)
        self.refine = torch.nn.Conv2d(self.n_channels, self.n_channels, kernel_size=1, stride=1, padding=0)
        self.softmax = torch.nn.Softmax2d()
        self.alpha = torch.nn.Parameter(torch.zeros(1))
        
    def forward(self, x):
        
        attended_features = torch.matmul(self.softmax(torch.matmul(self.keys(x).view(x.size(0), self.n_channels, -1).permute(0, 2, 1), 
                                                                   self.queries(x).view(x.size(0), self.n_channels, -1))/self.n_channels**0.5), 
                                         self.values(x).view(x.size(0), self.n_channels, -1).permute(0, 2, 1)) # (batch_size, feature_map_size * feature_map_size, n_channels)
        attended_features = attended_features.permute(0, 2, 1).view(x.size(0), self.n_channels, self.feature_map_size, self.feature_map_size) # (batch_size, n_channels, feature_map_size, feature_map_size)
        attended_features = self.refine(attended_features)
        attended_features = self.alpha * attended_features + x
        
        return attended_features
    
    
    
class KeyFrameAttention(torch.nn.Module):
    def __init__(self, n_frames = 4, n_channels = 1280):
        super().__init__()
    
    
        self.n_frames = n_frames
        self.n_channels = n_channels
        self.keys = torch.nn.Linear(self.n_channels, self.n_channels)
        self.queries = torch.nn.Linear(self.n_channels, self.n_channels)
        self.values = torch.nn.Linear(self.n_channels, self.n_channels)
        self.refine = torch.nn.Linear(self.n_channels, self.n_channels)
        self.softmax = torch.nn.Softmax(dim=-1)
        #self.alpha = torch.nn.Parameter(torch.zeros(1))
        
    def forward(self, x, Mask = None):
        # x shape: 
        #print('x shape', x.shape)
        keys = self.keys(x) # (batch_size, n_frames, n_channels)
        #print('keys', keys.shape)
        queries = self.queries(x) # (batch_size, n_frames, n_channels)
        #print('queries', queries)
        values = self.values(x) # (batch_size, n_frames, n_channels)
        #print('values', values)
        matmul = torch.matmul(queries, keys.permute(0, 2, 1)).float() # (batch_size, n_channels, n_frames)
        #print('matmul', matmul.shape)
        if Mask is not None:
            matmul = matmul.masked_fill(Mask == 0, -1e20)
            #print('matmul masked', matmul)
        #print('matmul shape', matmul.shape)
        softmax = self.softmax(matmul/(self.n_channels) ** 0.5) # (batch_size, n_channels, n_frames)
        attention_map = torch.matmul(values.permute(0, 2, 1), softmax) # (batch_size, n_channels, n_frames)
        #print('attention_map', attention_map.shape)
        #print('attention_map', attention_map)
        attended_features = self.refine(attention_map.permute(0, 2, 1)) # (batch_size, n_frames, n_channels)
        #print('attended_features', attended_features)
        #attended_features = attended_features + x
        #print('attended_features', attended_features)
        #attended_features = attended_features[:, :org_seq_len, :]
        #attended_features = attended_features.permute(0, 2, 1)
        #print('attended_features', attended_features)
        #print('attended_features', attended_features.shape)
        #print('attended_features', attended_features)
        return attended_features


    

class EffnetV2_Key_Frame(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1, dropout = 0.4, use_sigmoid = False, 
                 use_attention = True, use_key_frame_attention = False,
                 max_len = 4):
        super().__init__()
        
        self.max_len = max_len
        self.use_key_frame_attention = use_key_frame_attention
        self.use_sigmoid = use_sigmoid 
        self.use_attention = use_attention
        self.dropout = dropout
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = effnetv2_l()
        #self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.avgpool = torch.nn.Identity()
        self.model.classifier = torch.nn.Sequential(nn.Dropout(self.dropout), nn.Linear(1280, self.out_features))
        self.sigmoid = torch.nn.Sigmoid()
        if self.use_attention:
            self.spatial_attention = SpatialAttention(feature_map_size = 16, n_channels=1280)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        if self.use_key_frame_attention:
            self.key_frame_attention = KeyFrameAttention(n_frames = 4, n_channels = 1280)
        
    def features_padding(self, features, max_length, split_sizes):
        padded_output = []
        for i, feature in enumerate(features):
            padding_length = max_length - split_sizes[i]
            padded_seq = torch.nn.functional.pad(feature, (0, padding_length, 0, 0, 0, 0), mode='constant', value=0)
            padded_output.append(padded_seq)
        padded_output = torch.stack(padded_output)
        return padded_output
    
    def mask_sequence(self, padded_seq, org_seq_lens):  
        
        # Define the padded sequence
        padded_seq_len = padded_seq.size(-1)
        batch_len = padded_seq.size(0)
        # Define the mask tensor
        mask = torch.zeros((batch_len, padded_seq_len, padded_seq_len), dtype=torch.float32)
        
        # Set the non-padding elements to 1's
        for i in range(batch_len):
            mask[i, :, :org_seq_lens[i]] = 1
        
        
        return mask
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x, org_seq_len):

        #print('shape of x', x.shape)
        x = torch.cat(x)
        
        features = self.model.features(x)
        if self.use_attention:
            features = self.spatial_attention(features)
        
        #print('features shape', features.shape)
        features = self.avgpool(features).squeeze(-1).permute(2, 1, 0)
        print('features shape', features.shape)
        #print('features shape', features.shape)
        tensor_list = torch.split(features, split_size_or_sections = org_seq_len, dim=2)
        
        features_padded = self.features_padding(tensor_list, self.max_len, org_seq_len)
        #print('features_padded shape', features_padded.shape)
        #print('features_padded shape', features_padded.shape)
        features_padded = features_padded.squeeze(1)
        #print('features_padded shape', features_padded.shape)
        mask = self.mask_sequence(features_padded, org_seq_lens = org_seq_len).cuda() if torch.cuda.is_available() else self.mask_sequence(features_padded, org_seq_lens = org_seq_len)
        #print('mask shape', mask.shape)
        
        if self.use_key_frame_attention:
            x = self.key_frame_attention(features_padded.permute(0, 2, 1), Mask = mask)
            x = self.model.classifier(x)
        else:
            #print(features_padded.shape)
            features_padded = features_padded.mean(dim = 1)
            #print(features_padded.shape)
            x = self.model.classifier(features_padded)
            #print('x shape in else', x.shape)
            #print(x.shape)
        if self.use_sigmoid:
            x = self.sigmoid(x)
        x = x.mean(dim = 1)
    #print('x', x.shape)
        return x


        
    
class EffnetV2_L(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1, dropout = 0.4, use_sigmoid = False, use_attention = False):
        super().__init__()
        
        self.use_sigmoid = use_sigmoid
        self.use_attention = use_attention
        self.dropout = dropout
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.avgpool = torch.nn.Identity()
        self.model.classifier = torch.nn.Sequential(nn.Dropout(self.dropout), nn.Linear(1280, self.out_features))
        self.sigmoid = torch.nn.Sigmoid()
        if self.use_attention:
            self.spatial_attention = SpatialAttention(feature_map_size = 16, n_channels=1280)
        self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
        
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x):
        if self.use_attention:
            x = self.model.features(x)
            x = self.spatial_attention(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        else:
            x = self.model.features(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.model.classifier(x)
        if self.use_sigmoid:
            x = self.sigmoid(x)
        return x


class EffnetV2_L_pos_encoding(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1):
        super().__init__()
        
        
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        #self.model.classifier = torch.nn.Sequential(nn.Dropout(0.4), nn.Linear(1280, self.out_features))
        self.model.classifier = torch.nn.Identity()
        self.classifier = torch.nn.Sequential(nn.Dropout(0.4), nn.Linear(1280, self.out_features))
        
        
        # Positional embedding 
        max_len = 1000
        num_hiddens = 1280
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
        self.P = self.P.squeeze(0)
        
        
    def count_params(self):
        
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
        
    def forward(self, x):
        img = x[0]
        pos = x[1].cpu()
        features = self.model(img)
        
        positions = self.P[pos, :].cuda()
        out = features + positions
        out = self.classifier(out)
        return out, features
        




        
# test_tensor = torch.rand(1, 1, 448, 448).cuda()
# pos = torch.tensor(50).cuda()
# model = EffnetV2_L(out_features = 7, in_channels = 1).cuda()

# print(model((test_tensor, pos)).shape)


class EffnetV2_L_meta(torch.nn.Module):
    def __init__(self, out_features = 7, in_channels = 1, dropout = 0.4):
        super().__init__()
        
        
        self.dropout = dropout
        self.out_features = out_features
        self.in_channels = in_channels
        self.model = efficientnet_v2_l(weights = 'EfficientNet_V2_L_Weights.IMAGENET1K_V1')
        self.model.features[0] = torch.nn.Conv2d(self.in_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.model.classifier = torch.nn.Identity()
        #self.model.classifier = torch.nn.Sequential(nn.Dropout(0.4), nn.Linear(1280, self.out_features))
        self.classifier = torch.nn.Sequential(nn.Linear(1280 + 4, 512),
                                             nn.BatchNorm1d(512),
                                             torch.nn.SiLU(),
                                             nn.Dropout(self.dropout),
                                             torch.nn.Linear(512, out_features = self.out_features),) # 1280 + 64 meta feature (days, frame_location)
        self.meta = torch.nn.Sequential(nn.Linear(2, 4),
                                        nn.BatchNorm1d(4),
                                        nn.SiLU(),)
                                        
    def count_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
    def forward(self, x):
        img = x[0]
        meta = x[1]
        meta = self.meta(meta)
        features = self.model(img)
        features = torch.cat([features, meta], dim = 1)
        out = self.classifier(features)
        return out
    
    
model = EffnetV2_Key_Frame(out_features = 1, in_channels = 1, dropout = 0.4, use_key_frame_attention=True)
split_sizes = [3, 2, 4, 4]
org_batch = [torch.rand(i, 1, 512, 512) for i in split_sizes]
print(model(org_batch, split_sizes).shape)

features shape torch.Size([1, 1280, 13])
torch.Size([4, 1])


In [3]:
model = effnetv2_l()


In [4]:
test_tensor = torch.randn(1, 1, 512, 512)

In [5]:
model(test_tensor)

torch.Size([1, 1280, 16, 16])


tensor([[ 9.9423e-03, -2.5810e-02, -1.0292e-03,  9.8495e-03,  1.2715e-02,
          7.1520e-03,  1.2836e-03,  1.2975e-03, -6.1024e-04, -2.2802e-02,
         -5.2287e-03,  2.5930e-02, -4.5033e-03,  3.7363e-03, -9.2677e-03,
         -1.3247e-03, -2.0822e-03, -2.6916e-03, -5.2659e-03, -6.9323e-03,
          1.5951e-02, -4.7287e-03, -3.4655e-03,  1.9892e-02, -5.3976e-03,
          2.0525e-02,  7.8521e-03, -3.6106e-04,  7.0998e-03,  1.2976e-02,
          2.3873e-02, -5.0131e-03,  9.7873e-03, -1.7515e-02, -2.5771e-03,
          6.6768e-03,  1.8742e-02,  2.1666e-03, -2.3195e-02, -2.2833e-02,
          5.6540e-03,  6.0438e-03, -4.6009e-03, -2.5070e-03,  1.7450e-02,
         -1.2427e-02,  2.4086e-03, -1.5538e-02, -8.4707e-03, -3.8669e-04,
          1.1338e-02, -1.0951e-02,  3.0101e-03, -4.0156e-03, -7.6427e-03,
          1.6752e-02, -5.6868e-03,  2.5995e-03, -2.1607e-02, -1.9860e-02,
          7.2404e-03,  3.4970e-03,  2.7539e-02, -6.5138e-03,  3.5483e-03,
          1.4101e-02, -8.6467e-03, -3.

In [21]:
from torchvision.models import efficientnet_v2_l

In [75]:
model = efficientnet_v2_l()


In [76]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

118515272

In [23]:
print(model)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
      (1): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  