In [1]:
import torch.nn as nn
import torch
from dataclasses import dataclass
from torchvision.models import efficientnet_v2_s
from torch import Tensor
from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation    
from typing import Callable, Optional, Tuple, List, Dict, Any, Sequence, Union
import math
import copy
from torchvision.ops import StochasticDepth
from functools import partial

In [2]:
import torch

def _make_divisible(v, divisor=8, min_value=None):
    """
    This function takes an input value `v` and ensures that it is divisible by
    the specified `divisor`. Optionally, a `min_value` can be provided to ensure
    that the returned value is not less than a certain minimum value.
    """
    if min_value is None:
        min_value = divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < 0.9 * v:
        new_v += divisor
    return new_v


In [3]:
@dataclass
class _MBConvConfig:
    expand_ratio: float
    kernel: int
    stride: int
    input_channels: int
    out_channels: int
    num_layers: int
    block: Callable[..., nn.Module]

    @staticmethod
    def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int:
        return _make_divisible(channels * width_mult, 8, min_value)


class MBConvConfig(_MBConvConfig):
    # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
        width_mult: float = 1.0,
        depth_mult: float = 1.0,
        block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        input_channels = self.adjust_channels(input_channels, width_mult)
        out_channels = self.adjust_channels(out_channels, width_mult)
        num_layers = self.adjust_depth(num_layers, depth_mult)
        if block is None:
            block = MBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)

    @staticmethod
    def adjust_depth(num_layers: int, depth_mult: float):
        return int(math.ceil(num_layers * depth_mult))


In [4]:
class FusedMBConvConfig(_MBConvConfig):
    # Stores information listed at Table 4 of the EfficientNetV2 paper
    def __init__(
        self,
        expand_ratio: float,
        kernel: int,
        stride: int,
        input_channels: int,
        out_channels: int,
        num_layers: int,
        block: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        if block is None:
            block = FusedMBConv
        super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block)


class MBConv(nn.Module):
    def __init__(
        self,
        cnf: MBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
        se_layer: Callable[..., nn.Module] = SqueezeExcitation,
    ) -> None:
        super().__init__()

        if not (1 <= cnf.stride <= 2):
            raise ValueError("illegal stride value")

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        # expand
        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
            layers.append(
                Conv2dNormActivation(
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=1,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

        # depthwise
        layers.append(
            Conv2dNormActivation(
                expanded_channels,
                expanded_channels,
                kernel_size=cnf.kernel,
                stride=cnf.stride,
                groups=expanded_channels,
                norm_layer=norm_layer,
                activation_layer=activation_layer,
            )
        )

        # squeeze and excitation
        squeeze_channels = max(1, cnf.input_channels // 4)
        layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True)))

        # project
        layers.append(
            Conv2dNormActivation(
                expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
            )
        )

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result


class FusedMBConv(nn.Module):
    def __init__(
        self,
        cnf: FusedMBConvConfig,
        stochastic_depth_prob: float,
        norm_layer: Callable[..., nn.Module],
    ) -> None:
        super().__init__()

        if not (1 <= cnf.stride <= 2):
            raise ValueError("illegal stride value")

        self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels

        layers: List[nn.Module] = []
        activation_layer = nn.SiLU

        expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio)
        if expanded_channels != cnf.input_channels:
            # fused expand
            layers.append(
                Conv2dNormActivation(
                    cnf.input_channels,
                    expanded_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

            # project
            layers.append(
                Conv2dNormActivation(
                    expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None
                )
            )
        else:
            layers.append(
                Conv2dNormActivation(
                    cnf.input_channels,
                    cnf.out_channels,
                    kernel_size=cnf.kernel,
                    stride=cnf.stride,
                    norm_layer=norm_layer,
                    activation_layer=activation_layer,
                )
            )

        self.block = nn.Sequential(*layers)
        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")
        self.out_channels = cnf.out_channels

    def forward(self, input: Tensor) -> Tensor:
        result = self.block(input)
        if self.use_res_connect:
            result = self.stochastic_depth(result)
            result += input
        return result

In [5]:
class EfficientNet(nn.Module):
    def __init__(
        self,
        in_channels,
        inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
        stochastic_depth_prob: float = 0.2,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        last_channel: Optional[int] = None,
    ) -> None:
        """
        EfficientNet V1 and V2 main class

        Args:
            inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure
            stochastic_depth_prob (float): The stochastic depth probability
            num_classes (int): Number of classes
            norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use
            last_channel (int): The number of channels on the penultimate layer
        """
        super().__init__()

        if not inverted_residual_setting:
            raise ValueError("The inverted_residual_setting should not be empty")
        elif not (
            isinstance(inverted_residual_setting, Sequence)
            and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting])
        ):
            raise TypeError("The inverted_residual_setting should be List[MBConvConfig]")

        if norm_layer is None:
            norm_layer = nn.BatchNorm2d

        layers: List[nn.Module] = []

        # building first layer
        firstconv_output_channels = inverted_residual_setting[0].input_channels
        layers.append(
            Conv2dNormActivation(
                in_channels, firstconv_output_channels, kernel_size=3, stride=1, norm_layer=norm_layer, activation_layer=nn.SiLU
            )
        )

        # building inverted residual blocks
        total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting)
        stage_block_id = 0
        for cnf in inverted_residual_setting:
            stage: List[nn.Module] = []
            for _ in range(cnf.num_layers):
                # copy to avoid modifications. shallow copy is enough
                block_cnf = copy.copy(cnf)

                # overwrite info if not the first conv in the stage
                if stage:
                    block_cnf.input_channels = block_cnf.out_channels
                    block_cnf.stride = 1

                # adjust stochastic depth probability based on the depth of the stage block
                sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks

                stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer))
                stage_block_id += 1

            layers.append(nn.Sequential(*stage))

        # building last several layers
        lastconv_input_channels = inverted_residual_setting[-1].out_channels
        lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels
        layers.append(
            Conv2dNormActivation(
                lastconv_input_channels,
                lastconv_output_channels,
                kernel_size=1,
                norm_layer=norm_layer,
                activation_layer=nn.SiLU,
            )
        )

        self.features = nn.Sequential(*layers)
        

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.Linear):
                init_range = 1.0 / math.sqrt(m.out_features)
                nn.init.uniform_(m.weight, -init_range, init_range)
                nn.init.zeros_(m.bias)

    def _forward_impl(self, x: Tensor) -> Tensor:
        x = self.features(x)


        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

In [6]:
def _efficientnet(
    inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]],
    dropout: float,
    last_channel: Optional[int],
    progress: bool,
    **kwargs: Any,
) -> EfficientNet:

    model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs)

    return model


In [7]:
# expand_ratio: float,
# kernel: int,
# stride: int,
# input_channels: int,
# out_channels: int,
# num_layers: int,

In [8]:
class_name = "FusedMBConvConfig"  # Example string representation of a class

class_obj = globals().get(class_name)
if class_obj:
    print(class_obj)
    # Output: <class '__main__.FusedMBConvConfig'>
else:
    print(f"Class '{class_name}' not found.")


<class '__main__.FusedMBConvConfig'>


In [9]:

def create_inverted_residual_setting(config_dict):
    inverted_residual_setting = []
    for config_key, config_value in config_dict.items():
        config = globals().get(config_value['class_name'])
        print(config)
        if config:
            args = config_value['args']
            kwargs = config_value['kwargs']
            inverted_residual_setting.append(config(*args, **kwargs))
        else:
            print(f"Class '{config_value['class_name']}' not found.")
    return inverted_residual_setting


In [10]:
#expand_ratio, kernel, stride, input_channels, out_channels, num_layers
inverted_residual_setting = [
    FusedMBConvConfig(1, 3, 1, 24, 24, 1),
    FusedMBConvConfig(2, 3, 2, 24, 48, 1),
    FusedMBConvConfig(2, 3, 1, 48, 96, 1),
    FusedMBConvConfig(2, 3, 1, 96, 128, 2)
]

In [15]:
test_tensor = torch.randn(1, 1, 100, 100)

In [16]:
model = EfficientNet(1, inverted_residual_setting)

In [17]:
model(test_tensor).shape

torch.Size([1, 512, 50, 50])

In [18]:
print(model)

EfficientNet(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
            (1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      )
    )
    (2): Sequential(
      (0): FusedMBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(24, 48, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
            (1): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, t

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import numpy as np
import torch.nn.functional as F
import pdb
import random

from ..base_model import BaseModel
from ..modules import SetBlockWrapper, HorizontalPoolingPyramid, PackSequenceWrapper, SeparateFCs, SeparateBNNecks, conv1d, mlp_sigmoid, conv_bn, SetBlock, MSTE, ATA, SSFL, BasicConv1d, BasicConv2d, TemporalShift



class Cell_Model(torch.nn.Module):
    def __init__(self, inverted_residual_setting):
        super(Cell_Model, self).__init__()
        self.backbone = EfficientNet(3, inverted_residual_setting)
        # backbone

    def forward(self, inputs):
        ipts, labs, _, _, seqL = inputs
        sils = ipts[0] #n, c, s, h, w (in cstl paper: B x N x H x W, - batch, num of frames, height, width)
        
        sils = sils.permute(0, 2, 1, 3, 4).contiguous() #n, s, c, h, w
        n, s, _, _, _ = sils.size()
        del ipts
        if len(sils.size()) == 4:
            sils = sils.unsqueeze(2)
        x = self.Backbone(sils) #n, s, c, h, w
        if len(x.size()) == 4:
            ns, c, h, w = x.size()
            x = x.view(n, s, c, h, w)
        x = x.max(-1)[0] + x.mean(-1) #Global Max pooling + Global Average Pooling
        #n, s, c, k (K: as in paper: the number of horizontal division feature parts that correspond to body parts in some extent.)
        t_f, t_s, t_l = self.multi_scale(x) #n, s, c, K
        aggregated_feature = self.adaptive_aggregation(t_f, t_s, t_l) #K, n, c
        part_classification, weighted_part_feature, selected_part_feature = self.salient_learning(t_f, t_s, t_l) #K, n, c
        feature = torch.cat([aggregated_feature, weighted_part_feature, selected_part_feature], -1)
        feature = feature.matmul(self.FCs) #n, c, p
        feature = feature.permute(1, 2, 0).contiguous() #p, n, c
        n, s, c, h, w = sils.size()
        retval = {
            'training_feat': {
                'triplet': {'embeddings': feature, 'labels': labs},
                'cstl_cross_entropy': {'part_prob': part_classification, 'label': labs}
            },
            'visual_summary': {
                'image/sils': sils.view(n*s, c, h, w)
            },
            'inference_feat': {
                'embeddings': feature
            }
        }
        return retval
        #return feature, part_classification.permute(1,0,2).contiguous()


In [19]:
x = torch.randn(1, 30, 3, 100, 100)

In [20]:
x.max(-1)[0].shape

torch.Size([1, 30, 3, 100])