In [8]:
# monai==1.3.2
# python==3.10.14

from __future__ import annotations

from collections.abc import Sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from swinunetr import SwinUNETR


class SwinUNETREncoder(SwinUNETR):
    def __init__(
        self,
        img_size: Sequence[int] | int,
        in_channels: int,
        out_channels: int,
        patch_size: int,
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        feature_size: int = 24,
        norm_name: tuple | str = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = True,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        downsample="merging",
        use_v2=False,
    ):
        super().__init__(
            img_size,
            in_channels,
            out_channels,
            patch_size,
            depths,
            num_heads,
            feature_size,
            norm_name,
            drop_rate,
            attn_drop_rate,
            dropout_path_rate,
            normalize,
            use_checkpoint,
            spatial_dims,
            downsample,
            use_v2,
        )
        del self.decoder1, self.decoder2, self.decoder3, self.decoder4, self.decoder5
        self.out = nn.Identity()  # NOTE - should be overloaded by task-specific heads

    def fuse_representations(self, reps: list):
        out = []
        for rep in reps:
            # rep: [b, n, j, k, l] -> avg pooled rep: [b, n, 1, 1, 1] -> reshaped rep: [b, n]
            out.append(F.adaptive_avg_pool2d(rep, (1, 1)).view(rep.size()[0], -1).contiguous())
        fused_rep = torch.cat(out, dim=1)
        return fused_rep

    def forward(self, x):
        # x:                    [b, c,            h,     w,     d]
        # hidden_states_out[0]: [b, n_feats * 1,  h//2,  w//2,  d//2]
        # hidden_states_out[1]: [b, n_feats * 2,  h//4,  w//4,  d//4]
        # hidden_states_out[2]: [b, n_feats * 4,  h//8,  w//8,  d//8]
        # hidden_states_out[3]: [b, n_feats * 8,  h//16, w//16, d//16]
        # hidden_states_out[4]: [b, n_feats * 16, h//32, w//32, d//32]
        # enc0:                 [b, n_feats * 1,  h,     w,     d]
        # enc1:                 [b, n_feats * 1,  h//2,  w//2,  d//2]
        # enc2:                 [b, n_feats * 2,  h//4,  w//4,  d//4]
        # enc3:                 [b, n_feats * 4,  h//8,  w//8,  d//8]
        # dec4:                 [b, n_feats * 16, h//32, w//32, d//32]
        # fused_rep:            [b, n_feats * (1 + 1 + 2 + 4 + 16)]
        # y:                    based on self.out
        hidden_states_out = self.swinViT(x)
        print(hidden_states_out[0].shape)
        print(hidden_states_out[1].shape)
        print(hidden_states_out[2].shape)
        print(hidden_states_out[3].shape)
        print(hidden_states_out[4].shape)
        enc0 = self.encoder1(x)
        print(enc0.shape)
        enc1 = self.encoder2(hidden_states_out[0])
        print(enc1.shape)
        enc2 = self.encoder3(hidden_states_out[1])
        print(enc2.shape)
        enc3 = self.encoder4(hidden_states_out[2])
        print(enc3.shape)
        dec4 = self.encoder10(hidden_states_out[4])
        print(dec4.shape)
        fused_rep = self.fuse_representations([enc0, enc1, enc2, enc3, dec4])
        y = self.out(fused_rep)
        return y


In [9]:
# model = SwinUNETREncoder(img_size=[512, 512], in_channels=1, out_channels=1, patch_size=8, feature_size=48, use_checkpoint=False, use_v2=False, spatial_dims=2)

# data = torch.rand((1, 1, 512, 512), dtype=torch.float32)

# out = model(data)

In [10]:
from scunet import SCUNet

model = SCUNet(in_nc=1, dim=64, drop_path_rate=0.0, input_resolution=256)

data = torch.rand((1, 1, 256, 256), dtype=torch.float32)

out = model(data)
print(out.shape)

Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000
Block Initial Type: W, drop_path_rate:0.000000
Block Initial Type: SW, drop_path_rate:0.000000


RuntimeError: Given groups=1, weight of size [64, 64, 1, 1], expected input[1, 1, 256, 256] to have 64 channels, but got 1 channels instead