In [1]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

In [2]:
import torch
import monai
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
#import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
import torch.nn as nn
from functools import partial
import torchvision
import torch.nn.functional as F
from lit_models.scratch_models import FPNDecoder
from Models.PreBackbone_3D import PreBackbone_3D


2023-05-21 22:34:51,155 - Created a temporary directory at /tmp/tmphxo__17v
2023-05-21 22:34:51,157 - Writing /tmp/tmphxo__17v/_remote_module_non_scriptable.py


In [3]:
PATCH_SIZE = 512
Z_DIM = 48
COMPETITION_DATA_DIR_str =  "kaggle/input/vesuvius-challenge-ink-detection/"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# change to the line below if not using Apple's M1 or chips
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class PVT_w_FPN(nn.Module):
    def __init__(self, in_channels = Z_DIM,  embed_dims=[ 64, 128, 256, 512], n_classes=1, ):
        super().__init__()
        
        self.embed_dims = embed_dims
        
       
        self.pvt = PyramidVisionTransformerV2(img_size = PATCH_SIZE,
                                  patch_size = 4,
                                  in_chans = Z_DIM,
                                  num_classes = 1,
                                  embed_dims = embed_dims,
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[8, 8, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-3),
                                #norm_layer=nn.LayerNorm,          
                                  depths=[2, 2, 2,2],
                                  sr_ratios=[4, 4, 2, 1]
                                 ).to(DEVICE) 
        
        self.FPN = FPNDecoder(
                            in_channels = Z_DIM,
                            encoder_channels = embed_dims ,
                            encoder_depth=5,
                            pyramid_channels=256,
                            segmentation_channels=128,
                            dropout=0.,
                            merge_policy="cat",).to(DEVICE) 
        
       

    def forward(self, x):
        #x = x.unsqueeze(1)
        #x = self.pre_model3d(x)
        #x = x.squeeze(1)
        
        pvt_outs = self.pvt(x)
        
        logits = self.FPN(*pvt_outs)
        
       
       
            
        
        return logits












In [4]:
class Model_3d_w_Swin(nn.Module):
    def __init__(self ):
        
        super().__init__()
        
        self.model_3d = PreBackbone_3D().to(DEVICE) 
        
       
        self.model_2d = smp.FPN(encoder_name='mit_b3',
                                encoder_depth=5,
                                encoder_weights='imagenet',
                                decoder_pyramid_channels=256, 
                                decoder_segmentation_channels=128,
                                decoder_merge_policy='cat',
                                decoder_dropout=0.,
                                in_channels=3,
                                classes=1, activation=None, upsampling=4, aux_params=None).to(DEVICE) 
      
       

    def forward(self, x):
      
        outs_3d = self.model_3d(x)
        logits = self.model_2d(outs_3d)
        
       
       
            
        
        return logits












In [5]:


class CFG:
    
    device = DEVICE
    
    THRESHOLD = 0.4
    use_wandb = True
    
    ######### Dataset #########
    
    # stage: 'train' or 'test'
    stage = 'train' 
    
    # location of competition Data
    competition_data_dir = COMPETITION_DATA_DIR_str
    
    # Number of slices in z-dim: 1<z_dim<65
    z_dim = Z_DIM
    
    # fragments to use for training avalaible [1,2,3]
    train_fragment_id=[2,3]
    
    # fragments to use for validation
    val_fragment_id=[1]
    
    

    
    
    batch_size = 4
    
    # Size of the patch and stride for feeding the model
    patch_size = PATCH_SIZE
    stride = patch_size // 2
    
    
    num_workers = 8
    on_gpu = True
    
    
    ######## Model and Lightning Model paramters ############
    
    # MODEL
    model = Model_3d_w_Swin().to(DEVICE) 
    
    
    
    
    checkpoint = None
    save_directory = None
    
    
    accumulate_grad_batches = 160 // batch_size  # experiments showed batch_size * accumulate_grad = 192 is optimal
    learning_rate = 0.00002
    eta_min = 1e-8
    t_max = 50
    max_epochs = 120
    weight_decay =  0.001
    precision =16
    
    # checkpointing
    save_top_k=5
    
    monitor="FBETA"
    mode="max"
    
    
    ####### Augemtnations ###############
    
    # Training Aug
    train_transforms = [
        # A.RandomResizedCrop(
        #     size, size, scale=(0.85, 1.0)),
        A.Resize(patch_size, patch_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        
       
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(patch_size * 0.3), max_height=int(patch_size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * z_dim,
            std= [1] * z_dim
        ),
        ToTensorV2(transpose_mask=True),
    ]
    

    
    # Validaiton Aug
    val_transforms = [
        A.Resize(patch_size, patch_size),
        A.Normalize(
            mean= [0] * z_dim,
            std= [1] * z_dim
        ),
        ToTensorV2(transpose_mask=True),
    ]
    
    # Test Aug
    test_transforms = [
        A.Resize(patch_size, patch_size),
        A.Normalize(
            mean=[0] * z_dim,
            std=[1] * z_dim
        ),

        ToTensorV2(transpose_mask=True),
    ]
        
    
    

model = PVT_w_FPN(in_channels = Z_DIM)
dummy = torch.randn(4,Z_DIM,256,256).to(DEVICE)
pvt_outs = model.pvt(dummy)
fpn_outs = model.FPN(*pvt_outs)


In [6]:
dataset = Vesuvius_Tile_Datamodule(cfg=CFG)#.to(DEVICE)

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

  0%|          | 0/48 [00:00<?, ?it/s]

In [7]:
lit_model = Lit_Model(cfg=CFG,).to(DEVICE) 

Checkpoint = True
if Checkpoint:
    lit_model = lit_model.load_from_checkpoint('logs/Model_48_3d_nonorm_w_mitb3_bce50_05tver60/last.ckpt',
                                               #learning_rate =7e-6 ,
                                                #t_max = 70,
                                               #eta_min = 1e-8,
                                               #weight_decay =  0.0001,
                                              )


[34m[1mwandb[0m: Currently logged in as: [33mgmarus[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667010900000605, max=1.0)…

In [8]:
SAVE_DIR = 'logs/Model_48_3d_nonorm_w_mitb3_bce50_05tver60_cont512'

checkpoint_callback = ModelCheckpoint(
    save_top_k=5,
    monitor="FBETA",
    mode="max",
    dirpath=SAVE_DIR,
    filename="Model_48_3d_nonorm_w_mitb3_bce50_05tver60_cont512{epoch:02d}{FBETA:.2f}{val_loss:.2f}{fbeta_4:.2f}{recall:.2f}{precision:.2f}",
    save_last =True,
)

torch.autograd.set_detect_anomaly(True)
trainer = pl.Trainer(
        accelerator='gpu',
        #benchmark=True,
        max_epochs=CFG.max_epochs,
        check_val_every_n_epoch= 1,
        devices=1,
        #fast_dev_run=fast_dev_run,
        logger=pl.loggers.CSVLogger(save_dir=SAVE_DIR),
        log_every_n_steps=1,
        default_root_dir = SAVE_DIR,
        #overfit_batches=1,
        #precision= CFG.precision,
        accumulate_grad_batches=CFG.accumulate_grad_batches, 
        callbacks=[checkpoint_callback],
        #gradient_clip_val=1,
        #resume_from_checkpoint ='logs/gcp_checkpoints/MoUB4_Bce015_Tver_alpha085epoch_64.ckpt'
        #detect_anomaly=True
        
        )





trainer.fit(lit_model, datamodule=dataset,
            #ckpt_path='logs/gcp_checkpoints/MoUB4_Bce015_Tver_alpha085epoch_64.ckpt'
           )

2023-05-21 22:37:22,397 - GPU available: True (cuda), used: True
2023-05-21 22:37:22,398 - TPU available: False, using: 0 TPU cores
2023-05-21 22:37:22,399 - IPU available: False, using: 0 IPUs
2023-05-21 22:37:22,400 - HPU available: False, using: 0 HPUs
2023-05-21 22:37:22,585 - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Adjusting learning rate of group 0 to 2.0000e-05.
2023-05-21 22:37:22,604 - 
  | Name         | Type                  | Params
-------------------------------------------------------
0 | metrics      | ModuleDict            | 0     
1 | model        | Model_3d_w_Swin       | 46.0 M
2 | loss_tversky | TverskyLoss           | 0     
3 | loss_bce     | SoftBCEWithLogitsLoss | 0     
-------------------------------------------------------
46.0 M    Trainable params
0         Non-trainable params
46.0 M    Total params
183.848   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9980e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9921e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9823e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9686e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9511e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9298e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.9049e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.8764e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.8444e-05.


Validation: 0it [00:00, ?it/s]

Adjusting learning rate of group 0 to 1.8091e-05.


Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


# OLD CLUNKY VERSION

class PVT(nn.Module):
    def __init__(self, img_size = 256,in_channels =16, embed_dim =96  ):
        
        
        super().__init__()
        
        self.embed_dim =embed_dim
        self.pvt =  PyramidVisionTransformerV2(img_size=img_size,
                                  patch_size=4,
                                  in_chans=in_channels,
                                  num_classes=1,
                                  embed_dims=[ 64, 128, 256, 512],
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[4, 4, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[2, 2, 2, 2],
                                  sr_ratios=[8, 4, 2, 1]
                                 )
        
        self.head = SegmentationHead(in_channels_list=[16, 64, 128, 256, 512], out_channels=1 ) # math them with embed_dims + original channel first
  
        

        
    def forward(self, x):
        # pass through PVT
        pvt_outs = self.pvt(x) # outputs 5 tensors
        
        final_outs = self.head(pvt_outs)
        
        
        return final_outs
        
        
class SegmentationHead(nn.Module):
    def __init__(self, in_channels_list, out_channels):
        super(SegmentationHead, self).__init__()
        self.convs = nn.ModuleList([nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        ) for in_channels in in_channels_list])
        self.final_conv = nn.Sequential(
            nn.Conv2d(len(in_channels_list) * out_channels, out_channels, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    
    def forward(self, features):
        upsampled_features = [F.interpolate(feature, size=(256, 256), mode='bilinear', align_corners=False) for feature in features]
        conv_features = [self.convs[i](feature) for i, feature in enumerate(upsampled_features)]
        concatenated_features = torch.cat(conv_features, dim=1)
        output = self.final_conv(concatenated_features)
        return output


class PVT_w_UNet(nn.Module):
    def __init__(self, in_channels,  embed_dims=[ 64, 128, 256, 512], n_classes=1, ):
        super().__init__()
        
        self.embed_dims = embed_dims
        
       
        self.pvt = PyramidVisionTransformerV2(img_size = PATCH_SIZE,
                                  patch_size = 4,
                                  in_chans = Z_DIM,
                                  num_classes = 1,
                                  embed_dims = embed_dims,
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[8, 8, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-3),
                                #norm_layer=nn.LayerNorm,          
                                  depths=[2, 2, 2,2],
                                  sr_ratios=[1, 1, 1, 1]
                                 ).to(DEVICE) 
        
        self.up1 = Up(self.embed_dims[-1], self.embed_dims[-2])
        self.up2 = Up(self.embed_dims[-2], self.embed_dims[-3])
        self.up3 = Up(self.embed_dims[-3], self.embed_dims[-4])
        self.up4 = Up(self.embed_dims[-4], in_channels, last_layer = True)
        
        self.out_conv = OutConv(in_channels,n_classes)
        

    def forward(self, x):
        #x = x.unsqueeze(1)
        #x = self.pre_model3d(x)
        #x = x.squeeze(1)
        
        x1, x2, x3, x4, x5 = self.pvt(x)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        logits = self.out_conv(x)
        
        
        return logits










