In [1]:
%load_ext autoreload
%autoreload 2

In [63]:
import timm
import torch
from pprint import pprint
from unittest.mock import patch
from acsconv.models.convnext import convnext_small
from acsconv.operators import ACSConv

from src.model.swin_transformer_v2_pseudo_3d import SwinTransformerV2Pseudo3d, map_pretrained_2d_to_pseudo_3d
from src.model.smp import Unet, patch_first_conv
from src.model.unet_3d_acs import AcsConvnextWrapper, UNet3dAcs, ACSConverterTimm
from src.utils.utils import FeatureExtractorWrapper, get_num_layers, get_feature_channels
from src.data.datamodules import SurfaceVolumeDatamodule
from src.model.modules import backbone_name_to_params
from src.utils.lr_scheduler import PiecewiceFactorsLRScheduler, LinearLRSchedulerPiece, CosineLRSchedulerPiece

# Pseudo 3D

In [3]:
model_2d = timm.create_model(
    'swinv2_tiny_window8_256.ms_in1k', 
    features_only=True,
    pretrained=True,
    in_channels=1,
)
x = torch.randn(1, 3, 256, 256)
y = model_2d(x)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
with patch('timm.models.swin_transformer_v2.SwinTransformerV2', SwinTransformerV2Pseudo3d):
    model_pseudo_3d = timm.create_model(
        'swinv2_tiny_window8_256.ms_in1k', 
        features_only=True,
        pretrained=False,
        window_size=(8, 8, 16),
        img_size=(256, 256, 64),
    )
x = torch.randn(1, 3, 256, 256, 64)
y = model_pseudo_3d(x)

In [5]:
model_2d_state_dict = model_2d.state_dict()
model_pseudo_3d_state_dict = model_pseudo_3d.state_dict()
for key, value in model_2d_state_dict.items():
    if key in model_pseudo_3d_state_dict:
        if value.shape == model_pseudo_3d_state_dict[key].shape:
            print(f'{key}: {value.shape} -> OK')
        else:
            print(f'{key}: {value.shape} -> {model_pseudo_3d_state_dict[key].shape}')
    else:
        print(f'{key}: {value.shape} -> NOT FOUND')

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
patch_embed.proj.bias: torch.Size([96]) -> OK
patch_embed.norm.weight: torch.Size([96]) -> OK
patch_embed.norm.bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.logit_scale: torch.Size([3, 1, 1]) -> OK
layers_0.blocks.0.attn.q_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.v_bias: torch.Size([96]) -> OK
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])
layers_0.blocks.0.attn.cpb_mlp.0.bias: torch.Size([512]) -> OK
layers_0.blocks.0.attn.cpb_mlp.2.weight: torch.Size([3, 512]) -> OK
layers_0.blocks.0.attn.qkv.weight: torch.Size([288, 96]) -> OK
layers_0.blocks.0.attn.proj.weight: torch.Size([96, 96]) -> OK
layers_0.blocks.0.attn.proj.bias: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.weight: torch.Size([96]) -> OK
layers_0.blocks.0.norm1.bias: torch.Size([96]) -> OK
layers_0.blocks.0.mlp.fc1.weight: torch.Size([384, 96]) -> OK
layers_0.blocks.0.mlp.fc1.bias: tor

No-matches are `patch_embed.proj` (Conv2d -> Conv3d) and `layers.0.blocks.0.attn.cpb_mlp.0` (relative position bias mapping MLP for Z dim) layers' weights and biases, algthough biases shapes match. 

- Conv layer's weight: `torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])`

- MLP's weight: `torch.Size([512, 2]) -> torch.Size([512, 3])`

For conv layer proposal is to repeat weights along 3rd dimension and scale them down by patch size along Z dim (4) and keep bias term intact. E. g. if the image is just repeated along Z dim, then the 3D patch embedding in such case will be equal to 2D patch embedding of non-repeated patch.

For relative position bias proposal is to calculate weights for new dimention as mean of weights of previous two and keep the bias intact. No invariancy for that case.

**Note**: it needs additional investigation whether low-rank of the obtained weights is a problem.

In [6]:
model = map_pretrained_2d_to_pseudo_3d(model_2d, model_pseudo_3d)

patch_embed.proj.weight: torch.Size([96, 3, 4, 4]) -> torch.Size([96, 3, 4, 4, 4])
layers_0.blocks.0.attn.cpb_mlp.0.weight: torch.Size([512, 2]) -> torch.Size([512, 3])


In [11]:
pprint([name for name, p in model.named_parameters()])

['patch_embed.proj.weight',
 'patch_embed.proj.bias',
 'patch_embed.norm.weight',
 'patch_embed.norm.bias',
 'layers_0.blocks.0.attn.logit_scale',
 'layers_0.blocks.0.attn.q_bias',
 'layers_0.blocks.0.attn.v_bias',
 'layers_0.blocks.0.attn.cpb_mlp.0.weight',
 'layers_0.blocks.0.attn.cpb_mlp.0.bias',
 'layers_0.blocks.0.attn.cpb_mlp.2.weight',
 'layers_0.blocks.0.attn.qkv.weight',
 'layers_0.blocks.0.attn.proj.weight',
 'layers_0.blocks.0.attn.proj.bias',
 'layers_0.blocks.0.norm1.weight',
 'layers_0.blocks.0.norm1.bias',
 'layers_0.blocks.0.mlp.fc1.weight',
 'layers_0.blocks.0.mlp.fc1.bias',
 'layers_0.blocks.0.mlp.fc2.weight',
 'layers_0.blocks.0.mlp.fc2.bias',
 'layers_0.blocks.0.norm2.weight',
 'layers_0.blocks.0.norm2.bias',
 'layers_0.blocks.1.attn.logit_scale',
 'layers_0.blocks.1.attn.q_bias',
 'layers_0.blocks.1.attn.v_bias',
 'layers_0.blocks.1.attn.cpb_mlp.0.weight',
 'layers_0.blocks.1.attn.cpb_mlp.0.bias',
 'layers_0.blocks.1.attn.cpb_mlp.2.weight',
 'layers_0.blocks.1.attn

In [28]:
x = torch.randn(1, 3, 256, 256, 64)
y = model(x)
[y_.shape for y_ in y]

[torch.Size([1, 64, 64, 96]),
 torch.Size([1, 32, 32, 192]),
 torch.Size([1, 16, 16, 384]),
 torch.Size([1, 8, 8, 768])]

In [61]:
get_num_layers(model), get_num_layers(FeatureExtractorWrapper(model))

(4, 4)

In [63]:
get_feature_channels(model, (3, 256, 256, 64)), \
get_feature_channels(FeatureExtractorWrapper(model), (3, 256, 256, 64))

((96, 192, 384, 768), (96, 192, 384, 768))

In [173]:
get_feature_channels(FeatureExtractorWrapper(model), input_shape=(3, 256, 256, 64))[::-1]

(768, 384, 192, 96)

In [14]:
unet = Unet(
    encoder=FeatureExtractorWrapper(model),
    encoder_channels=get_feature_channels(model, input_shape=(3, 256, 256, 64)),
    classes=2,
    upsampling=4,
)

In [16]:
pprint([name for name, p in unet.named_parameters()])

['encoder.model.patch_embed.proj.weight',
 'encoder.model.patch_embed.proj.bias',
 'encoder.model.patch_embed.norm.weight',
 'encoder.model.patch_embed.norm.bias',
 'encoder.model.layers_0.blocks.0.attn.logit_scale',
 'encoder.model.layers_0.blocks.0.attn.q_bias',
 'encoder.model.layers_0.blocks.0.attn.v_bias',
 'encoder.model.layers_0.blocks.0.attn.cpb_mlp.0.weight',
 'encoder.model.layers_0.blocks.0.attn.cpb_mlp.0.bias',
 'encoder.model.layers_0.blocks.0.attn.cpb_mlp.2.weight',
 'encoder.model.layers_0.blocks.0.attn.qkv.weight',
 'encoder.model.layers_0.blocks.0.attn.proj.weight',
 'encoder.model.layers_0.blocks.0.attn.proj.bias',
 'encoder.model.layers_0.blocks.0.norm1.weight',
 'encoder.model.layers_0.blocks.0.norm1.bias',
 'encoder.model.layers_0.blocks.0.mlp.fc1.weight',
 'encoder.model.layers_0.blocks.0.mlp.fc1.bias',
 'encoder.model.layers_0.blocks.0.mlp.fc2.weight',
 'encoder.model.layers_0.blocks.0.mlp.fc2.bias',
 'encoder.model.layers_0.blocks.0.norm2.weight',
 'encoder.mode

In [186]:
x = torch.randn(1, 3, 256, 256, 64)
y = unet(x)
y.shape

torch.Size([1, 2, 256, 256])

# ACS

### Shapes

In [14]:
encoder = convnext_small(in_chans=3, pretrained=True)
x = torch.randn(1, 3, 64, 256, 256)
y = encoder(x)
y.shape

torch.Size([1, 1000])

In [15]:
y

tensor([[ 3.7097e-03, -8.1073e-02,  4.6252e-01,  2.4873e-01,  4.7153e-01,
          2.9299e-01,  1.8045e-01, -1.7074e-01, -6.1954e-02, -4.5682e-03,
          2.8239e-02,  1.2265e-01,  6.0563e-02,  2.2887e-01,  1.2799e-01,
         -5.6556e-02, -7.1557e-03, -1.0235e-01, -9.4845e-02,  1.4535e-01,
          1.9142e-01,  2.9389e-01, -2.9258e-02,  1.2216e-01,  7.8282e-02,
         -1.8379e-01,  1.9690e-01,  4.8824e-03,  2.8430e-02,  2.6966e-01,
          1.4927e-01,  2.4046e-01,  2.0935e-01,  6.4321e-02, -1.2264e-01,
         -1.6131e-01,  1.8348e-01,  2.9164e-02,  1.7949e-01,  1.0048e-01,
          1.6094e-01,  4.1961e-02, -6.7309e-02, -5.5643e-02, -1.0218e-01,
         -2.0490e-01,  1.1845e-01,  2.3018e-01, -1.2406e-02, -1.5821e-01,
          5.8660e-02,  1.4072e-02,  1.4914e-01,  7.7491e-03,  1.5666e-01,
          2.9254e-01,  4.1319e-02, -7.7670e-03,  1.6499e-01,  2.4006e-01,
          2.1093e-01,  4.5539e-03, -1.6025e-02, -6.4127e-02,  1.5546e-01,
         -6.7125e-02,  1.4844e-01, -5.

In [16]:
patch_first_conv(
    encoder, 
    new_in_channels=1,
    default_in_channels=3, 
    pretrained=True,
    conv_type=ACSConv,
)
x = torch.randn(1, 1, 64, 256, 256)
y = encoder(x)
y.shape

torch.Size([1, 1000])

In [17]:
y

tensor([[-8.9908e-01, -6.5073e-01,  1.9427e-01,  4.3452e-01, -2.8106e-01,
          3.4318e-01, -8.6568e-02, -2.7676e-01,  2.6204e-02, -3.6346e-01,
         -7.5107e-01,  5.8831e-01, -4.1110e-01,  2.0318e-01, -9.2363e-01,
         -1.7910e-01, -3.5732e-01, -6.9077e-01,  3.4564e-01, -7.9706e-01,
          2.5459e-01, -6.7878e-01, -1.9979e-02, -2.4043e-01, -7.7108e-04,
          2.3837e-01,  1.1973e+00,  2.3980e-01,  2.3720e-01,  5.7646e-01,
         -3.7150e-01,  4.0539e-01,  5.6372e-02,  6.8230e-01,  1.3508e+00,
          2.4499e-01,  3.5656e-01, -4.2827e-01,  1.0896e+00,  2.5458e-01,
          8.6731e-01, -1.3666e-02, -9.2752e-02, -4.7640e-02,  8.0678e-01,
         -4.4065e-01,  1.2023e+00,  2.6664e-01,  8.7790e-03,  4.0668e-02,
         -3.6794e-01, -4.3836e-01,  9.1000e-01,  3.7706e-01,  4.3348e-01,
          2.6114e-01,  3.8635e-01, -1.3854e-01,  1.6246e-02,  7.2587e-01,
          1.0760e+00,  2.9684e-02, -8.4969e-02,  5.1115e-01, -2.4269e-01,
          6.6350e-01,  7.4615e-01,  4.

In [18]:
encoder_wrapped = AcsConvnextWrapper(encoder)
x = torch.randn(1, 1, 256, 256, 64)
y = encoder_wrapped(x)
[y_.shape for y_ in y]

[torch.Size([1, 96, 64, 64, 16]),
 torch.Size([1, 192, 32, 32, 8]),
 torch.Size([1, 384, 16, 16, 4]),
 torch.Size([1, 768, 8, 8, 2])]

In [19]:
y

[tensor([[[[[-6.6157e-01,  2.7441e+00,  1.5920e-01,  ...,  1.4625e+00,
              2.6964e+00,  3.8220e+00],
            [ 5.5644e+00,  1.4376e+00,  1.9432e+00,  ...,  3.5560e+00,
              1.8596e+00,  7.5981e-01],
            [ 1.4291e+00,  5.2879e-02,  3.2384e+00,  ...,  2.1880e+00,
              1.3040e+00,  4.1368e+00],
            ...,
            [ 2.6612e+00,  3.1980e+00,  4.3057e+00,  ...,  3.2451e+00,
             -2.7839e-01,  2.5487e+00],
            [ 7.6664e-01,  3.8418e+00, -3.1098e-02,  ...,  1.7310e+00,
              7.8505e-01, -6.2494e-02],
            [ 3.7845e+00, -6.2155e-01, -1.7732e-01,  ...,  4.8696e+00,
              3.5863e+00,  9.8867e-01]],
 
           [[ 1.0292e+00,  2.1494e+00,  2.6683e+00,  ...,  1.7510e+00,
              4.8057e+00,  1.4246e+00],
            [ 2.3860e+00, -9.7124e-01, -4.4606e-01,  ...,  2.2809e+00,
             -1.3427e+00,  1.4697e+00],
            [ 3.4688e+00, -9.8714e-01,  2.8996e+00,  ...,  3.9024e+00,
             -2.0115e

### Test gradients

In [None]:
backbone_name = 'convnext_small_1k'
in_channels = 32
lr = 1e-4
weight_decay = 1e-6
grad_accum_steps = 1
max_epochs = 16

In [3]:
datamodule = SurfaceVolumeDatamodule(
    surface_volume_dirs=[
        '/workspace/data/fragments_downscaled_2/train/1',
        '/workspace/data/fragments_downscaled_2/train/2',
        '/workspace/data/fragments_downscaled_2/train/3',
    ],
    surface_volume_dirs_test=None,	
    val_dir_indices=[2],
    crop_size=256,
    crop_size_z=32,
    img_size=256,
    img_size_z=32,
    resize_xy='crop',
    use_imagenet_stats=True,
    batch_size_full=None,
    batch_size_full_apply_epoch=None,
    batch_size=8,
    num_workers=8,
    pin_memory=False,
    prefetch_factor=2,
    persistent_workers=True,
)
datamodule.setup()
dataloader = datamodule.train_dataloader()

In [4]:
encoder = convnext_small(in_chans=3, pretrained=True)
patch_first_conv(
    encoder, 
    new_in_channels=1,
    default_in_channels=3, 
    pretrained=True,
    conv_type=ACSConv,
)
encoder = AcsConvnextWrapper(encoder)
model = UNet3dAcs(
    encoder=encoder,
    encoder_channels=get_feature_channels(
        encoder,
        input_shape=(1, *backbone_name_to_params[backbone_name]['img_size'], in_channels)
    ),
    decoder_mid_channels=backbone_name_to_params[backbone_name]['decoder_mid_channels'],
    decoder_out_channels=backbone_name_to_params[backbone_name]['decoder_out_channels'],
    classes=1,
    upsampling=backbone_name_to_params[backbone_name]['upsampling'],
).cuda()

In [5]:
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

total_steps = len(dataloader) * max_epochs
milestones = [0, 0.1, 1.0]
milestones = [
    int(milestone * total_steps / grad_accum_steps) 
    for milestone in milestones
]
scheduler = PiecewiceFactorsLRScheduler(
    optimizer,
    milestones=milestones,
    pieces=[
        LinearLRSchedulerPiece(1e-1, 1.0),
        CosineLRSchedulerPiece(1.0, 1e-2),
    ],
)

In [6]:
for i, batch in enumerate(dataloader):
    optimizer.zero_grad()
    x, y = batch['image'].cuda(), batch['mask_2'].cuda()
    y_hat = model(x)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(y_hat.squeeze(1).float().flatten(), y.float().flatten())
    loss.backward()
    optimizer.step()
    scheduler.step()
    print(loss.item(), torch.isnan(y_hat).any())
    if i >= 10:
        break

[(0, 25165824, torch.Size([8, 96, 8, 64, 64]), tensor(0, device='cuda:0')), (1, 6291456, torch.Size([8, 192, 4, 32, 32]), tensor(0, device='cuda:0')), (2, 1572864, torch.Size([8, 384, 2, 16, 16]), tensor(0, device='cuda:0')), (3, 393216, torch.Size([8, 768, 1, 8, 8]), tensor(0, device='cuda:0'))]
0.75528883934021 tensor(False, device='cuda:0')
[(0, 25165824, torch.Size([8, 96, 8, 64, 64]), tensor(0, device='cuda:0')), (1, 6291456, torch.Size([8, 192, 4, 32, 32]), tensor(0, device='cuda:0')), (2, 1572864, torch.Size([8, 384, 2, 16, 16]), tensor(0, device='cuda:0')), (3, 393216, torch.Size([8, 768, 1, 8, 8]), tensor(0, device='cuda:0'))]
0.7573579549789429 tensor(False, device='cuda:0')
[(0, 25165824, torch.Size([8, 96, 8, 64, 64]), tensor(0, device='cuda:0')), (1, 6291456, torch.Size([8, 192, 4, 32, 32]), tensor(0, device='cuda:0')), (2, 1572864, torch.Size([8, 384, 2, 16, 16]), tensor(0, device='cuda:0')), (3, 393216, torch.Size([8, 768, 1, 8, 8]), tensor(0, device='cuda:0'))]
0.715015

In [7]:
for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.numel(), param.grad.mean())

encoder.model.downsample_layers.0.0.weight 1536 tensor(4.1122e-06, device='cuda:0')
encoder.model.downsample_layers.0.0.bias 96 tensor(-3.8805e-11, device='cuda:0')
encoder.model.downsample_layers.0.1.weight 96 tensor(-0.0002, device='cuda:0')
encoder.model.downsample_layers.0.1.bias 96 tensor(-2.6401e-05, device='cuda:0')
encoder.model.downsample_layers.1.0.weight 96 tensor(-0.0004, device='cuda:0')
encoder.model.downsample_layers.1.0.bias 96 tensor(0.0003, device='cuda:0')
encoder.model.downsample_layers.1.1.weight 73728 tensor(2.7031e-07, device='cuda:0')
encoder.model.downsample_layers.1.1.bias 192 tensor(-1.2043e-05, device='cuda:0')
encoder.model.downsample_layers.2.0.weight 192 tensor(-5.2687e-05, device='cuda:0')
encoder.model.downsample_layers.2.0.bias 192 tensor(-6.2077e-06, device='cuda:0')
encoder.model.downsample_layers.2.1.weight 294912 tensor(-5.7902e-08, device='cuda:0')
encoder.model.downsample_layers.2.1.bias 384 tensor(-5.3963e-06, device='cuda:0')
encoder.model.down

### `timm` conversion

In [64]:
backbone_names = [
    'convnext_small.in12k_ft_in1k_384',
]
for backbone_name in backbone_names:
    model = timm.create_model(
        backbone_name,
        features_only=True,
        pretrained=True,
    )
    x = torch.randn(1, 3, 256, 256)
    y = model(x)
    print('Original: ', [y_.shape for y_ in y])
    
    model_converted = ACSConverterTimm(model)
    x = torch.randn(1, 3, 64, 256, 256)
    y = model_converted(x)
    print('Converted: ', [y_.shape for y_ in y])

Original:  [torch.Size([1, 96, 64, 64]), torch.Size([1, 192, 32, 32]), torch.Size([1, 384, 16, 16]), torch.Size([1, 768, 8, 8])]
Converted:  [torch.Size([1, 96, 16, 64, 64]), torch.Size([1, 192, 8, 32, 32]), torch.Size([1, 384, 4, 16, 16]), torch.Size([1, 768, 2, 8, 8])]
