In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import timm
import torch
from unittest.mock import patch

from src.model.swin_transformer_v2_pseudo_3d import SwinTransformerV2Pseudo3d, map_pretrained_2d_to_pseudo_3d

In [20]:
model_2d = timm.create_model(
    'swinv2_tiny_window8_256.ms_in1k', 
    features_only=True,
    pretrained=True,
)
x = torch.randn(1, 3, 256, 256)
y = model_2d(x)

In [4]:
with patch('timm.models.swin_transformer_v2.SwinTransformerV2', SwinTransformerV2Pseudo3d):
    model_pseudo_3d = timm.create_model(
        'swinv2_tiny_window8_256.ms_in1k', 
        features_only=True,
        pretrained=False,
        window_size=(8, 8, 16),
        img_size=(256, 256, 64),
    )
x = torch.randn(1, 3, 256, 256, 64)
y = model_pseudo_3d(x)

In [5]:
model_2d_state_dict = model_2d.state_dict()
model_pseudo_3d_state_dict = model_pseudo_3d.state_dict()
for key, value in model_2d_state_dict.items():
    if key in model_pseudo_3d_state_dict:
        if value.shape == model_pseudo_3d_state_dict[key].shape:
            print(f'{key}: {value.shape} -> OK')
        else:
            print(f'{key}: {value.shape} -> {model_pseudo_3d_state_dict[key].shape}')
    else:
        print(f'{key}: {value.shape} -> NOT FOUND')

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
patch_embed.proj.bias: torch.Size([96]) -> OK
patch_embed.norm.weight: torch.Size([96]) -> OK
patch_embed.norm.bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.logit_scale: torch.Size([3, 1, 1]) -> OK
layers_0.blocks.0.attn.q_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.v_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])
layers_0.blocks.0.attn.cpb_mlp.0.bias: torch.Size([512]) -> OK
layers_0.blocks.0.attn.cpb_mlp.2.weight: torch.Size([3, 512]) -> OK
layers_0.blocks.0.attn.qkv.weight: torch.Size([288, 96]) -> OK
layers_0.blocks.0.attn.proj.weight: torch.Size([96, 96]) -> OK
layers_0.blocks.0.attn.proj.bias: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.weight: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.bias: torch.Size([96]) -> OK
layers_0.blocks.0.mlp.fc1.weight: torch.Size([384, 96]) -> OK
layers_0.blocks.0.mlp.fc1.bias: tor

No-matches are `patch_embed.proj` (Conv2d -> Conv3d) and `layers.0.blocks.0.attn.cpb_mlp.0` (relative position bias mapping MLP for Z dim) layers' weights and biases, algthough biases shapes match. 

- Conv layer's weight: `torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])`

- MLP's weight: `torch.Size([512, 2]) -> torch.Size([512, 3])`

For conv layer proposal is to repeat weights along 3rd dimension and scale them down by patch size along Z dim (4) and keep bias term intact. E. g. if the image is just repeated along Z dim, then the 3D patch embedding in such case will be equal to 2D patch embedding of non-repeated patch.

For relative position bias proposal is to calculate weights for new dimention as mean of weights of previous two and keep the bias intact. No invariancy for that case.

**Note**: it needs additional investigation whether low-rank of the obtained weights is a problem.

In [6]:
model = map_pretrained_2d_to_pseudo_3d(model_2d, model_pseudo_3d)

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])


In [28]:

x = torch.randn(1, 3, 256, 256, 64)
y = model(x)
[y_.shape for y_ in y]

[torch.Size([1, 64, 64, 96]),
 torch.Size([1, 32, 32, 192]),
 torch.Size([1, 16, 16, 384]),
 torch.Size([1, 8, 8, 768])]

In [119]:
from timm.layers.format import nhwc_to, Format
from torch import nn

class FeatureExtractorWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.output_stride = 32

    def __iter__(self):
        return iter(self.model)
    
    def forward(self, x):
        return [nhwc_to(y, Format('NCHW')) for y in self.model(x)]

In [61]:
def get_num_layers(model):
    return len([key for key in model if 'layers' in key])
get_num_layers(model), get_num_layers(FeatureExtractorWrapper(model))

(4, 4)

In [63]:
def get_feature_channels(model, input_shape):
    is_training = model.training
    model.eval()
    x = torch.randn(1, *input_shape)
    y = model(x)
    if isinstance(model, FeatureExtractorWrapper):
        channel_index = 1
    else:
        channel_index = 3
    result = tuple(y_.shape[channel_index] for y_ in y)
    model.train(is_training)
    return result
get_feature_channels(model, (3, 256, 256, 64)), \
get_feature_channels(FeatureExtractorWrapper(model), (3, 256, 256, 64))

((96, 192, 384, 768), (96, 192, 384, 768))

In [148]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Union
try:
    from inplace_abn import InPlaceABN
except ImportError:
    InPlaceABN = None


def initialize_decoder(module):
    for m in module.modules():

        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_uniform_(m.weight, mode="fan_in", nonlinearity="relu")
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

        elif isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


def initialize_head(module):
    for m in module.modules():
        if isinstance(m, (nn.Linear, nn.Conv2d)):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)


class Conv2dReLU(nn.Sequential):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        padding=0,
        stride=1,
        use_batchnorm=True,
    ):

        if use_batchnorm == "inplace" and InPlaceABN is None:
            raise RuntimeError(
                "In order to use `use_batchnorm='inplace'` inplace_abn package must be installed. "
                + "To install see: https://github.com/mapillary/inplace_abn"
            )

        conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            bias=not (use_batchnorm),
        )
        relu = nn.ReLU(inplace=True)

        if use_batchnorm == "inplace":
            bn = InPlaceABN(out_channels, activation="leaky_relu", activation_param=0.0)
            relu = nn.Identity()

        elif use_batchnorm and use_batchnorm != "inplace":
            bn = nn.BatchNorm2d(out_channels)

        else:
            bn = nn.Identity()

        super(Conv2dReLU, self).__init__(conv, bn, relu)

class ArgMax(nn.Module):
    def __init__(self, dim=None):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        return torch.argmax(x, dim=self.dim)


class Clamp(nn.Module):
    def __init__(self, min=0, max=1):
        super().__init__()
        self.min, self.max = min, max

    def forward(self, x):
        return torch.clamp(x, self.min, self.max)


class Activation(nn.Module):
    def __init__(self, name, **params):

        super().__init__()

        if name is None or name == "identity":
            self.activation = nn.Identity(**params)
        elif name == "sigmoid":
            self.activation = nn.Sigmoid()
        elif name == "softmax2d":
            self.activation = nn.Softmax(dim=1, **params)
        elif name == "softmax":
            self.activation = nn.Softmax(**params)
        elif name == "logsoftmax":
            self.activation = nn.LogSoftmax(**params)
        elif name == "tanh":
            self.activation = nn.Tanh()
        elif name == "argmax":
            self.activation = ArgMax(**params)
        elif name == "argmax2d":
            self.activation = ArgMax(dim=1, **params)
        elif name == "clamp":
            self.activation = Clamp(**params)
        elif callable(name):
            self.activation = name(**params)
        else:
            raise ValueError(
                f"Activation should be callable/sigmoid/softmax/logsoftmax/tanh/"
                f"argmax/argmax2d/clamp/None; got {name}"
            )

    def forward(self, x):
        return self.activation(x)

class SCSEModule(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super().__init__()
        self.cSE = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction, in_channels, 1),
            nn.Sigmoid(),
        )
        self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())

    def forward(self, x):
        return x * self.cSE(x) + x * self.sSE(x)

class Attention(nn.Module):
    def __init__(self, name, **params):
        super().__init__()

        if name is None:
            self.attention = nn.Identity(**params)
        elif name == "scse":
            self.attention = SCSEModule(**params)
        else:
            raise ValueError("Attention {} is not implemented".format(name))

    def forward(self, x):
        return self.attention(x)


class SegmentationHead(nn.Sequential):
    def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
        conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
        upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
        activation = Activation(activation)
        super().__init__(conv2d, upsampling, activation)


class ClassificationHead(nn.Sequential):
    def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
        if pooling not in ("max", "avg"):
            raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
        pool = nn.AdaptiveAvgPool2d(1) if pooling == "avg" else nn.AdaptiveMaxPool2d(1)
        flatten = nn.Flatten()
        dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
        linear = nn.Linear(in_channels, classes, bias=True)
        activation = Activation(activation)
        super().__init__(pool, flatten, dropout, linear, activation)

class SegmentationModel(torch.nn.Module):
    def initialize(self):
        initialize_decoder(self.decoder)
        initialize_head(self.segmentation_head)
        if self.classification_head is not None:
            initialize_head(self.classification_head)

    def check_input_shape(self, x):
        h, w = x.shape[-2:]
        output_stride = self.encoder.output_stride
        if h % output_stride != 0 or w % output_stride != 0:
            new_h = (h // output_stride + 1) * output_stride if h % output_stride != 0 else h
            new_w = (w // output_stride + 1) * output_stride if w % output_stride != 0 else w
            raise RuntimeError(
                f"Wrong input shape height={h}, width={w}. Expected image height and width "
                f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})."
            )

    def forward(self, x):
        """Sequentially pass `x` trough model`s encoder, decoder and heads"""

        self.check_input_shape(x)

        features = self.encoder(x)
        decoder_output = self.decoder(*features)

        masks = self.segmentation_head(decoder_output)

        if self.classification_head is not None:
            labels = self.classification_head(features[-1])
            return masks, labels

        return masks

    @torch.no_grad()
    def predict(self, x):
        """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()`

        Args:
            x: 4D torch tensor with shape (batch_size, channels, height, width)

        Return:
            prediction: 4D torch tensor with shape (batch_size, classes, height, width)

        """
        if self.training:
            self.eval()

        x = self.forward(x)

        return x


class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels,
        out_channels,
        use_batchnorm=True,
        attention_type=None,
    ):
        super().__init__()
        self.conv1 = Conv2dReLU(
            in_channels + skip_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention1 = Attention(attention_type, in_channels=in_channels + skip_channels)
        self.conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        self.attention2 = Attention(attention_type, in_channels=out_channels)

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
            x = torch.cat([x, skip], dim=1)
            x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x


class CenterBlock(nn.Sequential):
    def __init__(self, in_channels, out_channels, use_batchnorm=True):
        conv1 = Conv2dReLU(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        conv2 = Conv2dReLU(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            use_batchnorm=use_batchnorm,
        )
        super().__init__(conv1, conv2)

class UnetDecoder(nn.Module):
    def __init__(
        self,
        encoder_channels,
        decoder_channels,
        n_blocks=5,
        use_batchnorm=True,
        attention_type=None,
        center=False,
    ):
        super().__init__()

        if n_blocks != len(decoder_channels):
            raise ValueError(
                "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format(
                    n_blocks, len(decoder_channels)
                )
            )

        # remove first skip with same spatial resolution
        encoder_channels = encoder_channels[1:]
        # reverse channels to start from head of encoder
        encoder_channels = encoder_channels[::-1]

        # computing blocks input and output channels
        head_channels = encoder_channels[0]
        in_channels = [head_channels] + list(decoder_channels[:-1])
        skip_channels = list(encoder_channels[1:]) + [0]
        out_channels = decoder_channels
        
        if center:
            self.center = CenterBlock(head_channels, head_channels, use_batchnorm=use_batchnorm)
        else:
            self.center = nn.Identity()

        # combine decoder keyword arguments
        kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type)
        blocks = [
            DecoderBlock(in_ch, skip_ch, out_ch, **kwargs)
            for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels)
        ]
        self.blocks = nn.ModuleList(blocks)

    def forward(self, *features):

        features = features[1:]  # remove first skip with same spatial resolution
        features = features[::-1]  # reverse channels to start from head of encoder

        head = features[0]
        skips = features[1:]

        x = self.center(head)
        for i, decoder_block in enumerate(self.blocks):
            skip = skips[i] if i < len(skips) else None
            x = decoder_block(x, skip)

        return x


class Unet(SegmentationModel):
    """Unet_ is a fully convolution neural network for image semantic segmentation. Consist of *encoder*
    and *decoder* parts connected with *skip connections*. Encoder extract features of different spatial
    resolution (skip connections) which are used by decoder to define accurate segmentation mask. Use *concatenation*
    for fusing decoder blocks with skip connections.

    Args:
        encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
            to extract features of different spatial resolution
        encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
            two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
            with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
            Default is 5
        encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
            other pretrained weights (see table with available weights for each encoder_name)
        decoder_channels: List of integers which specify **in_channels** parameter for convolutions used in decoder.
            Length of the list should be the same as **encoder_depth**
        decoder_use_batchnorm: If **True**, BatchNorm2d layer between Conv2D and Activation layers
            is used. If **"inplace"** InplaceABN will be used, allows to decrease memory consumption.
            Available options are **True, False, "inplace"**
        decoder_attention_type: Attention module used in decoder of the model. Available options are
            **None** and **scse** (https://arxiv.org/abs/1808.08127).
        in_channels: A number of input channels for the model, default is 3 (RGB images)
        classes: A number of classes for output mask (or you can think as a number of channels of output mask)
        activation: An activation function to apply after the final convolution layer.
            Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
                **callable** and **None**.
            Default is **None**
        aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
            on top of encoder if **aux_params** is not **None** (default). Supported params:
                - classes (int): A number of classes
                - pooling (str): One of "max", "avg". Default is "avg"
                - dropout (float): Dropout factor in [0, 1)
                - activation (str): An activation function to apply "sigmoid"/"softmax"
                    (could be **None** to return logits)

    Returns:
        ``torch.nn.Module``: Unet

    .. _Unet:
        https://arxiv.org/abs/1505.04597

    """

    def __init__(
        self,
        encoder,
        encoder_channels,
        decoder_use_batchnorm: bool = True,
        decoder_channels: List[int] | str = (256, 128, 64),
        decoder_attention_type: Optional[str] = None,
        classes: int = 1,
        activation: Optional[Union[str, callable]] = None,
        aux_params: Optional[dict] = None,
    ):
        super().__init__()

        self.encoder = encoder
        encoder_channels = list(encoder_channels)
        decoder_channels = list(decoder_channels)
        if decoder_channels == 'same':
            decoder_channels = encoder_channels

        self.decoder = UnetDecoder(
            encoder_channels=encoder_channels,
            decoder_channels=decoder_channels,
            n_blocks=len(decoder_channels),
            use_batchnorm=decoder_use_batchnorm,
            center=False,
            attention_type=decoder_attention_type,
        )
        self.segmentation_head = SegmentationHead(
            in_channels=decoder_channels[-1],
            out_channels=classes,
            activation=activation,
            kernel_size=3,
        )

        if aux_params is not None:
            self.classification_head = ClassificationHead(in_channels=self.encoder.out_channels[-1], **aux_params)
        else:
            self.classification_head = None

        self.initialize()


In [149]:
get_feature_channels(FeatureExtractorWrapper(model), input_shape=(3, 256, 256, 64))[::-1]

(768, 384, 192, 96)

In [150]:
unet = Unet(
    encoder=FeatureExtractorWrapper(model),
    input_shape=get_feature_channels(model, input_shape=(3, 256, 256, 64)),
)

In [151]:
unet

Unet(
  (encoder): FeatureExtractorWrapper(
    (model): FeatureListNet(
      (patch_embed): PatchEmbedPseudo3d(
        (proj): Conv3d(3, 96, kernel_size=(4, 4, 4), stride=(4, 4, 4))
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (layers_0): SwinTransformerV2StagePseudo3d(
        (downsample): Identity()
        (blocks): ModuleList(
          (0): SwinTransformerV2BlockPseudo3d(
            (attn): WindowAttentionPseudo3d(
              (cpb_mlp): Sequential(
                (0): Linear(in_features=3, out_features=512, bias=True)
                (1): ReLU(inplace=True)
                (2): Linear(in_features=512, out_features=3, bias=False)
              )
              (qkv): Linear(in_features=96, out_features=288, bias=False)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(d

In [152]:
x = torch.randn(1, 3, 256, 256, 64)
y = unet(x)
y.shape

torch.Size([1, 1, 64, 64])