## summary

* 2.5d segmentation
    *  segmentation_models_pytorch 
    *  Unet
* use only 3 slices
* slide inference
* add rotate TTA

In [None]:
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
import sys
import time
import torch as tc
import random
from torch.utils.data import DataLoader, Dataset
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.optim import AdamW

from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import cv2,gc
import os,warnings
import pandas as pd


In [None]:
sys.path.append('/kaggle/input/pretrainedmodels/pretrainedmodels-0.7.4')
sys.path.append('/kaggle/input/efficientnet-pytorch/EfficientNet-PyTorch-master')
sys.path.append('/kaggle/input/timm-pytorch-image-models/pytorch-image-models-master')
sys.path.append('/kaggle/input/d/vad13irt/segmentation-models-pytorch')

import segmentation_models_pytorch as smp

## config

In [None]:
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

class CFG:
    # ============== comp exp name =============
    comp_name = 'vesuvius'

    # comp_dir_path = './'
    comp_dir_path = '/kaggle/input/'
    comp_folder_name = 'vesuvius-challenge-ink-detection'
    # comp_dataset_path = f'{comp_dir_path}datasets/{comp_folder_name}/'
    comp_dataset_path = f'{comp_dir_path}{comp_folder_name}/'
    
    exp_name = 'vesuvius_2d_slide_exp012'

    # ============== pred target =============
    target_size = 1
    TTA=True
    
    # ============== model cfg =============
    model_name = 'Unet'
#     backbone = 'efficientnet-b0'
    backbone = 'mit_b3'

    in_chans = 3 # 65
    # ============== training cfg =============
    size = 224
    tile_size = 224
    stride = tile_size // 4

    batch_size = 64 # 32
    use_amp = True

    scheduler = 'GradualWarmupSchedulerV2'
    # scheduler = 'CosineAnnealingLR'
    epochs = 15

    warmup_factor = 10
    lr = 1e-4 / warmup_factor

    # ============== fold =============
    valid_id = 2

    objective_cv = 'binary'  # 'binary', 'multiclass', 'regression'
    metric_direction = 'maximize'  # maximize, 'minimize'
    # metrics = 'dice_coef'

    # ============== fixed =============
    pretrained = True
    inf_weight = 'best'  # 'best'

    min_lr = 1e-6
    weight_decay = 1e-6
    max_grad_norm = 1000

    print_freq = 50
    num_workers = 4

    seed = 42

    # ============== augmentation =============
    train_aug_list = [
        # A.RandomResizedCrop(
        #     size, size, scale=(0.85, 1.0)),
        A.Resize(size, size),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.75),
        A.ShiftScaleRotate(p=0.75),
        A.OneOf([
                A.GaussNoise(var_limit=[10, 50]),
                A.GaussianBlur(),
                A.MotionBlur(),
                ], p=0.4),
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
        A.CoarseDropout(max_holes=1, max_width=int(size * 0.3), max_height=int(size * 0.3), 
                        mask_fill_value=0, p=0.5),
        # A.Cutout(max_h_size=int(size * 0.6),
        #          max_w_size=int(size * 0.6), num_holes=1, p=1.0),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]

    valid_aug_list = [
        A.Resize(size, size),
        A.Normalize(
            mean= [0] * in_chans,
            std= [1] * in_chans
        ),
        ToTensorV2(transpose_mask=True),
    ]


In [None]:
IS_DEBUG = False
mode = 'train' if IS_DEBUG else 'test'
TH = 0.5

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## helper

In [None]:
# ref.: https://www.kaggle.com/stainsby/fast-tested-rle
def rle(img):
    '''
    img: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = img.flatten()
    # pixels = (pixels >= thr).astype(int)
    
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

## dataset

In [None]:
def read_image(fragment_id):
    images = []

    # idxs = range(65)
    mid = 65 // 2
    start = 28
    end = 31
    idxs = range(start, end)

    for i in tqdm(idxs):
        
        image = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/surface_volume/{i:02}.tif", 0)

        pad0 = (CFG.tile_size - image.shape[0] % CFG.tile_size)
        pad1 = (CFG.tile_size - image.shape[1] % CFG.tile_size)

        image = np.pad(image, [(0, pad0), (0, pad1)], constant_values=0)

        images.append(image)
    images = np.stack(images, axis=2)
    
    return images

In [None]:
def get_transforms(data, cfg):
    if data == 'train':
        aug = A.Compose(cfg.train_aug_list)
    elif data == 'valid':
        aug = A.Compose(cfg.valid_aug_list)

    # print(aug)
    return aug

class CustomDataset(Dataset):
    def __init__(self, images, cfg, labels=None, transform=None):
        self.images = images
        self.cfg = cfg
        self.labels = labels
        self.transform = transform

    def __len__(self):
        # return len(self.xyxys)
        return len(self.images)

    def __getitem__(self, idx):
        # x1, y1, x2, y2 = self.xyxys[idx]
        image = self.images[idx]
        data = self.transform(image=image)
        image = data['image']
        return image


In [None]:
def make_test_dataset(fragment_id):
    test_images = read_image(fragment_id)
    
    x1_list = list(range(0, test_images.shape[1]-CFG.tile_size+1, CFG.stride))
    y1_list = list(range(0, test_images.shape[0]-CFG.tile_size+1, CFG.stride))
    
    test_images_list = []
    xyxys = []
    for y1 in y1_list:
        for x1 in x1_list:
            y2 = y1 + CFG.tile_size
            x2 = x1 + CFG.tile_size
            
            test_images_list.append(test_images[y1:y2, x1:x2])
            xyxys.append((x1, y1, x2, y2))
    xyxys = np.stack(xyxys)
            
    test_dataset = CustomDataset(test_images_list, CFG, transform=get_transforms(data='valid', cfg=CFG))
    
    test_loader = DataLoader(test_dataset,
                          batch_size=CFG.batch_size,
                          shuffle=False,
                          num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    return test_loader, xyxys

## model

In [None]:
class CustomModel(nn.Module):
    def __init__(self, cfg, weight=None):
        super().__init__()
        self.cfg = cfg

        self.encoder = smp.Unet(
            encoder_name=cfg.backbone, 
            encoder_weights=weight,
            in_channels=cfg.in_chans,
            classes=cfg.target_size,
            activation=None,
        )

    def forward(self, image):
        output = self.encoder(image)
        output = output.squeeze(-1)
        return output

def build_model(cfg, weight="imagenet"):
    print('model_name', cfg.model_name)
    print('backbone', cfg.backbone)

    model = CustomModel(cfg, weight)
    return model


In [None]:
class EnsembleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.ModuleList()
        for fold in [1,2,5,6,7,8,3,4,9,10,11,12,13,14,15,16,17,18,22,23,24]:
            if(fold==1):
                CFG.backbone = 'mit_b2'
                model_path = f'/kaggle/input/b2-mit/outputs/vesuvius/mitb2/vesuvius-models/Unet_fold1_best.pth'
            if(fold==2):
                CFG.backbone = 'mit_b3'
                model_path = f'/kaggle/input/b3-models/Unet_fold1_best.pth'
            if(fold==3):
                CFG.backbone = 'mit_b4'
                model_path = f'/kaggle/input/fork-of-using-mitb4-fold3/outputs/vesuvius/vesuvius_2d_slide_exp016/vesuvius-models/Unet_fold3_best.pth'
            if(fold==4):
                CFG.backbone = 'mit_b5'
                model_path = f'/kaggle/input/using-mitb5-fold3/outputs/vesuvius/vesuvius_2d_slide_exp016/vesuvius-models/Unet_fold3_best.pth'
            if(fold==5):
                CFG.backbone = 'mit_b2'
                model_path = f'/kaggle/input/b2-mit/outputs/vesuvius/mitb2/vesuvius-models/Unet_fold3_best.pth'
            if(fold==6):
                CFG.backbone = 'mit_b3'
                model_path = f'/kaggle/input/b3-models/Unet_fold3_best.pth'
            if(fold==7):
                CFG.backbone = 'mit_b2'
                model_path = f'/kaggle/input/b2-mit/outputs/vesuvius/mitb2/vesuvius-models/Unet_fold2_best.pth'
            if(fold==8):
                CFG.backbone = 'mit_b3'
                model_path = f'/kaggle/input/b3-models/Unet_fold2_best.pth'
            if(fold==9):
                CFG.backbone = 'mit_b5'
                model_path = f'/kaggle/input/using-mit-b5-fold1/outputs/vesuvius/vesuvius_2d_slide_exp0015/vesuvius-models/Unet_fold1_best.pth'
            if(fold==10):
                CFG.backbone = 'mit_b4'
                model_path = f'/kaggle/input/using-mit-b4-fold1/outputs/vesuvius/vesuvius_2d_slide_exp0015/vesuvius-models/Unet_fold1_best.pth'
            if(fold==11):
                CFG.backbone = 'mit_b4'
                model_path = f'/kaggle/input/mitb45/Unet_fold1_best.pth'
            if(fold==12):
                CFG.backbone = 'mit_b5'
                model_path = f'/kaggle/input/mitb45/Unet_fold1_best-b5.pth'
            if(fold==13):
                CFG.backbone = 'resnet50'
                model_path = f'/kaggle/input/resent50/outputs/vesuvius/resnet50/vesuvius-models/Unet_fold1_best.pth'
            if(fold==14):
                CFG.backbone = 'resnet50'
                model_path = f'/kaggle/input/resent50/outputs/vesuvius/resnet50/vesuvius-models/Unet_fold2_best.pth'
            if(fold==15):
                CFG.backbone = 'resnet50'
                model_path = f'/kaggle/input/resent50/outputs/vesuvius/resnet50/vesuvius-models/Unet_fold3_best.pth'
            if(fold==16):
                CFG.backbone = 'tu-regnety_064'
                model_path = f'/kaggle/input/tu-regnety-064/outputs/vesuvius/tu-regnety_064/vesuvius-models/Unet_fold1_best.pth'
            if(fold==17):
                CFG.backbone = 'tu-regnety_064'
                model_path = f'/kaggle/input/tu-regnety-064/outputs/vesuvius/tu-regnety_064/vesuvius-models/Unet_fold2_best.pth'
            if(fold==18):
                CFG.backbone = 'tu-regnety_064'
                model_path = f'/kaggle/input/tu-regnety-064/outputs/vesuvius/tu-regnety_064/vesuvius-models/Unet_fold3_best.pth'
            if(fold==19):
                CFG.backbone = 'tu-resnest50d_4s2x40d'
                model_path = f'/kaggle/input/tu-resnest50d-4s2x40d/outputs/vesuvius/tu-resnest50d_4s2x40d/vesuvius-models/Unet_fold1_best.pth'
            if(fold==20):
                CFG.backbone = 'tu-resnest50d_4s2x40d'
                model_path = f'/kaggle/input/tu-resnest50d-4s2x40d/outputs/vesuvius/tu-resnest50d_4s2x40d/vesuvius-models/Unet_fold2_best.pth'
            if(fold==21):
                CFG.backbone = 'tu-resnest50d_4s2x40d'
                model_path = f'/kaggle/input/tu-resnest50d-4s2x40d/outputs/vesuvius/tu-resnest50d_4s2x40d/vesuvius-models/Unet_fold3_best.pth'
            
            if(fold==22):
                CFG.backbone = 'resnet34'
                model_path = f'/kaggle/input/resnet34/outputs/vesuvius/resnet34/vesuvius-models/Unet_fold1_best.pth'
            if(fold==23):
                CFG.backbone = 'resnet34'
                model_path = f'/kaggle/input/resnet34/outputs/vesuvius/resnet34/vesuvius-models/Unet_fold2_best.pth'
            if(fold==24):
                CFG.backbone = 'resnet34'
                model_path = f'/kaggle/input/resnet34/outputs/vesuvius/resnet34/vesuvius-models/Unet_fold3_best.pth'
            

                
            _model = build_model(CFG, weight=None)
            _model.to(device)

            state = torch.load(model_path)['model']
            _model.load_state_dict(state)
            _model.eval()

            self.model.append(_model)
    
    def forward(self,x):
        output=[]
        for m in self.model:
            output.append(m(x))
        output=torch.stack(output,dim=0).mean(0)
        return output
        
    

In [None]:
def TTA(x:torch.Tensor,model:nn.Module):
    # x.shape=(batch,c,h,w)
    shape=x.shape
    rot = [1,3] # How much to rotate the fragments for TTA
    x=[torch.rot90(x,k=i,dims=(-2,-1)) for i in rot]
#     x = [x, torch.flip(x, dims=[2]), torch.flip(x, dims=[3])]
    x=torch.cat(x,dim=0)
    x=model(x)
#     x = torch.from_numpy(x).to(device)
    # print(type(x))
    x=x.reshape(len(rot),shape[0],1,*shape[-2:])
#     x=x.reshape(3,shape[0],1,*shape[-2:])
    x=[torch.rot90(x[count],k=-i,dims=(-2,-1)) for count, i in enumerate(rot)]
    
    x=torch.stack(x,dim=0)
#     print(x.shape)
    return x.mean(0)

# # #     return x.max(0).values#[:, 0, :, :]

def TTA2(x:torch.Tensor,model:nn.Module):
    # x.shape=(batch,c,h,w)
    shape=x.shape
    rot = [1, 3] # How much to rotate the fragments for TTA
#     x_rot1_3=[torch.rot90(x,k=i,dims=(-2,-1)) for i in rot]
#     x_rot1_3=torch.cat(x_rot1_3,dim=0)
#     x_rot1_3=model(x_rot1_3)
#     x_rot1_3 = torch.from_numpy(x_rot1_3).to(device)
#     x_rot1_3=x_rot1_3.reshape(len(rot),shape[0],1,*shape[-2:])
#     x_rot1_3=[torch.rot90(x_rot1_3[count],k=-i,dims=(-2,-1)) for count, i in enumerate(rot)]
#     x_rot1_3=torch.stack(x_rot1_3,dim=0)
    
#     h_flip_x_rot_1_3 = [torch.rot90(torch.flip(x, dims=[3]), k=i, dims=(-2, -1)) for i in rot]
#     h_flip_x_rot_1_3 = torch.cat(h_flip_x_rot_1_3, dim=0)
# #     print(h_flip_x_rot_1_3.shape)
#     h_flip_x_rot_1_3 = model(h_flip_x_rot_1_3)
# #     print(h_flip_x_rot_1_3.shape)
#     h_flip_x_rot_1_3 = torch.from_numpy(h_flip_x_rot_1_3).to(device)
#     h_flip_x_rot_1_3 = h_flip_x_rot_1_3.reshape(len(rot), shape[0], 1, *shape[-2:])

#     # Reverse transformations
#     h_flip_x_rot_1_3 = [torch.rot90(h_flip_x_rot_1_3[count], k=-i, dims=(-2, -1)) for count, i in enumerate(rot)]
#     h_flip_x_rot_1_3 = torch.stack(h_flip_x_rot_1_3, dim=0)
#     h_flip_x_rot_1_3 = torch.flip(h_flip_x_rot_1_3, dims=[3])
#     h_flip_preds=h_flip_x_rot_1_3.mean(0)
    
    v_flip_x_rot_1_3 = [torch.rot90(torch.flip(x, dims=[2]), k=i, dims=(-2, -1)) for i in rot]
    v_flip_x_rot_1_3 = torch.cat(v_flip_x_rot_1_3, dim=0)
    v_flip_x_rot_1_3 = model(v_flip_x_rot_1_3)
    v_flip_x_rot_1_3 = torch.from_numpy(v_flip_x_rot_1_3).to(device)
    v_flip_x_rot_1_3 = v_flip_x_rot_1_3.reshape(len(rot), shape[0], 1, *shape[-2:])

    # Reverse transformations
    v_flip_x_rot_1_3 = [torch.rot90(v_flip_x_rot_1_3[count], k=-i, dims=(-2, -1)) for count, i in enumerate(rot)]
    v_flip_x_rot_1_3 = torch.stack(v_flip_x_rot_1_3, dim=0)

    # Vertical flip
    v_flip_x_rot_1_3 = torch.flip(v_flip_x_rot_1_3, dims=[2])
    v_flip_preds=v_flip_x_rot_1_3.mean(0)

#     print(x_rot1_3.shape, h_flip_x_rot_1_3.shape, v_flip_x_rot_1_3.shape)
#     print("h_flip_x_rot_1_3 ", h_flip_preds.shape)
    x=torch.stack([v_flip_preds], dim=0)
    x = x.squeeze(0)

    return x

In [None]:
if mode == 'test':
    fragment_ids = sorted(os.listdir(CFG.comp_dataset_path + mode))
else:
    fragment_ids = [3]
model = EnsembleModel()

In [None]:
# patho=f'/kaggle/input/ensemble-training-fold-3/outputs/vesuvius/vesuvius_2d_slide_exp0015/vesuvius-models/Unet_fold1_best.pth'
model = nn.DataParallel(model, device_ids=[0, 1])
model = model.cuda()
# state = torch.load(patho)['model']
# model.load_state_dict(state)
model.eval()

In [None]:
# model.eval()

## main

In [None]:
results = []
for fragment_id in fragment_ids:
    
    test_loader, xyxys = make_test_dataset(fragment_id)
    
    binary_mask = cv2.imread(CFG.comp_dataset_path + f"{mode}/{fragment_id}/mask.png", 0)
    binary_mask = (binary_mask / 255).astype(int)
    
    ori_h = binary_mask.shape[0]
    ori_w = binary_mask.shape[1]
    # mask = mask / 255

    pad0 = (CFG.tile_size - binary_mask.shape[0] % CFG.tile_size)
    pad1 = (CFG.tile_size - binary_mask.shape[1] % CFG.tile_size)

    binary_mask = np.pad(binary_mask, [(0, pad0), (0, pad1)], constant_values=0)
    
    mask_pred = np.zeros(binary_mask.shape)
    mask_count = np.zeros(binary_mask.shape)

    for step, (images) in tqdm(enumerate(test_loader), total=len(test_loader)):
        images = images.cuda()
        batch_size = images.size(0)

        with torch.no_grad():
            y_preds = TTA(images,model).cpu().numpy()

        start_idx = step*CFG.batch_size
        end_idx = start_idx + batch_size
        for i, (x1, y1, x2, y2) in enumerate(xyxys[start_idx:end_idx]):
            mask_pred[y1:y2, x1:x2] += y_preds[i].reshape(mask_pred[y1:y2, x1:x2].shape)
            mask_count[y1:y2, x1:x2] += np.ones((CFG.tile_size, CFG.tile_size))
    
    
    print(f'mask_count_min: {mask_count.min()}')
    mask_pred /= mask_count
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 8))
    axes[0].imshow(mask_count)
    axes[1].imshow(mask_pred.copy())
    
    
    
    mask_pred = mask_pred[:ori_h, :ori_w]
    binary_mask = binary_mask[:ori_h, :ori_w]
    
    mask_pred = (mask_pred >= TH).astype(int)
    mask_pred *= binary_mask
    axes[2].imshow(mask_pred)
    plt.show()
    
    inklabels_rle = rle(mask_pred)
    
    results.append((fragment_id, inklabels_rle))
    

    del mask_pred, mask_count
    del test_loader
    
    gc.collect()
    torch.cuda.empty_cache()

## submission

In [None]:
sub = pd.DataFrame(results, columns=['Id', 'Predicted'])

In [None]:
sub

In [None]:
sample_sub = pd.read_csv(CFG.comp_dataset_path + 'sample_submission.csv')
sample_sub = pd.merge(sample_sub[['Id']], sub, on='Id', how='left')

In [None]:
sample_sub

In [None]:
sample_sub.to_csv("submission.csv", index=False)