In [1]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [134]:
import torch
import monai
#from monai.visualize import matshow3d
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
from Models.PVT2 import PyramidVisionTransformerV2, Up, OutConv
import torch.nn as nn
from functools import partial
import torchvision
import torch.nn.functional as F
from Models.Swin import SwinTransformer, SwinTransformerBlockV2, PatchMergingV2
from lit_models.scratch_models import FPNDecoder
from Models.PreBackbone_3D import PreBackbone_3D


In [3]:
PATCH_SIZE = 256
Z_DIM = 8
COMPETITION_DATA_DIR_str =  "kaggle/input/vesuvius-challenge-ink-detection/"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")

# change to the line below if not using Apple's M1 or chips
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 3d Convolutions

#### to decrease z_dim/channels

In [129]:
class PreBackbone_3D(nn.Module):
    def __init__(self, batch_norm = True ):
        
        
        
        super(PreBackbone_3D, self).__init__()
        
        self.leaky_relu = nn.LeakyReLU( inplace=True)
        self.batch_norm = batch_norm
        
        self.conv = nn.Conv3d(in_channels=1,
                            out_channels=1,
                             kernel_size = (3, 1, 1),
                             stride=(1, 1, 1),
                             padding= (1, 0, 0)
                             )
        
        torch.nn.init.xavier_uniform_(self.conv.weight)
        torch.nn.init.zeros_(self.conv.bias)
         
        self.pool = nn.AvgPool3d(kernel_size = (2,1,1), stride=(2,1,1))
        self.batch_norm = torch.nn.BatchNorm3d( num_features=1, momentum=0.9)
        
        
        
        
    

    def forward(self, x):
        x = x.unsqueeze(1) # (B,C,H,W) -> (B, 1 C, H, W)
    
        
        y = self.conv(x)
        y = self.pool(y)
        y = self.leaky_relu(y)
        if self.batch_norm:
            y = self.batch_norm(y)  # (B, 1 C, H, W) -> (B, 1 C/2, H, W)
            
       
        y = self.conv(y)
        y = self.pool(y)
        y = self.leaky_relu(y)
        if self.batch_norm:
            y = self.batch_norm(y)  # (B, 1 C, H, W) -> (B, 1 C/2, H, W)
            
            
        y = self.conv(y)
        y = self.pool(y)
        y = self.leaky_relu(y)
        if self.batch_norm:
            y = self.batch_norm(y)  # (B, 1 C, H, W) -> (B, 1 C/2, H, W)
            
        y = self.conv(y)
        y = self.pool(y)
        y = self.leaky_relu(y)
        if self.batch_norm:
            y = self.batch_norm(y)  # (B, 1 C, H, W) -> (B, 1 C/2, H, W)
            
        return y
      


In [133]:
dummy = torch.randn(5,64,256,256)
pre_model = PreBackbone_3D()
pre_out = pre_model(dummy)
print(pre_out.shape)

torch.Size([5, 1, 4, 256, 256])


In [4]:
class PVT_w_FPN(nn.Module):
    def __init__(self, in_channels,  embed_dims=[  64, 128, 256, 512], n_classes=1, ):
        super().__init__()
        
        self.embed_dims = embed_dims
        
       
        self.pvt = PyramidVisionTransformerV2(img_size = PATCH_SIZE,
                                  patch_size = 4,
                                  in_chans = Z_DIM,
                                  num_classes = 1,
                                  embed_dims = embed_dims,
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-3),
                                #norm_layer=nn.LayerNorm,          
                                  depths=[2, 2, 2,2],
                                  sr_ratios=[1, 1, 1, 1]
                                 ).to(DEVICE) 
        
        self.FPN = FPNDecoder(
                            in_channels = Z_DIM,
                            encoder_channels = embed_dims ,
                            encoder_depth=5,
                            pyramid_channels=256,
                            segmentation_channels=128,
                            dropout=0.2,
                            merge_policy="cat",).to(DEVICE) 
        
       

    def forward(self, x):
        #x = x.unsqueeze(1)
        #x = self.pre_model3d(x)
        #x = x.squeeze(1)
        
        pvt_outs = self.pvt(x)
        
        logits = self.FPN(*pvt_outs)
        
       
       
            
        
        return logits












dummy = torch.randn(5,8,256,256).to(DEVICE) 
model = PVT_w_FPN(in_channels =8 ,  embed_dims=[ 64, 128, 256, 512])
out = model(dummy)
print(out.shape)

# BACKBONE

In [5]:
model = PVT_w_FPN(in_channels =8 ,  embed_dims=[ 64, 128, 256, 512])
dummy = torch.randn(5,8,256,256).to(DEVICE) 
pvt_outs = model.pvt(dummy)
print('pvt outputs')
for t in pvt_outs:
    print(t.shape)

pvt outputs
torch.Size([5, 8, 256, 256])
torch.Size([5, 64, 64, 64])
torch.Size([5, 128, 32, 32])
torch.Size([5, 256, 16, 16])
torch.Size([5, 512, 8, 8])


# NECK FPN

In [6]:
fpn_outs = model.FPN(*pvt_outs)
print(fpn_outs.shape)

torch.Size([5, 1, 256, 256])


In [7]:
from lit_models.Vesuvius_Lit_Model import dice_coef_torch

In [8]:
dummy_y = torch.ones(5,1,256,256)
loss = dice_coef_torch(fpn_outs, dummy_y)

In [9]:
loss

tensor(0.1686, device='mps:0', grad_fn=<RsubBackward1>)

# MLP HEAD

In [10]:
class MLP(nn.Module):
    """
    Linear Embedding
    """
    def __init__(self, input_dim=2048, embed_dim=768):
        super().__init__()
        self.proj = nn.Linear(input_dim, embed_dim)

    def forward(self, x):
        x = x.flatten(2).transpose(1, 2)
        x = self.proj(x)
        return x



class SegFormerHead(nn.Module):
    """
    SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
    """
    def __init__(self, z_dim, in_channels, embedding_dim, dropout= 0, feature_strides=None, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        self.feature_strides = feature_strides
        self.num_classes = 1
        self.dropout = dropout


        #decoder_params = kwargs['decoder_params']
        self.embedding_dim = embedding_dim

        self.linear_c4 = MLP(input_dim=self.in_channels[-1], embed_dim=self.embedding_dim)
        self.linear_c3 = MLP(input_dim=self.in_channels[-2], embed_dim=self.embedding_dim)
        self.linear_c2 = MLP(input_dim=self.in_channels[-3], embed_dim=self.embedding_dim)
        self.linear_c1 = MLP(input_dim=self.in_channels[-4], embed_dim=self.embedding_dim)
        self.linear_c0 = MLP(input_dim=z_dim, embed_dim=self.embedding_dim)

        self.conv_fuse = nn.Sequential(
                    nn.ConvTranspose2d(
                        embedding_dim*5, embedding_dim, kernel_size=1, stride=1),
                    torch.nn.SyncBatchNorm(embedding_dim, eps=1e-04, momentum=0.1),
                    #nn.GroupNorm(32, segmentation_channels, eps=1e-03),
                    nn.GELU(),
                    nn.ConvTranspose2d(
                        embedding_dim, embedding_dim, kernel_size=1, stride=1),
                ).to(DEVICE)

        self.linear_pred = nn.Conv2d(embedding_dim, self.num_classes, kernel_size=1)
        self.dropout = nn.Dropout2d(p=self.dropout, inplace=True)

    def forward(self, *features):
        #x = self._transform_inputs(inputs)  # len=4, 1/4,1/8,1/16,1/32
        c0, c1, c2, c3, c4,  = features
        
        print(c0.shape,c1.shape, c2.shape, c3.shape, c4.shape)

        ############## MLP decoder on C1-C4 ###########
        n, _, h, w = c4.shape

        _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
        #_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
        _c4 =  F.interpolate(_c4, size=c0.size()[2:], scale_factor=None, mode='bilinear',align_corners=False)

        _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
        #_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
        _c3  =  F.interpolate(_c3, size=c0.size()[2:], scale_factor=None, mode='bilinear',align_corners=False)

        _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
        #_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
        _c2 = F.interpolate(_c2, size=c0.size()[2:], scale_factor=None, mode='bilinear',align_corners=False)

        _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
        _c1 = F.interpolate(_c1 , size=c0.size()[2:], scale_factor=None, mode='bilinear',align_corners=False)
        
        _c0 =  self.linear_c0(c0).permute(0,2,1).reshape(n, -1, c0.shape[2], c0.shape[3])
        
        print('one', _c0.shape, _c1.shape, _c2.shape, _c3.shape, _c4.shape)
        
        cc =  torch.cat([_c4, _c3, _c2, _c1, _c0], dim=1)
        print(cc.shape)

        _c = self.conv_fuse(cc)

        x = self.dropout(_c)
        x = self.linear_pred(x)

        return x

In [11]:
S_Head = SegFormerHead( z_dim = 8, in_channels =[  64, 128, 256, 512] , embedding_dim=128 ).to(DEVICE)

In [12]:
mlp_out = S_Head(*pvt_outs)

torch.Size([5, 8, 256, 256]) torch.Size([5, 64, 64, 64]) torch.Size([5, 128, 32, 32]) torch.Size([5, 256, 16, 16]) torch.Size([5, 512, 8, 8])
one torch.Size([5, 128, 256, 256]) torch.Size([5, 128, 256, 256]) torch.Size([5, 128, 256, 256]) torch.Size([5, 128, 256, 256]) torch.Size([5, 128, 256, 256])
torch.Size([5, 640, 256, 256])


In [13]:
print(mlp_out.shape)

torch.Size([5, 1, 256, 256])
