In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import timm
import torch
from unittest.mock import patch

from src.model.swin_transformer_v2_pseudo_3d import SwinTransformerV2Pseudo3d, map_pretrained_2d_to_pseudo_3d
from src.model.smp import Unet
from src.utils.utils import FeatureExtractorWrapper, get_num_layers, get_feature_channels

In [20]:
model_2d = timm.create_model(
    'swinv2_tiny_window8_256.ms_in1k', 
    features_only=True,
    pretrained=True,
)
x = torch.randn(1, 3, 256, 256)
y = model_2d(x)

In [4]:
with patch('timm.models.swin_transformer_v2.SwinTransformerV2', SwinTransformerV2Pseudo3d):
    model_pseudo_3d = timm.create_model(
        'swinv2_tiny_window8_256.ms_in1k', 
        features_only=True,
        pretrained=False,
        window_size=(8, 8, 16),
        img_size=(256, 256, 64),
    )
x = torch.randn(1, 3, 256, 256, 64)
y = model_pseudo_3d(x)

In [5]:
model_2d_state_dict = model_2d.state_dict()
model_pseudo_3d_state_dict = model_pseudo_3d.state_dict()
for key, value in model_2d_state_dict.items():
    if key in model_pseudo_3d_state_dict:
        if value.shape == model_pseudo_3d_state_dict[key].shape:
            print(f'{key}: {value.shape} -> OK')
        else:
            print(f'{key}: {value.shape} -> {model_pseudo_3d_state_dict[key].shape}')
    else:
        print(f'{key}: {value.shape} -> NOT FOUND')

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
patch_embed.proj.bias: torch.Size([96]) -> OK
patch_embed.norm.weight: torch.Size([96]) -> OK
patch_embed.norm.bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.logit_scale: torch.Size([3, 1, 1]) -> OK
layers_0.blocks.0.attn.q_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.v_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])
layers_0.blocks.0.attn.cpb_mlp.0.bias: torch.Size([512]) -> OK
layers_0.blocks.0.attn.cpb_mlp.2.weight: torch.Size([3, 512]) -> OK
layers_0.blocks.0.attn.qkv.weight: torch.Size([288, 96]) -> OK
layers_0.blocks.0.attn.proj.weight: torch.Size([96, 96]) -> OK
layers_0.blocks.0.attn.proj.bias: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.weight: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.bias: torch.Size([96]) -> OK
layers_0.blocks.0.mlp.fc1.weight: torch.Size([384, 96]) -> OK
layers_0.blocks.0.mlp.fc1.bias: tor

No-matches are `patch_embed.proj` (Conv2d -> Conv3d) and `layers.0.blocks.0.attn.cpb_mlp.0` (relative position bias mapping MLP for Z dim) layers' weights and biases, algthough biases shapes match. 

- Conv layer's weight: `torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])`

- MLP's weight: `torch.Size([512, 2]) -> torch.Size([512, 3])`

For conv layer proposal is to repeat weights along 3rd dimension and scale them down by patch size along Z dim (4) and keep bias term intact. E. g. if the image is just repeated along Z dim, then the 3D patch embedding in such case will be equal to 2D patch embedding of non-repeated patch.

For relative position bias proposal is to calculate weights for new dimention as mean of weights of previous two and keep the bias intact. No invariancy for that case.

**Note**: it needs additional investigation whether low-rank of the obtained weights is a problem.

In [6]:
model = map_pretrained_2d_to_pseudo_3d(model_2d, model_pseudo_3d)

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])


In [28]:
x = torch.randn(1, 3, 256, 256, 64)
y = model(x)
[y_.shape for y_ in y]

[torch.Size([1, 64, 64, 96]),
 torch.Size([1, 32, 32, 192]),
 torch.Size([1, 16, 16, 384]),
 torch.Size([1, 8, 8, 768])]

In [61]:
get_num_layers(model), get_num_layers(FeatureExtractorWrapper(model))

(4, 4)

In [63]:
get_feature_channels(model, (3, 256, 256, 64)), \
get_feature_channels(FeatureExtractorWrapper(model), (3, 256, 256, 64))

((96, 192, 384, 768), (96, 192, 384, 768))

In [173]:
get_feature_channels(FeatureExtractorWrapper(model), input_shape=(3, 256, 256, 64))[::-1]

(768, 384, 192, 96)

In [184]:
unet = Unet(
    encoder=FeatureExtractorWrapper(model),
    encoder_channels=get_feature_channels(model, input_shape=(3, 256, 256, 64)),
    classes=2,
)

In [185]:
unet

Unet(
  (encoder): FeatureExtractorWrapper(
    (model): FeatureListNet(
      (patch_embed): PatchEmbedPseudo3d(
        (proj): Conv3d(3, 96, kernel_size=(4, 4, 4), stride=(4, 4, 4))
        (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      )
      (layers_0): SwinTransformerV2StagePseudo3d(
        (downsample): Identity()
        (blocks): ModuleList(
          (0): SwinTransformerV2BlockPseudo3d(
            (attn): WindowAttentionPseudo3d(
              (cpb_mlp): Sequential(
                (0): Linear(in_features=3, out_features=512, bias=True)
                (1): ReLU(inplace=True)
                (2): Linear(in_features=512, out_features=3, bias=False)
              )
              (qkv): Linear(in_features=96, out_features=288, bias=False)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(d

In [186]:
x = torch.randn(1, 3, 256, 256, 64)
y = unet(x)
y.shape

torch.Size([1, 2, 256, 256])