In [1]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [2]:
import torch
import monai
from monai.visualize import matshow3d
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
import torchvision

from Models.PVT_model import PyramidVisionTransformerV2
from Models.Swin import SwinTransformer, SwinTransformerBlockV2, PatchMergingV2
import torch.nn as nn
from functools import partial
import timm


2023-05-16 16:56:05,087 - Created a temporary directory at /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmpv0l_x52i
2023-05-16 16:56:05,088 - Writing /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmpv0l_x52i/_remote_module_non_scriptable.py


In [3]:
# need a tensor torch.Size([8, 16, 224, 224])

x = torch.randn(8,16,256,256)

In [18]:
pvt = PyramidVisionTransformerV2(img_size=256,
                                  #patch_size=4,
                                  in_chans=16,
                                  num_classes=1,
                                  embed_dims=[64, 128, 256, 512],
                                num_heads=[16, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[3, 4, 6, 3],
                                  sr_ratios=[8, 4, 2, 1]
                                 )




# PVT

In [19]:
pvt_output = pvt(x)

In [6]:
print('pvt outputs', len(pvt_output), 'tensors')
for t in pvt_output:
    print(t.shape)

pvt outputs 5 tensors
torch.Size([8, 16, 256, 256])
torch.Size([8, 64, 64, 64])
torch.Size([8, 128, 32, 32])
torch.Size([8, 256, 16, 16])
torch.Size([8, 512, 8, 8])


In [7]:
pvt_output[1].shape[1]

64

## SWIN

In [8]:
swin = SwinTransformer(
         #patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 4, ],
        num_heads=[2, 4, 8, ],
        window_size=[8, 8],
        #in_channels = 128,
        stochastic_depth_prob=0.2,
        #weights=weights,
        #progress=progress,
        block=SwinTransformerBlockV2,
        downsample_layer=PatchMergingV2,)

In [9]:
outputs = []
for out in pvt_output:
    swin_out = swin(out)
    outputs.append(swin_out)

In [10]:
for out in outputs:
    print(out.shape)

torch.Size([8, 384, 64, 64])
torch.Size([8, 384, 16, 16])
torch.Size([8, 384, 8, 8])
torch.Size([8, 384, 4, 4])
torch.Size([8, 384, 2, 2])


In [20]:
class PVT_w_SWIM(nn.Module):
    def __init__(self, img_size = 256,in_channels =16,  ):
        
        
        super().__init__()
        self.pvt = pvt = PyramidVisionTransformerV2(img_size=img_size,
                                  patch_size=4,
                                  in_chans=in_channels,
                                  num_classes=1,
                                  embed_dims=[64, 128, 256, 512],
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[2, 2, 2, 2],
                                  sr_ratios=[8, 4, 2, 1]
                                 )
        
        self.swin = SwinTransformer(#patch_size=[4, 4],
                                    embed_dim=96,
                                    depths=[2, 2, 4, ],
                                    num_heads=[2, 4, 8, ],
                                    window_size=[8, 8],
                                    #in_channels = 128,
                                    stochastic_depth_prob=0.2,
                                    #weights=weights,
                                    #progress=progress,
                                    block=SwinTransformerBlockV2,
                                    downsample_layer=PatchMergingV2,)
        
        self.head = SegmentationHead()

        
    def forward(self, x):
        # pass through PVT
        pvt_outs = self.pvt(x) # outputs 5 tensors
        
        
        # we run each pvt output thru SWIM
        # SWIM will outputs 5 tensors 
        swim_outs = []
        swim_outs.append(x)
        for pvt_out in pvt_outs:
            swim_out = self.swin(pvt_out)
            swim_outs.append(swim_out)
            
        # UPsample and use Segmentation head
        final_outs = self.head(swim_outs)
        
        
        
        return final_outs
        
        

class SegmentationHead(nn.Module):
    def __init__(self):
        super(SegmentationHead, self).__init__()

        # Define deconvolution layers
        self.deconv1 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=4, stride=4, padding=0)
        self.deconv2 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=16, stride=16, padding=0)
        self.deconv3 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=32, stride=32, padding=0)
        self.deconv4 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=64, stride=64, padding=0)
        self.deconv5 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=128, stride=128, padding=0)

        # 1x1 convolution to reduce the number of channels to 1
        self.conv1x1 = nn.Conv2d(in_channels=16 + 384*5, out_channels=1, kernel_size=1)

    def forward(self, swin_outputs):
        # Upsample the outputs (except the first one which is already at the desired spatial size)
        deconv_outputs = [self.deconv1(swin_outputs[1]), self.deconv2(swin_outputs[2]), 
                          self.deconv3(swin_outputs[3]), self.deconv4(swin_outputs[4]), 
                          self.deconv5(swin_outputs[5])]

        # Now, we add the first output from swin_outputs (which is already at the desired spatial size)
        deconv_outputs = [swin_outputs[0]] + deconv_outputs

        # Concatenate
        concatenated_output = torch.cat(deconv_outputs, dim=1)  # Shape: (8, 16 + 384*5, 256, 256)

        # 1x1 convolution
        final_output = self.conv1x1(concatenated_output)  # Shape: (8, 1, 256, 256)

        return final_output


In [12]:
x = torch.randn(8,16,256,256)
model = PVT_w_SWIM()
final_outs = model(x)

In [13]:
final_outs.shape

torch.Size([8, 1, 256, 256])

In [14]:
class SegmentationHead(nn.Module):
    def __init__(self):
        super(SegmentationHead, self).__init__()

        # Define deconvolution layers
        self.deconv1 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=4, stride=4, padding=0)
        self.deconv2 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=16, stride=16, padding=0)
        self.deconv3 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=32, stride=32, padding=0)
        self.deconv4 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=64, stride=64, padding=0)
        self.deconv5 = nn.ConvTranspose2d(in_channels=384, out_channels=384, 
                                          kernel_size=128, stride=128, padding=0)

        # 1x1 convolution to reduce the number of channels to 1
        self.conv1x1 = nn.Conv2d(in_channels=16 + 384*5, out_channels=1, kernel_size=1)

    def forward(self, swin_outputs):
        # Upsample the outputs (except the first one which is already at the desired spatial size)
        deconv_outputs = [self.deconv1(swin_outputs[1]), self.deconv2(swin_outputs[2]), 
                          self.deconv3(swin_outputs[3]), self.deconv4(swin_outputs[4]), 
                          self.deconv5(swin_outputs[5])]

        # Now, we add the first output from swin_outputs (which is already at the desired spatial size)
        deconv_outputs = [swin_outputs[0]] + deconv_outputs

        # Concatenate
        concatenated_output = torch.cat(deconv_outputs, dim=1)  # Shape: (8, 16 + 384*5, 256, 256)

        # 1x1 convolution
        final_output = self.conv1x1(concatenated_output)  # Shape: (8, 1, 256, 256)

        return final_output


RuntimeError: Given transposed=1, weight of size [384, 384, 4, 4], expected input[1, 1, 256, 256] to have 384 channels, but got 1 channels instead

In [None]:
outs.shape

In [None]:
class HierarchicalSwinTransformer(nn.Module):
    def __init__(self, num_levels, **kwargs):
        super().__init__()
        self.swins = nn.ModuleList([SwinTransformer(**kwargs) for _ in range(num_levels)])

    def forward(self, xs):
        outputs = [swin(x) for swin, x in zip(self.swins, xs)]
        # Upsample and average outputs
        max_h, max_w = outputs[0].shape[-2:]
        outputs = [F.interpolate(out, size=(max_h, max_w)) for out in outputs]

In [None]:
class PVTSwinSegmenter(nn.Module):
    def __init__(self, img_channels=64, mask_channels=1):
        super().__init__()
        self.pvt = PyramidVisionTransformer(img_channels=img_channels)
        self.swin = SwinTransformer(in_channels=self.pvt.embed_dim)  # Ensure input channels match PVT output
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(self.swin.embed_dim, mask_channels, kernel_size=1), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.pvt(x)
        x = self.swin(x)
        x = self.segmentation_head(x)
        return x