In [2]:
import torch 
import torch.nn as nn 

from typing import Callable

In [4]:
def num_groups(group_size: int | None, channels: int):
    if not group_size:  # 0 or None
        return 1  # normal conv with 1 group
    else:
        # NOTE group_size == 1 -> depthwise conv
        assert channels % group_size == 0
        return channels // group_size


def make_divisible(v: int, divisor: int = 8, min_value: int | None = None, round_limit: float = .9):
    min_value = min_value or divisor
    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
    # Make sure that round down does not go down by more than 10%.
    if new_v < round_limit * v:
        new_v += divisor
    return new_v

In [6]:
class SqueezeExcite(nn.Module):
    """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family

    Args:
        in_chs (int): input channels to layer
        rd_ratio (float): ratio of squeeze reduction
        act_layer (nn.Module): activation layer of containing block
        gate_layer (Callable): attention gate function
        force_act_layer (nn.Module): override block's activation fn if this is set/bound
        rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
    """

    def __init__(
            self,
            in_chs: int,
            rd_ratio: float = 0.25,
            rd_channels: int | None = None,
            act_layer: Callable = nn.ReLU,
            gate_layer: Callable = nn.Sigmoid,
            force_act_layer: Callable | None = None,
            rd_round_fn: Callable | None = None,
    ):
        super(SqueezeExcite, self).__init__()
        if rd_channels is None:
            rd_round_fn = rd_round_fn or round
            rd_channels = rd_round_fn(in_chs * rd_ratio)
        act_layer = force_act_layer or act_layer
        self.conv_reduce = nn.Conv3d(in_chs, rd_channels, 1, bias=True)
        self.act1 = act_layer()
        self.conv_expand = nn.Conv3d(rd_channels, in_chs, 1, bias=True)
        self.gate = gate_layer()

    def forward(self, x):
        x_se = x.mean((2, 3, 4), keepdim=True)
        x_se = self.conv_reduce(x_se)
        x_se = self.act1(x_se)
        x_se = self.conv_expand(x_se)
        return x * self.gate(x_se)

In [7]:
class LayerScale3d(nn.Module):
    def __init__(self, dim: int, init_values: float = 1e-5, inplace: bool = False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        gamma = self.gamma.view(1, -1, 1, 1, 1)
        return x.mul_(gamma) if self.inplace else x * gamma

In [27]:
class UniversalInvertedResidual3d(nn.Module):

    def __init__(
        self,
        in_chs: int, 
        out_chs: int, 
        dw_kernel_size_start: int = 0,
        dw_kernel_size_mid: int = 3, 
        dw_kernel_size_end: int = 0,
        stride: int | tuple[int, int, int] = (2, 1, 1), 
        dilation: int = 1, 
        padding: int = 1, 
        group_size: int = 1,  
        noskip: bool = False, 
        exp_ratio: float = 1.0,
        act_layer: Callable = nn.ReLU, 
        norm_layer: Callable = nn.BatchNorm3d, 
        se_layer: bool = False,
        conv_kwargs: dict | None = None, 
        layer_scale_init_value: float | None = 1e-5,
    ):

        super().__init__()
        conv_kwargs = conv_kwargs or {}
        if isinstance(stride, int):
            stride = (stride, stride, stride)
        self.has_skip = (in_chs == out_chs and stride[0] in {1, 2}) and not noskip
        if stride > 1: 
            assert dw_kernel_size_start or dw_kernel_size_mid or dw_kernel_size_end
        
        if dw_kernel_size_start:
            dw_start_stride = stride if not dw_kernel_size_mid else 1 
            dw_start_groups = num_groups(group_size, in_chs)
            self.dw_start = nn.Sequential(
                nn.Conv3d(in_chs, in_chs, kernel_size=dw_kernel_size_start, stride=dw_start_stride, dilation=dilation, padding=padding, groups=dw_start_groups, bias=False),
                norm_layer(in_chs),
                # no activation
            )
        else:
            self.dw_start = nn.Identity()
        
        mid_chs = make_divisible(in_chs * exp_ratio)
        self.pw_exp = nn.Sequential(
            nn.Conv3d(in_chs, mid_chs, 1, bias=False),
            norm_layer(mid_chs),
            act_layer(),
        )

        if dw_kernel_size_mid:
            groups = num_groups(group_size, mid_chs)
            self.dw_mid = nn.Sequential(
                nn.Conv3d(mid_chs, mid_chs, dw_kernel_size_mid, stride=stride, dilation=dilation, padding=padding, groups=groups, bias=False),
                norm_layer(mid_chs),
                act_layer(),
            )
        
        self.se = SqueezeExcite(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()

        self.pw_proj = nn.Sequential(
            nn.Conv3d(mid_chs, out_chs, 1, bias=False),
            norm_layer(out_chs),
            # no activation
        )

        if dw_kernel_size_end:
            dw_end_stride = stride if not dw_kernel_size_start and not dw_kernel_size_mid else 1
            dw_end_groups = num_groups(group_size, out_chs)
            self.dw_end = nn.Sequential(
                nn.Conv3d(out_chs, out_chs, dw_kernel_size_end, stride=dw_end_stride, dilation=dilation, padding=padding, groups=dw_end_groups, bias=False),
                norm_layer(out_chs),
                act_layer(),
            )
        else:
            self.dw_end = nn.Identity()
        
        if layer_scale_init_value is not None:
            self.layer_scale = LayerScale3d(out_chs, layer_scale_init_value)
        else:
            self.layer_scale = nn.Identity()
        
        if stride == 2 and self.has_skip:
            self.downsample = nn.Sequential(
                nn.AvgPool3d(stride, stride), 
                nn.Conv3d(in_chs, out_chs, 1, padding=0, bias=False),
                norm_layer(out_chs),
            )
        else:
            self.downsample = nn.Identity()

    def forward(self, x):
        shortcut = x
        x = self.dw_start(x)
        x = self.pw_exp(x)
        x = self.dw_mid(x)
        x = self.se(x)
        x = self.pw_proj(x)
        x = self.dw_end(x)
        x = self.layer_scale(x)
        if self.has_skip:
            x = x + self.downsample(shortcut)
        return x

In [28]:
block = UniversalInvertedResidual3d(64, 64, 3, 3, 3, 2, 1, 1, exp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.BatchNorm3d, se_layer=True)
block

UniversalInvertedResidual3d(
  (dw_start): Sequential(
    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=64, bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (pw_exp): Sequential(
    (0): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
  )
  (dw_mid): Sequential(
    (0): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1), groups=256, bias=False)
    (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): GELU(approximate='none')
  )
  (se): SqueezeExcite(
    (conv_reduce): Conv3d(256, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (act1): GELU(approximate='none')
    (conv_expand): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(1, 1, 1))
    (gate): Sigmoid()
  )
  (pw_proj): Se

In [29]:
x = torch.randn((2, 64, 64, 64, 64))
block(x).shape

torch.Size([2, 64, 32, 32, 32])