## Vesuvius Challenge - Ink Detection Training Notebook

Summary:
- Model training uses pre-trained weights 
- Training on fragment 2 & 3, validation on fragment 1

### Setup

In [None]:
%%capture
!pip install segmentation_models_pytorch

# Pretrained weights
# ref - https://github.com/kenshohara/3D-ResNets-PyTorch
!pip install gdown
!gdown 1Nb4abvIkkp_ydPFA9sNPT1WakoVKA8Fa

# Utility packages for reading and visualizing volumes
!pip install zarr imageio-ffmpeg

# save model checkpoints
!mkdir ./ckpts

In [None]:
import glob
import os
import gc
import sys
import zarr
import random
import imageio
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import Video

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda import amp
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations import ImageOnlyTransform

import segmentation_models_pytorch as smp

sys.path.append("/kaggle/input/resnet3d")
from resnet3d import generate_model

np.random.seed(42)

### Config

In [None]:
TRAIN_FRAGMENTS = ["2", "3"]
TEST_FRAGMENT = "1"
    
class ModelConfig:
    # model
    crop_size_scaling=2
    crop_size = 256*crop_size_scaling
    z_start = 24
    z_dims = 16
    
    # training
    init_lr = 1e-4
    batch_size = int(32/(crop_size_scaling**2))
    epochs = 50
    
    # augmentation
    train_aug_list = [
        #A.PadIfNeeded(min_height=crop_size, min_width=crop_size, position="top_left"),
        A.ToFloat(),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ShiftScaleRotate(p=0.75),
        A.CoarseDropout(max_holes=1, max_width=int(crop_size * 0.1), max_height=int(crop_size * 0.1)),
        A.ChannelDropout(p=0.2),
        ToTensorV2()
    ]

    valid_aug_list = [
        #A.PadIfNeeded(min_height=crop_size, min_width=crop_size, position="top_left"),
        A.ToFloat(),
        ToTensorV2(),
    ]
    
    #inference
    tta = None # rotate, flip

def get_transforms(aug_list):
aug = A.Compose(aug_list)
return aug

### Load data

In [None]:
FRAGMENTS_ZARR = {
    "1" : zarr.open("/kaggle/input/vesuvius-zarr-files/train-1.zarr", mode="r"),
    "2" : zarr.open("/kaggle/input/vesuvius-zarr-files/train-2.zarr", mode="r"),
    "3" : zarr.open("/kaggle/input/vesuvius-zarr-files/train-3.zarr", mode="r")
}

FRAGMENTS_SHAPE = {k : v.mask.shape for k, v in FRAGMENTS_ZARR.items()}

### Dataloaders

In [None]:
class VesuviusTrain(Dataset):
    def __init__(self, fragments, cfg):
        self.fragments = fragments
        self.xys = []
        self.transform = get_transforms(cfg.train_aug_list)
        self.cfg = cfg
        
        for fragment in fragments:
            H, W = FRAGMENTS_SHAPE[fragment]
            for y in range(0, H-cfg.crop_size+1, cfg.crop_size):
                for x in range(0, W-cfg.crop_size+1, cfg.crop_size):
                    self.xys.append((fragment, x, y, W, H))
        
    def __getitem__(self, i):
        fragment, x1, y1, W, H = self.xys[i]
        z1, z2 = self.cfg.z_start, self.cfg.z_start+self.cfg.z_dims
        
        x_offset = random.randint(-32 if x1 != 0 else 0, 32)
        y_offset = random.randint(-32 if y1 != 0 else 0, 32)
        
        x1 += x_offset
        y1 += y_offset
        
        x2 = x1 + self.cfg.crop_size
        y2 = y1 + self.cfg.crop_size
        
        if x2 > W:
            x1 -= x_offset
            x2 -= x_offset
            
        if y2 > H:
            y1 -= y_offset
            y2 -= y_offset
        
        frag_crop = FRAGMENTS_ZARR[fragment].surface_volume[y1:y2, x1:x2, z1:z2]
        mask_crop = FRAGMENTS_ZARR[fragment].truth[y1:y2, x1:x2]
        if self.transform is not None:
            data = self.transform(image=frag_crop.astype(np.float32), mask=mask_crop.astype(np.float32))
            image = data['image']
            label = data['mask']
            image = image/65535.0
            image = (image - 0.45)/0.225
            image, label = torch.unsqueeze(image, 0), torch.unsqueeze(label, 0)
            return image, label
        else:
            if random.random() > 0.5:
                frag_crop = np.flip(frag_crop, axis=1).copy()
                mask_crop = np.flip(mask_crop, axis=1).copy()

            frag_crop = torch.from_numpy(frag_crop.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
            frag_crop = frag_crop/65535.0
            frag_crop = (frag_crop - 0.45)/0.225

            mask_crop = torch.from_numpy(mask_crop.astype(np.float32)).unsqueeze(0)
            return frag_crop, mask_crop

    def __len__(self):
        return len(self.xys)

In [None]:
class VesuviusVal(Dataset):
    def __init__(self, fragment, cfg):
        self.fragment = FRAGMENTS_ZARR[fragment]
        self.xys = []
        self.transform = None #get_transforms(cfg.valid_aug_list)
        self.cfg=cfg
        
        H, W = FRAGMENTS_SHAPE[fragment]
        for y in range(0, H-cfg.crop_size+1, cfg.crop_size):
            for x in range(0, W-cfg.crop_size+1, cfg.crop_size):
                self.xys.append((x, y))
                
    def __getitem__(self, i):
        x1, y1 = self.xys[i]
        x2, y2 = x1+self.cfg.crop_size, y1+self.cfg.crop_size
        z1, z2 = self.cfg.z_start, self.cfg.z_start+self.cfg.z_dims
        
        frag_crop = self.fragment.surface_volume[y1:y2, x1:x2, z1:z2]
        mask_crop = self.fragment.truth[y1:y2, x1:x2]
        
        if self.transform is not None:
            data = self.transform(image=frag_crop, mask=mask_crop.astype(np.float32))
            image = data['image']
            label = data['mask']
            image = image/65535.0
            image = (image - 0.45)/0.225
            image, label = torch.unsqueeze(image, 0), torch.unsqueeze(label, 0)
            return image, label, torch.tensor([x1, y1, x2, y2], dtype=torch.int32)
        else:
            frag_crop = torch.from_numpy(frag_crop.astype(np.float32)).unsqueeze(0).permute(0, 3, 1, 2)
            frag_crop = frag_crop/65535.0
            frag_crop = (frag_crop - 0.45)/0.225

            mask_crop = torch.from_numpy(mask_crop.astype(np.float32)).unsqueeze(0)
            return frag_crop, mask_crop, torch.tensor([x1, y1, x2, y2], dtype=torch.int32)

    def __len__(self):
        return len(self.xys)

In [None]:
# Create data loaders
dataset_train = VesuviusTrain(TRAIN_FRAGMENTS, ModelConfig)
dataloader_train = DataLoader(dataset_train, batch_size=ModelConfig.batch_size, num_workers=2,
                              shuffle=True, pin_memory=True, drop_last=True)
n_train = len(dataloader_train)

dataset_valid = VesuviusVal(TEST_FRAGMENT, ModelConfig)
dataloader_valid = DataLoader(dataset_valid, batch_size=ModelConfig.batch_size, num_workers=2,
                              shuffle=False, pin_memory=True, drop_last=False)
n_valid = len(dataloader_valid)

### Model
* Encoder is a 3D ResNet model with 18 layers. The architecture has been modified to remove temporal downsampling between blocks.
* A 2D decoder is used for predicting the segmentation map.
* The encoder feature maps are average pooled over the Z dimension before passing it to the decoder.

In [None]:
class Decoder(nn.Module):
    def __init__(self, encoder_dims, upscale):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
                nn.BatchNorm2d(encoder_dims[i-1]),
                nn.ReLU(inplace=True)
            ) for i in range(1, len(encoder_dims))])

        self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
        self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")

    def forward(self, feature_maps):
        for i in range(len(feature_maps)-1, 0, -1):
            f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
            f = torch.cat([feature_maps[i-1], f_up], dim=1)
            f_down = self.convs[i-1](f)
            feature_maps[i-1] = f_down

        x = self.logit(feature_maps[0])
        mask = self.up(x)
        return mask


class SegModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = generate_model(model_depth=18, n_input_channels=1)
        self.decoder = Decoder(encoder_dims=[64, 128, 256, 512], upscale=4)
        
    def forward(self, x):
        feat_maps = self.encoder(x)
        feat_maps_pooled = [torch.mean(f, dim=2) for f in feat_maps]
        pred_mask = self.decoder(feat_maps_pooled)
        return pred_mask
    
    def load_pretrained_weights(self, state_dict):
        # Convert 3 channel weights to single channel
        # ref - https://timm.fast.ai/models#Case-1:-When-the-number-of-input-channels-is-1
        conv1_weight = state_dict['conv1.weight']
        state_dict['conv1.weight'] = conv1_weight.sum(dim=1, keepdim=True)
        print(self.encoder.load_state_dict(state_dict, strict=False))

In [None]:
model = SegModel()
model.load_pretrained_weights(torch.load("r3d18_K_200ep.pth")["state_dict"])
model = nn.DataParallel(model, device_ids=[0, 1])
model = model.cuda()

### Competition metric (F0.5 Score)

In [None]:
# ref - https://www.kaggle.com/competitions/vesuvius-challenge-ink-detection/discussion/397288
def fbeta_score(preds, targets, threshold, beta=0.5, smooth=1e-5):
    preds_t = torch.where(preds > threshold, 1.0, 0.0).float()
    y_true_count = targets.sum()
    
    ctp = preds_t[targets==1].sum()
    cfp = preds_t[targets==0].sum()
    beta_squared = beta * beta

    c_precision = ctp / (ctp + cfp + smooth)
    c_recall = ctp / (y_true_count + smooth)
    dice = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall + smooth)

    return dice

### Training

In [None]:
# Define loss, optimize and scheduler
DiceLoss = smp.losses.DiceLoss(mode='binary')
BCELoss = smp.losses.SoftBCEWithLogitsLoss()

alpha = 0.5
beta = 1 - alpha
TverskyLoss = smp.losses.TverskyLoss(
    mode='binary', log_loss=False, alpha=alpha, beta=beta)

def criterion(y_pred, y_true):
    #return BCELoss(y_pred, y_true)
    return 0.5 * BCELoss(y_pred, y_true) + 0.5 * DiceLoss(y_pred, y_true)
    #return DiceLoss(y_pred, y_true)

scaler = amp.GradScaler()
optimizer = torch.optim.AdamW(model.parameters(), lr=ModelConfig.init_lr)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=ModelConfig.init_lr,
                                                steps_per_epoch=10, epochs=ModelConfig.epochs//10,
                                                pct_start=0.1)

In [None]:
#transform = get_transforms(ModelConfig.valid_aug_list)
#mask = np.array(FRAGMENTS_ZARR[TEST_FRAGMENT].truth.astype(np.float32))
#data = transform(image = np.zeros_like(mask), mask=mask)
#gt_mask = data['mask']

gt_mask = torch.from_numpy(np.asarray(FRAGMENTS_ZARR[TEST_FRAGMENT].truth)).float().cuda()
gt_shape = FRAGMENTS_SHAPE[TEST_FRAGMENT]

In [None]:
def validate(model, dataloader_valid):
    mloss_val = 0.0
    model.eval()
    pbar_val = enumerate(dataloader_valid)
    pbar_val = tqdm(pbar_val, total=n_valid, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    final_pred_mask = torch.zeros(gt_shape, dtype=torch.float32, device='cuda')

    for i, (fragments, masks, xys) in pbar_val:
        fragments, masks = fragments.cuda(), masks.cuda()
        with torch.no_grad():
            if tta:
                preds = []
                for fragment in fragments:
                    x = fragment
                    if tta == "flip":
                        x = [x, torch.flip(x, (-1,)), torch.flip(x, (-2,), torch.flip(x, (-2,-1)))]
                    elif tta == "rotate":
                        x=[torch.rot90(x, k=i, dims=(-2,-1)) for i in range(4)]
                    x=torch.stack(x,dim=0)
                    x=model(x)
                    x=torch.sigmoid(x)
                    if tta == "flip":
                        x = [x[0], torch.flip(x[1], (-1,)), torch.flip(x[2], (-2,)), torch.flip(x[2], (-2,-1))]
                    elif tta == "rotate":
                        x=[torch.rot90(x[i],k=-i,dims=(-2,-1)) for i in range(4)]
                    x=torch.stack(x,dim=0)
                    preds.append(x.mean(0))
                pred_masks = torch.stack(preds,dim=0)
            else:
                pred_masks = model(fragments)
                mloss_val += criterion(pred_masks, masks).item()
                pred_masks = torch.sigmoid(pred_masks)

        for j, xy in enumerate(xys):
            final_pred_mask[xy[1]:xy[3], xy[0]:xy[2]] = pred_masks[j, 0]

        pbar_val.set_description(("%10s") % (f"Val Loss: {mloss_val / (i+1):.4f}"))
    return final_pred_mask, mloss_val

In [None]:
best_fbeta = 0
train_losses = []
val_losses = []
fbetas = []

for epoch in range(1, ModelConfig.epochs+1):
    model.train()
    cur_lr = f"LR : {scheduler.get_last_lr()[0]:.2E}"
    pbar_train = enumerate(dataloader_train)
    pbar_train = tqdm(pbar_train, total=n_train, bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
    mloss_train = 0.0

    for i, (fragments, masks) in pbar_train:
        fragments, masks = fragments.cuda(), masks.cuda()
        optimizer.zero_grad()
        with amp.autocast():
            pred_masks = model(fragments)
            loss = criterion(pred_masks, masks)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            mloss_train += loss.detach().item()

        gpu_mem = f"Mem : {torch.cuda.memory_reserved() / 1E9:.3g}GB"
        pbar_train.set_description(("%10s  " * 3 + "%10s") % (f"Epoch {epoch}/{ModelConfig.epochs}", gpu_mem, cur_lr,
                                                              f"Loss: {mloss_train / (i + 1):.4f}"))
        
    scheduler.step()
    final_pred_mask, mloss_val = validate(model, dataloader_valid)
    plt.imshow(final_pred_mask.cpu().numpy())
    plt.show()
    
    fbeta_ = 0
    for threshold in np.arange(0.2, 0.85, 0.05):
        fbeta = fbeta_score(final_pred_mask, gt_mask, threshold)
        fbeta_ = max(fbeta, fbeta_)
        print(f"Threshold : {threshold:.2f}\tFBeta : {fbeta:.6f}")
        
    # save losses and metrics
    train_losses.append(mloss_train)
    val_losses.append(mloss_val)
    fbetas.append(fbeta_.item())
    
    if fbeta_ > best_fbeta:
        best_fbeta = fbeta_
        torch.save(model.module.state_dict(), f"./ckpts/resnet18_3d_seg_{epoch}_{best_fbeta:.2f}.pt")

In [None]:
# Plot training curve
plt.plot(np.array(train_losses)/len(dataloader_train), color="blue", label="train_loss")
plt.plot(np.array(val_losses)/len(dataloader_valid), color="orange", label="val_loss")
plt.plot(fbetas, color="cyan", label="fbeta")
plt.legend()
plt.show()

In [None]:
# Load best model
checkpoints = glob.glob("/kaggle/working/ckpts/*.pt")
checkpoints.sort(key=os.path.getmtime)
model_ckpt = checkpoints[-1]
print(model_ckpt)
#model_ckpt  = "/kaggle/working/ckpts/resnet18_3d_seg_12_0.56.pt"
checkpoint = torch.load(model_ckpt)
model.module.load_state_dict(checkpoint)

In [None]:
# Get optimal theshold
final_pred_mask, _ = validate(model, dataloader_valid, ModelConfig.tta)
opt_f, opt_t = 0, 0
for threshold in np.arange(0.2, 1.0, 0.05):
    fbeta = fbeta_score(final_pred_mask, gt_mask, threshold)
    if fbeta > opt_f:
        opt_f, opt_t = fbeta, threshold
    print(f"{threshold:.2f}: {fbeta.item():.2f}")
final_pred_mask = final_pred_mask.cpu().numpy()
thresholded = np.zeros_like(final_pred_mask)
thresholded[final_pred_mask >= opt_t] = 1

In [None]:
np_gt_mask = gt_mask.cpu().numpy()

# Plot final model predictions
fig, (ax0, ax1, ax2) = plt.subplots(1, 3)
ax0.imshow(final_pred_mask)
ax1.imshow(thresholded)
ax2.imshow(np_gt_mask)
fig.suptitle(f'T: {opt_t:.2f} F0.5: {opt_f.item():.2f}')

plt.show()

In [None]:
# Uncomment to remove checkpoints if file persistence is on.
# !rm /kaggle/working/ckpts/*