In [1]:
# Here we take care of paths.
# Make sure root project directory is named 'VESUVIUS_Challenge' for this to work

from pathlib import Path
import os
print('Starting path:' + os.getcwd())
if os.getcwd()[-18:] == 'VESUVIUS_Challenge':
    pass
else:
    PATH = Path().resolve().parents[0]
    os.chdir(PATH)

# make sure you are in the root folder of the project
print('Current path:' + os.getcwd())

Starting path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge/jupyter notebooks
Current path:/Users/gregory/PROJECT_ML/VESUVIUS_Challenge


In [2]:
import torch
import monai
from monai.visualize import matshow3d
import segmentation_models_pytorch as smp
import matplotlib.pyplot as plt
import tempfile
import shutil
import os
import glob
import cv2
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from typing import Tuple, List
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from Data_Modules.Vesuvius_Dataset import Vesuvius_Tile_Datamodule
from lit_models.Vesuvius_Lit_Model import Lit_Model
from pytorch_lightning.callbacks import ModelCheckpoint
import torch.nn as nn
from Models.PVT2 import PyramidVisionTransformerV2, Up, OutConv
import torch.nn as nn
from functools import partial
import torchvision
import torch.nn.functional as F
from Models.Swin import SwinTransformer, SwinTransformerBlockV2, PatchMergingV2


2023-05-17 12:58:08,869 - Created a temporary directory at /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmp8554pxa0
2023-05-17 12:58:08,870 - Writing /var/folders/wc/60y8v25x3ns_jgsx6clbdb180000gn/T/tmp8554pxa0/_remote_module_non_scriptable.py


In [3]:
PATCH_SIZE = 256
Z_DIM = 16
COMPETITION_DATA_DIR_str =  "kaggle/input/vesuvius-challenge-ink-detection/"


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps")

# change to the line below if not using Apple's M1 or chips
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
class PVT_w_UNet(nn.Module):
    def __init__(self, in_channels,  embed_dims=[ 64, 128, 256, 512], n_classes=1, ):
        super().__init__()
        
        self.embed_dims = embed_dims
        
        self.pvt = PyramidVisionTransformerV2(img_size = PATCH_SIZE,
                                  patch_size = 4,
                                  in_chans = Z_DIM,
                                  num_classes = 1,
                                  embed_dims = embed_dims,
                                num_heads=[1, 2, 4, 8],
                                  mlp_ratios=[8, 8, 4, 4],
                                  qkv_bias=True,
                                  qk_scale=None,
                                  drop_rate=0.,
                                attn_drop_rate=0.,
                                  drop_path_rate=0.1,
                                  norm_layer=partial(nn.LayerNorm, eps=1e-6),
                                  depths=[3, 4, 6, 3],
                                  sr_ratios=[8, 4, 2, 1]
                                 ) 
        
        self.up1 = Up(self.embed_dims[-1], self.embed_dims[-2])
        self.up2 = Up(self.embed_dims[-2], self.embed_dims[-3])
        self.up3 = Up(self.embed_dims[-3], self.embed_dims[-4])
        self.up4 = Up(self.embed_dims[-4], in_channels, last_layer = True)
        
        self.out_conv = OutConv(in_channels,n_classes)
        

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.pvt(x)
        
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        logits = self.out_conv(x)
        
        
        return logits












In [15]:
test_model = PVT_w_UNet(in_channels = 16)
loss_bce = smp.losses.SoftBCEWithLogitsLoss()

In [28]:
dummy = torch.randn(5,16,256,256)
dummy_y = torch.zeros(5,1,256,256)
#dummy_y[:, :, 10:20, :] =1

In [29]:
final_outs = test_model(dummy)

In [20]:
final_outs.shape


torch.Size([5, 1, 256, 256])

In [30]:
loss = loss_bce(final_outs, dummy_y)
print(loss)

tensor(0.8460, grad_fn=<MeanBackward0>)


In [5]:


class CFG:
    
    device = DEVICE
    
    THRESHOLD = 0.4
    use_wandb = True
    
    ######### Dataset #########
    
    # stage: 'train' or 'test'
    stage = 'train' 
    
    # location of competition Data
    competition_data_dir = COMPETITION_DATA_DIR_str
    
    # Number of slices in z-dim: 1<z_dim<65
    z_dim = Z_DIM
    
    # fragments to use for training avalaible [1,2,3]
    train_fragment_id=[2,3]
    
    # fragments to use for validation
    val_fragment_id=[1]
    
    

    
    
    batch_size = 16
    
    # Size of the patch and stride for feeding the model
    patch_size = PATCH_SIZE
    stride = patch_size // 2
    
    
    num_workers = 0
    on_gpu = True
    
    
    ######## Model and Lightning Model paramters ############
    
    # MODEL
    model = PVT_w_UNet(in_channels = z_dim)
    
    
    
    
    checkpoint = None
    save_directory = None
    
    
    accumulate_grad_batches = 128 // batch_size  # experiments showed batch_size * accumulate_grad = 192 is optimal
    learning_rate = 0.00005
    eta_min = 1e-8
    t_max = 80
    max_epochs = 120
    weight_decay =  0.00001
    precision =16
    
    # checkpointing
    save_top_k=5
    
    monitor="FBETA"
    mode="max"
    
    
    ####### Augemtnations ###############
    
    # Training Aug
    train_transforms = [
        # A.RandomResizedCrop(
        #     size, size, scale=(0.85, 1.0)),
        A.Resize(patch_size, patch_size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        
       
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(patch_size * 0.3), max_height=int(patch_size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * z_dim,
            std= [1] * z_dim
        ),
        ToTensorV2(transpose_mask=True),
    ]
    

    
    # Validaiton Aug
    val_transforms = [
        A.Resize(patch_size, patch_size),
        A.Normalize(
            mean= [0] * z_dim,
            std= [1] * z_dim
        ),
        ToTensorV2(transpose_mask=True),
    ]
    
    # Test Aug
    test_transforms = [
        A.Resize(patch_size, patch_size),
        A.Normalize(
            mean=[0] * z_dim,
            std=[1] * z_dim
        ),

        ToTensorV2(transpose_mask=True),
    ]
        
    
    

In [6]:
dataset = Vesuvius_Tile_Datamodule(cfg=CFG)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [7]:
lit_model = Lit_Model(cfg=CFG,)

Checkpoint = False
if Checkpoint:
    lit_model = lit_model.load_from_checkpoint('logs/gcp_checkpoints/MoUB4_Bce015_Tver_alpha085epoch_64.ckpt',
                                               #learning_rate =7e-6 ,
                                                #t_max = 70,
                                               #eta_min = 1e-8,
                                               #weight_decay =  0.0001,
                                              )


[34m[1mwandb[0m: Currently logged in as: [33mgmarus[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016751448600552978, max=1.0…

In [8]:
SAVE_DIR = 'logs/PVT_Unet'

checkpoint_callback = ModelCheckpoint(
    save_top_k=5,
    monitor="FBETA",
    mode="max",
    dirpath=SAVE_DIR,
    filename="PVT_Unet{epoch:02d}{FBETA:.2f}{val_loss:.2f}{fbeta_4:.2f}{recall:.2f}{precision:.2f}",
    save_last =True,
)


trainer = pl.Trainer(
        accelerator='mps',
        #benchmark=True,
        max_epochs=CFG.max_epochs,
        check_val_every_n_epoch= 1,
        devices=1,
        #fast_dev_run=fast_dev_run,
        logger=pl.loggers.CSVLogger(save_dir=SAVE_DIR),
        log_every_n_steps=1,
        default_root_dir = SAVE_DIR,
        #overfit_batches=1,
        #precision=CFG.precision,
        accumulate_grad_batches=CFG.accumulate_grad_batches, 
        callbacks=[checkpoint_callback],
        gradient_clip_val=1,
        #resume_from_checkpoint ='logs/gcp_checkpoints/MoUB4_Bce015_Tver_alpha085epoch_64.ckpt'
        
        )





trainer.fit(lit_model, datamodule=dataset,
            #ckpt_path='logs/gcp_checkpoints/MoUB4_Bce015_Tver_alpha085epoch_64.ckpt'
           )

2023-05-17 12:58:39,406 - GPU available: True (mps), used: True
2023-05-17 12:58:39,407 - TPU available: False, using: 0 TPU cores
2023-05-17 12:58:39,408 - IPU available: False, using: 0 IPUs
2023-05-17 12:58:39,408 - HPU available: False, using: 0 HPUs
Adjusting learning rate of group 0 to 5.0000e-05.
2023-05-17 12:58:39,632 - 
  | Name                  | Type                  | Params
----------------------------------------------------------------
0 | metrics               | ModuleDict            | 0     
1 | model                 | PVT_w_UNet            | 24.0 M
2 | loss_dice             | DiceLoss              | 0     
3 | loss_tversky          | TverskyLoss           | 0     
4 | loss_focal            | FocalLoss             | 0     
5 | loss_bce              | SoftBCEWithLogitsLoss | 0     
6 | loss_monai_focal_dice | DiceFocalLoss         | 0     
----------------------------------------------------------------
24.0 M    Trainable params
0         Non-trainable params
24.0 M  

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  tp = (output * target).sum(2)
  if ignore_index is None and target.min() < 0:
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
