In [1]:
# pip install nibabel monai

# Libraries

## Torch

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

## others

In [None]:
import os
import numpy as np
import nibabel as nib
from tqdm import tqdm
import matplotlib.pyplot as plt

# Model Architecture

The following code is done by Manpurwar Ganesh

## Swin Transformer Block

In [2]:

# Define the PatchPartition module

class PatchPartition(nn.Module):

    def __init__(self, in_channels=4, out_channels=48, patch_size=2):

        super(PatchPartition, self).__init__()

        self.patch_partition = nn.Conv3d(

            in_channels=in_channels,

            out_channels=out_channels,

            kernel_size=patch_size,

            stride=patch_size,

            padding=0

        )



    def forward(self, x):

        # Input shape: (B, H, W, D, C)

        x = x.permute(0, 4, 1, 2, 3)  # Convert to (B, C, H, W, D)

        x = self.patch_partition(x)   # Output shape: (B, 48, H/2, W/2, D/2)

        x = x.permute(0, 2, 3, 4, 1)  # Convert to (B, H/2, W/2, D/2, 48)

        return x



# Define the WindowMultiHeadSelfAttention module

class WindowMultiHeadSelfAttention(nn.Module):

    def __init__(self, dim, num_heads, window_size, mlp_dim):

        super(WindowMultiHeadSelfAttention, self).__init__()

        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads)

        self.norm1 = nn.LayerNorm(dim)

        self.linear = nn.Linear(dim, dim)

        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(

            nn.Linear(dim, mlp_dim),

            nn.ReLU(),

            nn.Linear(mlp_dim, dim)

        )

        self.window_size = window_size



    def forward(self, x):

        B, H, W, D, C = x.shape

        x_norm = self.norm1(x)



        x_windowed = x_norm.view(B, H // self.window_size[0], self.window_size[0],

                                 W // self.window_size[1], self.window_size[1],

                                 D // self.window_size[2], self.window_size[2], C)

        x_windowed = x_windowed.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()

        x_windowed = x_windowed.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)



        attn_output, _ = self.attention(x_windowed, x_windowed, x_windowed)

        attn_output = attn_output.view(B, H // self.window_size[0], W // self.window_size[1],

                                       D // self.window_size[2], self.window_size[0],

                                       self.window_size[1], self.window_size[2], C)

        attn_output = attn_output.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()

        attn_output = attn_output.view(B, H, W, D, C)



        w_msa_output = x + attn_output

        x_linear = self.linear(w_msa_output)

        x_linear_norm = self.norm2(x_linear)

        mlp_output = self.mlp(x_linear_norm)



        output = w_msa_output + mlp_output

        return output



# Define the ShiftedWindowMultiHeadSelfAttention module

class ShiftedWindowMultiHeadSelfAttention(nn.Module):

    def __init__(self, dim, num_heads, window_size, shift_size, mlp_dim):

        super(ShiftedWindowMultiHeadSelfAttention, self).__init__()

        self.attention = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads)

        self.norm1 = nn.LayerNorm(dim)

        self.linear = nn.Linear(dim, dim)

        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(

            nn.Linear(dim, mlp_dim),

            nn.ReLU(),

            nn.Linear(mlp_dim, dim)

        )

        self.window_size = window_size

        self.shift_size = shift_size



    def forward(self, x):

        B, H, W, D, C = x.shape

        shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3))

        x_norm = self.norm1(shifted_x)



        x_windowed = x_norm.view(B, H // self.window_size[0], self.window_size[0],

                                 W // self.window_size[1], self.window_size[1],

                                 D // self.window_size[2], self.window_size[2], C)

        x_windowed = x_windowed.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()

        x_windowed = x_windowed.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C)



        attn_output, _ = self.attention(x_windowed, x_windowed, x_windowed)

        attn_output = attn_output.view(B, H // self.window_size[0], W // self.window_size[1],

                                       D // self.window_size[2], self.window_size[0],

                                       self.window_size[1], self.window_size[2], C)

        attn_output = attn_output.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous()

        attn_output = attn_output.view(B, H, W, D, C)



        unshifted_attn_output = torch.roll(attn_output, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3))

        sw_msa_output = x + unshifted_attn_output

        x_linear = self.linear(sw_msa_output)

        x_linear_norm = self.norm2(x_linear)

        mlp_output = self.mlp(x_linear_norm)



        output = sw_msa_output + mlp_output

        return output



# Define the PatchMerging module

class PatchMerging(nn.Module):

    def __init__(self, dim):

        super(PatchMerging, self).__init__()

        self.reduction = nn.Linear(8 * dim, 2 * dim)



    def forward(self, x):

        B, H, W, D, C = x.shape

        x = x.view(B, H // 2, 2, W // 2, 2, D // 2, 2, C)

        x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous()

        x = x.view(B, H // 2, W // 2, D // 2, 8 * C)

        x = self.reduction(x)

        return x  # Shape: (B, H/2, W/2, D/2, 2*C)

In [3]:
class SwinPipeline(nn.Module):

    def __init__(self, dim, num_heads, window_size, shift_size, mlp_dim):

        super(SwinPipeline, self).__init__()

        self.attention_module1 = WindowMultiHeadSelfAttention(dim=dim, num_heads=num_heads, window_size=window_size, mlp_dim=mlp_dim)

        self.attention_module2 = ShiftedWindowMultiHeadSelfAttention(dim=dim, num_heads=num_heads, window_size=window_size, shift_size=shift_size, mlp_dim=mlp_dim)

        self.patch_merging = PatchMerging(dim=dim)



    def forward(self, x):

        x = self.attention_module1(x)

        x = self.attention_module2(x)

        x = self.patch_merging(x)

        return x


In [4]:
B, H, W, D = 2, 64, 64, 64  # Example batch size and dimensions

input_tensor = torch.randn(B, H, W, D, 4)  # Input tensor with shape (B, H, W, D, 4)

patch_partition = PatchPartition(out_channels=48)

patch_output = patch_partition(input_tensor)

print("Output shape:", patch_output.shape)

Output shape: torch.Size([2, 32, 32, 32, 48])


In [5]:
input_tensor = patch_output



pipeline = SwinPipeline(dim=48, num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2), mlp_dim=96)

stage_1_output = pipeline(input_tensor)

print("Output shape of stage 1:", stage_1_output.shape)



# patch_out=dim=mlp_dim//2

input_tensor = stage_1_output



pipeline = SwinPipeline(dim=96, num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2), mlp_dim=192)

stage_2_output = pipeline(input_tensor)

print("Output shape of stage 2:", stage_2_output.shape)



input_tensor = stage_2_output



pipeline = SwinPipeline(dim=192, num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2), mlp_dim=384)

stage_3_output = pipeline(input_tensor)

print("Output shape of stage 3:", stage_3_output.shape)



input_tensor = stage_3_output



pipeline = SwinPipeline(dim=384, num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2), mlp_dim=768)

stage_4_output = pipeline(input_tensor)

print("Output shape of stage 4:", stage_4_output.shape)

Output shape of stage 1: torch.Size([2, 16, 16, 16, 96])
Output shape of stage 2: torch.Size([2, 8, 8, 8, 192])
Output shape of stage 3: torch.Size([2, 4, 4, 4, 384])
Output shape of stage 4: torch.Size([2, 2, 2, 2, 768])


In [6]:
patch_output = patch_output.permute(0,4,1,2,3)

patch_output.shape

torch.Size([2, 48, 32, 32, 32])

In [7]:
stage_4_output = stage_4_output.permute(0,4,1,2,3)

print(stage_4_output.shape)

torch.Size([2, 768, 2, 2, 2])


## stage 4

The rest of the code is written by Sai Aditya Kudipudi 

In [8]:



class BottleneckBlock(nn.Module):

    def __init__(self, in_channels, bottleneck_channels):

        super(BottleneckBlock, self).__init__()

        self.conv1 = nn.Conv3d(in_channels, bottleneck_channels, kernel_size=1)

        self.bn1 = nn.BatchNorm3d(bottleneck_channels)

        self.conv2 = nn.Conv3d(bottleneck_channels, bottleneck_channels, kernel_size=3, padding=1)

        self.bn2 = nn.BatchNorm3d(bottleneck_channels)

        self.conv3 = nn.Conv3d(bottleneck_channels, in_channels, kernel_size=1)

        self.bn3 = nn.BatchNorm3d(in_channels)

        self.relu = nn.ReLU(inplace=True)



    def forward(self, x):

        identity = x

        out = self.conv1(x)

        out = self.bn1(out)

        out = self.relu(out)



        out = self.conv2(out)

        out = self.bn2(out)

        out = self.relu(out)



        out = self.conv3(out)

        out = self.bn3(out)

        out += identity

        out = self.relu(out)



        return out



class ResNet3DBlock(nn.Module):

    def __init__(self, channels):

        super(ResNet3DBlock, self).__init__()

        self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm3d(channels)

        self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)

        self.bn2 = nn.BatchNorm3d(channels)



    def forward(self, x):

        identity = x

        out = self.conv1(x)

        out = self.bn1(out)

        out = F.relu(out)



        out = self.conv2(out)

        out = self.bn2(out)

        out += identity

        out = F.relu(out)

        return out



class FeatureExtractor(nn.Module):

    def __init__(self, in_channels, mid_channels):

        super(FeatureExtractor, self).__init__()

        # Reduce the number of channels from in_channels to mid_channels

        self.channel_reduction = nn.Conv3d(in_channels, mid_channels, kernel_size=1)

        # Use a 3D transpose convolution to upscale the spatial dimensions

        self.upsample = nn.ConvTranspose3d(mid_channels, mid_channels, kernel_size=2, stride=2)



    def forward(self, x):

        x = self.channel_reduction(x)  # Reduce channels

        x = self.upsample(x)  # Upsample spatial dimensions

        return x



class FinalPipeline(nn.Module):

    def __init__(self, in_channels=768, bottleneck_channels=384):

        super(FinalPipeline, self).__init__()

        self.bottleneck_block = BottleneckBlock(in_channels=in_channels, bottleneck_channels=bottleneck_channels)

        self.resnet_block = ResNet3DBlock(channels=in_channels)

        self.feature_extractor = FeatureExtractor(in_channels=in_channels, mid_channels=bottleneck_channels)



    def forward(self, x):

        # Apply bottleneck block

        x = self.bottleneck_block(x)

        # print("After BottleneckBlock:", x.shape)



        # Apply ResNet block

        x = self.resnet_block(x)

        # print("After ResNet3DBlock:", x.shape)



        # Apply feature extractor for channel reduction and upsampling

        x = self.feature_extractor(x)

        # print("After FeatureExtractor:", x.shape)



        return x



# Example usage:

H, W, D = 2, 2, 2  # Example input spatial dimensions for stage 4 output

batch_size = 2



# Create an example input tensor with shape (batch_size, 768, H, W, D)

stage_4_output = torch.randn(batch_size, 768, H, W, D)



# Instantiate and apply the pipeline

pipeline = FinalPipeline(in_channels=768, bottleneck_channels=384)

skip_connection_4 = pipeline(stage_4_output)



print("Final output shape:", skip_connection_4.shape)  # Expected: (batch_size, 384, H * 2, W * 2, D * 2)


Final output shape: torch.Size([2, 384, 4, 4, 4])


## Stage skip connections

In [9]:
stage_3_output = stage_3_output.permute(0,4,1,2,3)

print(stage_3_output.shape)

stage_2_output = stage_2_output.permute(0,4,1,2,3)

print(stage_2_output.shape)

stage_1_output = stage_1_output.permute(0,4,1,2,3)

print(stage_1_output.shape)

torch.Size([2, 384, 4, 4, 4])
torch.Size([2, 192, 8, 8, 8])
torch.Size([2, 96, 16, 16, 16])


In [10]:
class HiddenFeature_1(nn.Module):#used at end of stage 1 for converstion

    def __init__(self, in_channels, out_channels):

        super(HiddenFeature_1, self).__init__()



        # Upsample to double the spatial dimensions (H/2, W/2, D/2) -> (H, W, D)

        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)



        # Convolution to reduce channels from 4*C to C

        self.conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)



        # ReLU activation

        self.relu = nn.ReLU()



    def forward(self, x):

        # Upsample the spatial dimensions

        x = self.upsample(x)



        # Apply the convolution to reduce channels from 4*C to C

        x = self.conv(x)



        # Apply activation function

        x = self.relu(x)



        return x

In [11]:
class ResNet3DBlock(nn.Module):

    def __init__(self, channels):

        super(ResNet3DBlock, self).__init__()

        # First 3D convolutional layer with batch normalization and ReLU

        self.conv1 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)

        self.bn1 = nn.BatchNorm3d(channels)

        # Second 3D convolutional layer with batch normalization

        self.conv2 = nn.Conv3d(channels, channels, kernel_size=3, padding=1)

        self.bn2 = nn.BatchNorm3d(channels)



    def forward(self, x):

        identity = x  # Save the input for the skip connection

        out = self.conv1(x)

        out = self.bn1(out)

        out = F.relu(out)



        out = self.conv2(out)

        out = self.bn2(out)



        # Add the identity (skip connection)

        out += identity

        out = F.relu(out)  # Final ReLU activation

        return out


In [12]:

class HiddenFeature(nn.Module):

    def __init__(self, channels):

        super(HiddenFeature, self).__init__()



        # A 3D convolution that preserves the input dimensions and channels

        self.conv = nn.Conv3d(in_channels=channels, out_channels=channels, kernel_size=3, padding=1)



        # ReLU activation

        self.relu = nn.ReLU()



    def forward(self, x):

        # Apply the convolution

        x = self.conv(x)



        # Apply the activation function

        x = self.relu(x)



        return x




In [13]:


class HiddenLayer_1(nn.Module):#

    def __init__(self, in_channels, out_channels):

        super(HiddenLayer_1, self).__init__()



        # Upsample to double the spatial dimensions (H/2, W/2, D/2) -> (H, W, D)

        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)



        # Convolution to reduce channels from 2*C to C

        self.conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)



        # ReLU activation

        self.relu = nn.ReLU()



    def forward(self, x):

        # Upsample the spatial dimensions

        x = self.upsample(x)



        # Apply the convolution to reduce channels from 2*C to C

        x = self.conv(x)



        # Apply activation function

        x = self.relu(x)



        return x




In [14]:
class HiddenLayer(nn.Module):

    def __init__(self, in_channels, out_channels):

        super(HiddenLayer, self).__init__()



        # Upsample spatial dimensions from (H/2, W/2, D/2) to (H, W, D)

        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=False)



        # 3D convolution to reduce channels from 2*C to C

        self.conv = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)



        # ReLU activation

        self.relu = nn.ReLU()



    def forward(self, x):

        # Upsample spatial dimensions

        x = self.upsample(x)



        # Reduce channels from 2*C to C

        x = self.conv(x)



        # Apply ReLU activation

        x = self.relu(x)



        return x

In [15]:
# Define a class that combines the operations in the example usage

class ModelPipeline(nn.Module):#used at stage 1,2,3

    def __init__(self, C):

        super(ModelPipeline, self).__init__()

        self.resnet_1 = ResNet3DBlock(channels=C)

        self.hidden_1 = HiddenFeature(channels=C)

        self.resnet_block_after_concat = ResNet3DBlock(channels=2 * C)

        self.hidden_feature_1 = HiddenFeature_1(in_channels=2 * C, out_channels=C // 2)



    def forward(self, x1, x2):

        hidden_1_output = self.hidden_1(x1)

        # print("Hidden feature output shape:", hidden_1_output.shape)



        resnet_1_output = self.resnet_1(hidden_1_output)

        # print("ResNet block 1 output shape:", resnet_1_output.shape)



        concatenated_output = torch.cat((resnet_1_output, x2), dim=1)

        # print("Concatenated output shape:", concatenated_output.shape)



        resnet_2_output = self.resnet_block_after_concat(concatenated_output)

        # print("ResNet block after concatenation output shape:", resnet_2_output.shape)



        output_tensor = self.hidden_feature_1(resnet_2_output)

        # print("Final output tensor shape:", output_tensor.shape)



        return output_tensor



# Example usage

H, W, D, C = 16, 16, 16, 96

x1 = torch.randn(1, C, H, W, D)

x2 = torch.randn(1, C, H, W, D)



model = ModelPipeline(C=C)

output = model(x1, x2)

In [16]:
class ModelPipeline_1(nn.Module):#used at stage 0

    def __init__(self, C):

        super(ModelPipeline_1, self).__init__()

        self.resnet_1 = ResNet3DBlock(channels=C)

        self.hidden_1 = HiddenFeature(channels=C)

        self.resnet_block_after_concat = ResNet3DBlock(channels=2 * C)

        self.hidden_layer_final = HiddenLayer(in_channels=2 * C, out_channels=C)



    def forward(self, x1, x2):



        hidden_1_output = self.hidden_1(x1)

        # Pass x1 through the first ResNet block

        resnet_1_output = self.resnet_1(hidden_1_output)



        # Concatenate the output of ResNet block with x2 along the channel dimension

        concatenated_output = torch.cat((hidden_1_output, x2), dim=1)



        # Pass the concatenated output through another ResNet block

        resnet_2_output = self.resnet_block_after_concat(concatenated_output)



        # Pass the output through the final HiddenLayer block to reduce channels

        final_output = self.hidden_layer_final(resnet_2_output)



        return final_output



H, W, D, C = 32, 32, 32, 48

x1 = torch.randn(1, C, H, W, D)

x2 = torch.randn(1, C, H, W, D)



model = ModelPipeline_1(C=C)



output_tensor = model(x1, x2)



# Print output shape

print("Output shape:", output_tensor.shape)

Output shape: torch.Size([1, 48, 64, 64, 64])


In [17]:
input_tensor = stage_3_output

skip_tensor = skip_connection_4

C = 384

# Create pipeline instance and forward pass

pipeline = ModelPipeline(C=C)

skip_connection_3 = pipeline(input_tensor, skip_tensor)



# Print final output shape

print("Final output shape:", skip_connection_3.shape)

Final output shape: torch.Size([2, 192, 8, 8, 8])


In [18]:
input_tensor = stage_2_output

skip_tensor = skip_connection_3

C = 192

# Create pipeline instance and forward pass

pipeline = ModelPipeline(C=C)

skip_connection_2 = pipeline(input_tensor, skip_tensor)



# Print final output shape

print("Final output shape:", skip_connection_2.shape)

Final output shape: torch.Size([2, 96, 16, 16, 16])


In [19]:
input_tensor = stage_1_output

skip_tensor = skip_connection_2

C = 96

# Create pipeline instance and forward pass

pipeline = ModelPipeline(C=C)

skip_connection_1 = pipeline(input_tensor, skip_tensor)



# Print final output shape

print("Final output shape:", skip_connection_1.shape)

Final output shape: torch.Size([2, 48, 32, 32, 32])


In [20]:
H, W, D, C = 32, 32, 32, 48

input_tensor = patch_output

skip_tensor = skip_connection_1



pipeline_0 = ModelPipeline_1(C=C)



skip_connection_0 = pipeline_0(input_tensor, skip_tensor)



# Print output shape

print("Output shape:", output_tensor.shape)

Output shape: torch.Size([1, 48, 64, 64, 64])


## Input Stage  pipeline

In [21]:
class FeaturePipeline(nn.Module):

    def __init__(self, in_channels=4, mid_channels=48, out_channels=48):

        super(FeaturePipeline, self).__init__()

        # Define the first 1x1 convolution to increase channels from in_channels to mid_channels

        self.conv1x1_increase = nn.Conv3d(in_channels=in_channels, out_channels=mid_channels, kernel_size=1)



        # Define a simple ResNet block with mid_channels as input/output

        self.resnet_block = ResNet3DBlock(channels=mid_channels)



        # Define the second 1x1 convolution to reduce channels from 2 * mid_channels to out_channels

        self.conv1x1_reduce = nn.Conv3d(in_channels=2 * mid_channels, out_channels=out_channels, kernel_size=1)



    def forward(self, input_tensor, skip_tensor):

        # Apply the first 1x1 convolution

        x = self.conv1x1_increase(input_tensor)

        # Apply the ResNet block

        x = self.resnet_block(x)

        # Concatenate with the skip tensor along the channel dimension

        x = torch.cat((x, skip_tensor), dim=1)

        # Apply the second 1x1 convolution

        x = self.conv1x1_reduce(x)

        return x



H, W, D, C = 64, 64, 64, 48

batch_size = 2

pipeline = FeaturePipeline(in_channels=4, mid_channels=48, out_channels=48)



input_tensor = torch.randn(batch_size, 4, H, W, D)

skip_tensor = skip_connection_0



output_tensor = pipeline(input_tensor, skip_tensor)

print("Output shape:", output_tensor.shape)

Output shape: torch.Size([2, 48, 64, 64, 64])


In [22]:
class HeadBlock(nn.Module):

    def __init__(self, in_channels=48, out_channels=3):

        super(HeadBlock, self).__init__()



        # First 3D convolution to refine the features

        self.conv1 = nn.Conv3d(in_channels, 64, kernel_size=3, padding=1)

        self.relu1 = nn.ReLU(inplace=True)



        # Second 3D convolution layer (optional for further refinement)

        self.conv2 = nn.Conv3d(64, 48, kernel_size=3, padding=1)

        self.relu2 = nn.ReLU(inplace=True)



        # Final 1x1x1 convolution to reduce to the desired output channels (3)

        self.conv3 = nn.Conv3d(48, out_channels, kernel_size=1)



        # Optionally, Softmax or Sigmoid can be applied on the output, depending on the task

        self.softmax = nn.Softmax(dim=1)  # For multi-class segmentation



    def forward(self, x):

        # Pass through the first conv layer and ReLU activation

        x = self.conv1(x)

        x = self.relu1(x)



        # Pass through the second conv layer and ReLU activation

        x = self.conv2(x)

        x = self.relu2(x)



        # Final 1x1x1 conv to project features to the desired output channels

        x = self.conv3(x)



        # Apply Softmax activation (for multi-class segmentation)

        x = self.softmax(x)



        return x


## Swin- UNETR Final Architecture

In [23]:
class SwinUnetr(nn.Module):

    def __init__(self):



      super(SwinUnetr, self).__init__()

      self.patch_partition = PatchPartition(out_channels = 48)



      # Define a simple ResNet block with mid_channels as input/output

      self.swin_stage_1= SwinPipeline(dim=48,num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2),mlp_dim=96)

      self.swin_stage_2= SwinPipeline(dim=96,num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2),mlp_dim=192)

      self.swin_stage_3= SwinPipeline(dim=192,num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2),mlp_dim=384)

      self.swin_stage_4= SwinPipeline(dim=384,num_heads=4, window_size=(4, 4, 4), shift_size=(2, 2, 2),mlp_dim=768)



      self.final_pipeline = FinalPipeline(in_channels=768, bottleneck_channels=384)

      self.intermediate_pipeline_3 = ModelPipeline(C=384)

      self.intermediate_pipeline_2 = ModelPipeline(C=192)

      self.intermediate_pipeline_1 = ModelPipeline(C=96)



      self.start_pipeline = ModelPipeline_1(C=48)



      self.feature_pipeline = FeaturePipeline(in_channels=4,mid_channels=48,out_channels=48)

      self.head_block = HeadBlock(in_channels=48, out_channels=3)











    def forward(self, input_tensor):





      B,H,W,D,C=input_tensor.shape

      patch_output = self.patch_partition(input_tensor)



      stage_1_output = self.swin_stage_1(patch_output)

      stage_2_output = self.swin_stage_2(stage_1_output)

      stage_3_output = self.swin_stage_3(stage_2_output)

      stage_4_output = self.swin_stage_4(stage_3_output)



      stage_4_output = stage_4_output.permute(0,4,1,2,3)

      stage_3_output = stage_3_output.permute(0,4,1,2,3)

      stage_2_output = stage_2_output.permute(0,4,1,2,3)

      stage_1_output = stage_1_output.permute(0,4,1,2,3)

      patch_output = patch_output.permute(0,4,1,2,3)



      skip_connection_4 = self.final_pipeline(stage_4_output)

      skip_connection_3 = self.intermediate_pipeline_3(stage_3_output,skip_connection_4)

      skip_connection_2 = self.intermediate_pipeline_2(stage_2_output,skip_connection_3)

      skip_connection_1 = self.intermediate_pipeline_1(stage_1_output,skip_connection_2)

      skip_connection_0 = self.start_pipeline(patch_output,skip_connection_1)

      input_tensor = input_tensor.permute(0,4,1,2,3)



      feature_output = self.feature_pipeline(input_tensor,skip_connection_0)

      segmentation_output = self.head_block(feature_output)









      return segmentation_output

In [24]:
swin_unetr = SwinUnetr()

input_tensor = torch.randn(2,64,64,64,4)

output_tensor = swin_unetr(input_tensor)

print(output_tensor.shape)

torch.Size([2, 3, 64, 64, 64])


# Dataset Preparation

This is from data reduction file by Kushwanth

In [25]:
class BRATS2021Dataset(Dataset):

    def __init__(self, root_dir, transform=None):

        """

        Args:

            root_dir (string): Directory with all patient folders containing .nii files.

            transform (callable, optional): Optional transform to apply on a sample.

        """

        self.root_dir = root_dir

        self.patients = sorted(os.listdir(root_dir))  # List of patient folders

        self.transform = transform



    def __len__(self):

        return len(self.patients)



    def __getitem__(self, idx):

        patient_dir = os.path.join(self.root_dir, self.patients[idx])



        # Define the modality filenames

        flair_path = os.path.join(patient_dir, f"{self.patients[idx]}_flair.nii")

        t1_path = os.path.join(patient_dir, f"{self.patients[idx]}_t1.nii")

        t1ce_path = os.path.join(patient_dir, f"{self.patients[idx]}_t1ce.nii")

        t2_path = os.path.join(patient_dir, f"{self.patients[idx]}_t2.nii")

        seg_path = os.path.join(patient_dir, f"{self.patients[idx]}_seg.nii")



        # Load modalities using nibabel

        flair = nib.load(flair_path).get_fdata(dtype=np.float32)

        t1 = nib.load(t1_path).get_fdata(dtype=np.float32)

        t1ce = nib.load(t1ce_path).get_fdata(dtype=np.float32)

        t2 = nib.load(t2_path).get_fdata(dtype=np.float32)

        



        # Stack modalities into a single tensor (e.g., shape: [4, H, W, D])

        modalities = torch.tensor(np.stack([flair, t1, t1ce, t2], axis=0))



        # Load segmentation

        seg = nib.load(seg_path).get_fdata(dtype=np.float32)

        segmentation = torch.tensor(seg, dtype=torch.long)



        # Apply transformations if any

        if self.transform:

            modalities = self.transform(modalities)

            segmentation = self.transform(segmentation)

        modalities = modalities.squeeze(dim=1)

        segmentation = segmentation.squeeze(dim=0)

        modalities = modalities.permute(1, 2, 3, 0)

        segmentation[segmentation == 4] = 2



        return modalities, segmentation



transform = transforms.Compose([

    transforms.Normalize(mean=[0.5, 0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5, 0.5])  # Normalize modalities

])


In [26]:
train_dir = "D:/swin-dataset/Processed_BraTS2021/train" 

train_dataset = BRATS2021Dataset(root_dir=train_dir)

train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)



test_dir = "D:/swin-dataset/Processed_BraTS2021/test"   

test_dataset = BRATS2021Dataset(root_dir=test_dir)

test_dataloader = DataLoader(test_dataset, batch_size=2, shuffle=True)



validation_dir = "D:swin-dataset/Processed_BraTS2021/val"     
validation_dataset = BRATS2021Dataset(root_dir=validation_dir)

validation_dataloader = DataLoader(validation_dataset, batch_size=2, shuffle=True)

In [27]:
train_img = next(iter(validation_dataloader))

In [28]:
len(train_img)

2

In [29]:
train_img[1].shape

torch.Size([2, 64, 64, 64])

In [30]:
torch.unique(train_img[1][0])

tensor([0, 1, 2])

# Loss Function

The rest is done by Aditya Kudupudi

In [31]:
import torch

import torch.nn as nn

import torch.nn.functional as F



class DiceLoss(nn.Module):

    def __init__(self, smooth=1e-6):

        super(DiceLoss, self).__init__()

        self.smooth = smooth



    def forward(self, preds, targets):

        """

        Computes the Dice Loss for multi-class segmentation.

        

        Args:

            preds (torch.Tensor): Predicted tensor of shape (B, C, H, W, D) with raw logits.

            targets (torch.Tensor): Ground truth tensor of shape (B, H, W, D) with class indices.

        

        Returns:

            torch.Tensor: The Dice Loss.

        """

        # Apply softmax to get class probabilities

        preds = F.softmax(preds, dim=1)  # Softmax along the channel axis

        

        # Get the number of classes from predictions

        num_classes = preds.shape[1]

        

        # Convert targets to one-hot encoding of shape (B, C, H, W, D)

        targets_one_hot = F.one_hot(targets, num_classes=num_classes)  # Shape: (B, H, W, D, C)

        targets_one_hot = targets_one_hot.permute(0, 4, 1, 2, 3).float()  # Shape: (B, C, H, W, D)

        

        # Compute Dice Loss per class

        dice_loss = 0

        for c in range(num_classes):

            pred_flat = preds[:, c].contiguous().view(-1)

            target_flat = targets_one_hot[:, c].contiguous().view(-1)

            intersection = (pred_flat * target_flat).sum()

            dice_score = (2. * intersection ) / (pred_flat.sum() + target_flat.sum() + self.smooth)

            dice_loss += 1 - dice_score



        # Average Dice Loss across all classes

        return dice_loss / num_classes


# Checkpoints setup

In [None]:
# Define checkpoint directory

# checkpoint_dir = 'D:\checkpoints'

# os.makedirs(checkpoint_dir, exist_ok=True)

In [32]:
import os

import re

import torch



def find_latest_checkpoint(checkpoint_dir):

    """

    Finds the checkpoint file with the highest epoch number in the given directory.

    Args:

        checkpoint_dir (str): Path to the directory containing checkpoint files.



    Returns:

        str: Path to the checkpoint file with the highest epoch.

    """

    # Regex to capture epoch numbers in filenames

    epoch_pattern = re.compile(r"model_epoch_(\d+)\.pth")

    

    latest_epoch = -1

    latest_checkpoint = None

    

    # Iterate over files in the directory

    for file_name in os.listdir(checkpoint_dir):

        match = epoch_pattern.search(file_name)

        if match:

            epoch = int(match.group(1))

            if epoch > latest_epoch:

                latest_epoch = epoch

                latest_checkpoint = os.path.join(checkpoint_dir, file_name)

    

    if latest_checkpoint is None:

        raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}.")

    

    return latest_checkpoint, latest_epoch



def load_checkpoint(model, optimizer, checkpoint_path, device='cuda'):

    # Load the checkpoint

    checkpoint = torch.load(checkpoint_path, map_location=device)

    

    # Restore model state

    model.load_state_dict(checkpoint['model_state_dict'])

    

    # Restore optimizer state

    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    

    # Restore the starting epoch

    start_epoch = checkpoint['epoch']

    

    print(f"Checkpoint loaded. Resuming training from epoch {start_epoch}.")

    return start_epoch




In [33]:
print(f"PyTorch version: {torch.__version__}")

print(f"CUDA available: {torch.cuda.is_available()}")

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU detected'}")


PyTorch version: 2.4.0
CUDA available: True
GPU: Tesla P100-PCIE-16GB


# Training Loop

In [34]:
# Initialize lists to store losses, learning rates, and epochs

train_losses = []

val_losses = []

learning_rates = []

epochs = []



# Define the device - use GPU if available, otherwise fall back to CPU

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')



# Initialize the model and send it to the device

model = SwinUnetr().to(device)

dice_loss_fn = DiceLoss()

# Optimizer setup with initial learning rate 1e-2

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
def train(model, train_loader, val_loader, num_epochs=10, device='cuda', resume_from_checkpoint=None):
    start_epoch = 0

    if resume_from_checkpoint:
        start_epoch = load_checkpoint(model, optimizer, resume_from_checkpoint, device)

    model.train()

    # Define the ReduceLROnPlateau scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

    for epoch in range(start_epoch, num_epochs):
        epoch_loss = 0.0
        epoch_lr = optimizer.param_groups[0]['lr']  # Get the initial learning rate for this epoch
        epochs.append(epoch + 1)  # Store the epoch number

        print(f"\nEpoch {epoch + 1}/{num_epochs} (Learning Rate: {epoch_lr})")

        # Loop over the batches in the train loader
        with tqdm(train_loader, desc=f"Training Epoch {epoch + 1}") as t:
            for batch_idx, (data, target) in enumerate(t):
                data, target = data.to(device), target.to(device)

                optimizer.zero_grad()  # Clear the gradients

                # Forward pass through the model
                output = model(data)

                # Compute the loss
                loss = dice_loss_fn(output, target)

                # Backward pass to compute gradients
                loss.backward()

                # Optimize the parameters (Adam optimizer)
                optimizer.step()

                # Track the loss and learning rate
                epoch_loss += loss.item()
                train_losses.append(loss.item())
                learning_rates.append(optimizer.param_groups[0]['lr'])

                # Update tqdm with loss info
                t.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

        # Print average epoch loss
        print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(train_loader)}")

        # Optionally: Evaluate the model on the validation set
        val_loss = validate(model, val_loader, device)
        val_losses.append(val_loss)

        # Step the scheduler with the validation loss
        scheduler.step(val_loss)

        # Save checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss / len(train_loader),
                'val_loss': val_loss
            }, checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

# Validation loop
def validate(model, val_loader, device='cuda'):
    model.eval()  # Set the model to evaluation mode

    val_loss = 0.0
    print("\nValidating...")
    with tqdm(val_loader, desc="Validation") as t:
        with torch.no_grad():  # No gradients during evaluation
            for data, target in t:
                data, target = data.to(device), target.to(device)
                output = model(data)
                loss = dice_loss_fn(output, target)
                val_loss += loss.item()

                # Update tqdm with loss info
                t.set_postfix(loss=loss.item())

    # Print the validation loss
    avg_val_loss = val_loss / len(val_loader)
    print(f"Validation Loss: {avg_val_loss}")

    model.train()  # Set the model back to training mode
    return avg_val_loss  # Return the validation loss for the scheduler

## Training 

In [None]:
## Directory containing checkpoints

checkpoint_dir = "path to dir"



# Find the latest checkpoint

# latest_checkpoint, latest_epoch = find_latest_checkpoint("/kaggle/working")



# Resume training

# print(f"Resuming training from checkpoint: {latest_checkpoint} (epoch {latest_epoch})")

# start_epoch = load_checkpoint(model, optimizer, latest_checkpoint, device=device)



train(model, train_dataloader, validation_dataloader, num_epochs=70,  resume_from_checkpoint=None,device=device)

# Testing

In [None]:
validate(model,test_dataloader,device,dice_loss_fn)

# plotting

In [None]:
# Plotting the losses and learning rates after training
def plot_metrics():

    # Plot training loss

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)

    plt.plot(train_losses, label='Training Loss', color='b')

    plt.title('Training Loss')

    plt.xlabel('Batch')

    plt.ylabel('Loss')

    plt.legend()



    # Plot validation loss

    plt.subplot(1, 2, 2)

    plt.plot(val_losses, label='Validation Loss', color='r')

    plt.title('Validation Loss')

    plt.xlabel('Epoch')

    plt.ylabel('Loss')

    plt.legend()



    plt.tight_layout()

    plt.show()




plot_metrics()
