In [1]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [2]:
import torch
import monai
from monai.visualize import matshow3d
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
import torchvision

from Models.PVT_model import PyramidVisionTransformerV2
from Models.Swin import SwinTransformer, SwinTransformerBlockV2, PatchMergingV2
import torch.nn as nn
from functools import partial
import timm


2023-05-16 13:53:34,739 - Created a temporary directory at /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmp9tgvv5t4
2023-05-16 13:53:34,740 - Writing /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmp9tgvv5t4/_remote_module_non_scriptable.py


In [3]:
# need a tensor torch.Size([8, 16, 224, 224])

x = torch.randn(8,16,256,256)

In [4]:
pvt = PyramidVisionTransformerV2(img_size=256,
                                  patch_size=4,
                                  in_chans=16,
                                  num_classes=1,
                                  embed_dims=[64, 128, 256, 512],
                                num_heads=[16, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[3, 4, 6, 3],
                                  sr_ratios=[8, 4, 2, 1]
                                 )




# PVT

In [5]:
pvt_output = pvt(x)

In [6]:
print('pvt outputs', len(pvt_output), 'tensors')
for t in pvt_output:
    print(t.shape)

pvt outputs 4 tensors
torch.Size([8, 64, 64, 64])
torch.Size([8, 128, 32, 32])
torch.Size([8, 256, 16, 16])
torch.Size([8, 512, 8, 8])


In [7]:
pvt_output[0].shape[1]

64

## SWIN

In [8]:
swin = SwinTransformer(
        embed_dim=pvt_output[0].shape[1],
        depths=[2, 2, 6, 2],
        num_heads=[4, 8, 16, 32],
        window_size=[8, 8],
        #in_channels = 16,
        stochastic_depth_prob=0.2,
        #weights=weights,
        #progress=progress,
        block=SwinTransformerBlockV2,
        downsample_layer=PatchMergingV2,)

In [9]:
swin_output= swin(pvt_output[0])

torch.Size([8, 64, 64, 64])
torch.Size([8, 256, 128, 128])


KeyboardInterrupt: 

In [None]:
print(swin_output.shape)

In [None]:
class HierarchicalSwinTransformer(nn.Module):
    def __init__(self, num_levels, **kwargs):
        super().__init__()
        self.swins = nn.ModuleList([SwinTransformer(**kwargs) for _ in range(num_levels)])

    def forward(self, xs):
        outputs = [swin(x) for swin, x in zip(self.swins, xs)]
        # Upsample and average outputs
        max_h, max_w = outputs[0].shape[-2:]
        outputs = [F.interpolate(out, size=(max_h, max_w)) for out in outputs

In [None]:
class PVTSwinSegmenter(nn.Module):
    def __init__(self, img_channels=64, mask_channels=1):
        super().__init__()
        self.pvt = PyramidVisionTransformer(img_channels=img_channels)
        self.swin = SwinTransformer(in_channels=self.pvt.embed_dim)  # Ensure input channels match PVT output
        self.segmentation_head = nn.Sequential(
            nn.Conv2d(self.swin.embed_dim, mask_channels, kernel_size=1), 
            nn.Sigmoid()
        )
        
    def forward(self, x):
        x = self.pvt(x)
        x = self.swin(x)
        x = self.segmentation_head(x)
        return x