In [3]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [4]:
import torch
import monai
from monai.visualize import matshow3d
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
import torchvision

from Models.PVT_model import PyramidVisionTransformerV2
from Models.Swin import SwinTransformer, SwinTransformerBlockV2, PatchMergingV2
import torch.nn as nn
from functools import partial
import timm


In [12]:
# need a tensor torch.Size([8, 16, 224, 224])
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

x = torch.randn(8,16,256,256)

In [7]:
pvt = PyramidVisionTransformerV2(img_size=256,
                                  patch_size=4,
                                  in_chans=16,
                                  num_classes=1,
                                  embed_dims=[32, 64, 128, 256,],
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[2, 2, 2, 2],
                                  sr_ratios=[8, 4, 2, 1]
                                 )




# PVT

In [8]:
pvt_output = pvt(x)

In [9]:
print('pvt outputs', len(pvt_output), 'tensors')
for t in pvt_output:
    print(t.shape)

pvt outputs 5 tensors
torch.Size([8, 16, 256, 256])
torch.Size([8, 32, 64, 64])
torch.Size([8, 64, 32, 32])
torch.Size([8, 128, 16, 16])
torch.Size([8, 256, 8, 8])


In [7]:
pvt_output[1].shape[1]

64

## SWIN

In [21]:
swin = SwinTransformer(
         #patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 4, ],
        num_heads=[2, 4, 8, ],
        window_size=[8, 8],
        #in_channels = 128,
        stochastic_depth_prob=0.2,
        #weights=weights,
        #progress=progress,
        block=SwinTransformerBlockV2,
        downsample_layer=PatchMergingV2,)

In [22]:
outputs = []
for out in pvt_output:
    swin_out = swin(out)
    outputs.append(swin_out)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (524288x16 and 96x288)

In [None]:
for out in outputs:
    print(out.shape)

In [15]:
class PVT_w_SWIM(nn.Module):
    def __init__(self, img_size = 256,in_channels =16, embed_dim =96  ):
        
        
        super().__init__()
        
        self.embed_dim =embed_dim
        self.pvt =  PyramidVisionTransformerV2(img_size=img_size,
                                  patch_size=4,
                                  in_chans=in_channels,
                                  num_classes=1,
                                  embed_dims=[32, 64, 128, 256,],
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[2, 2, 2, 2],
                                  sr_ratios=[8, 4, 2, 1]
                                 ).to(DEVICE)
        
        self.swin = SwinTransformer(#patch_size=[4, 4],
                                    embed_dim=self.embed_dim,
                                    depths=[2, 2, 2, ],
                                    num_heads=[2, 4, 8, ],
                                    window_size=[8, 8],
                                    #in_channels = 128,
                                    stochastic_depth_prob=0.2,
                                    #weights=weights,
                                    #progress=progress,
                                    block=SwinTransformerBlockV2,
                                    downsample_layer=PatchMergingV2,).to(DEVICE)
        
        
        # for in channels out outchannles of PVT
        self.conv = [nn.Conv2d(in_channels, self.embed_dim, kernel_size=1, stride=1, padding=0).to(DEVICE),
                     nn.Conv2d(in_channels*2, self.embed_dim, kernel_size=1, stride=1, padding=0).to(DEVICE),
                     nn.Conv2d(in_channels*4, self.embed_dim, kernel_size=1, stride=1, padding=0).to(DEVICE),
                     nn.Conv2d(in_channels*8, self.embed_dim, kernel_size=1, stride=1, padding=0).to(DEVICE),
                     nn.Conv2d(in_channels*16, self.embed_dim, kernel_size=1, stride=1, padding=0).to(DEVICE),
                    ]
        
        
        self.head = SegmentationHead().to(DEVICE)
        

        
    def forward(self, x):
        # pass through PVT
        pvt_outs = self.pvt(x) # outputs 5 tensors
        
        
        # we run each pvt output thru SWIM
        # SWIM will outputs 5 tensors 
        swim_outs = []
        swim_outs.append(x)
        for conv, pvt_out in zip(self.conv, pvt_outs):
            pvt_conv = conv(pvt_out)
            swim_out = self.swin(pvt_conv)
            swim_outs.append(swim_out)
            
        # UPsample and use Segmentation head
        final_outs = self.head(swim_outs)
        
        
        
        return final_outs
        
        

class SegmentationHead(nn.Module):
    def __init__(self,in_channels=1936, out_channels=1):
        super().__init__()
        self.conv1x1 = nn.Conv2d(in_channels,1, kernel_size=1).to(DEVICE)
        

    def forward(self, swin_outputs):
        # Upsample
        upsampled_outputs = [torch.nn.functional.interpolate(out, size=(256, 256), mode='bilinear', align_corners=False) for out in swin_outputs]
        # Concatenate
        concatenated_output = torch.cat(upsampled_outputs, dim=1)  # Shape: (batch_size, N, 256, 256)
        
        # 1x1 convolution
        final_output = self.conv1x1(concatenated_output)  # Shape: (batch_size, 1, 256, 256)
        return final_output


In [16]:
x = torch.randn(8,16,256,256)
model = PVT_w_SWIM()
final_outs = model(x)

RuntimeError: Given groups=1, weight of size [1, 1920, 1, 1], expected input[8, 976, 256, 256] to have 1920 channels, but got 976 channels instead

In [21]:
final_outs.shape

torch.Size([8, 1, 256, 256])

NameError: name 'self' is not defined

In [14]:
class SegmentationHead(nn.Module):
    def __init__(self, out_channels=1):
        super().__init__()
        #self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, swin_outputs):
        # Upsample
        upsampled_outputs = [torch.nn.functional.interpolate(out, size=(256, 256), mode='bilinear', align_corners=False) for out in swin_outputs]
        # Concatenate
        concatenated_output = torch.cat(upsampled_outputs, dim=1)  # Shape: (batch_size, N, 256, 256)
        # 1x1 convolution
        #final_output = self.conv1x1(concatenated_output)  # Shape: (batch_size, 1, 256, 256)
        return concatenated_output

In [15]:
head = SegmentationHead()
final = head(outputs)

In [16]:
final.shape

torch.Size([8, 1920, 256, 256])

In [17]:
class HierarchicalSwinTransformer(nn.Module):
    def __init__(self, num_levels, **kwargs):
        super().__init__()
        self.swins = nn.ModuleList([SwinTransformer(**kwargs) for _ in range(num_levels)])

    def forward(self, xs):
        outputs = [swin(x) for swin, x in zip(self.swins, xs)]
        # Upsample and average outputs
        max_h, max_w = outputs[0].shape[-2:]
        outputs = [F.interpolate(out, size=(max_h, max_w)) for out in outputs]

In [18]:
class PVTSwinSegmenter(nn.Module):
    def __init__(self, img_channels=64, mask_channels=1):
        super().__init__()
        self.pvt = PyramidVisionTransformer(img_channels=img_channels)
        self.swin = SwinTransformer(in_channels=self.pvt.embed_dim)  # Ensure input channels match PVT output
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(self.swin.embed_dim, mask_channels, kernel_size=1), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.pvt(x)
        x = self.swin(x)
        x = self.segmentation_head(x)
        return x