In [None]:
import gc
import os
import sys
import cv2
import glob
import json
import shutil
import random
import pydicom
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import StratifiedKFold, StratifiedGroupKFold

sys.path.append('../input/timm-pytorch-image-models/pytorch-image-models-master/')
import timm
from timm.data import create_transform
from timm import create_model, list_models

import torch
from torch import nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler 
from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy

import albumentations as A
from albumentations.pytorch import ToTensorV2

from transformers import Mask2FormerConfig, Mask2FormerImageProcessor, Mask2FormerForUniversalSegmentation
from transformers import OneFormerConfig, OneFormerImageProcessor, OneFormerForUniversalSegmentation

In [None]:
class CFG:
    """
    Parameters used for training
    """
    seed = 42
    
    img_size = (512, 512)
    batch_size = 1
    epochs = 10
    use_fp16 = True
    n_folds = 5
    train_folds = [0, 1, 2, 3, 4]
    
    weight_decay = 0.024
    one_cycle_max_lr = 4e-4  # 8e-4    
    tta = False

# os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def set_seed(seed = 42):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    print('> SEEDING DONE')
    
set_seed(CFG.seed)

In [None]:
COMP_PATH = "/kaggle/input/hubmap-hacking-the-human-vasculature"
tile_meta = pd.read_csv(f"{COMP_PATH}/tile_meta.csv")
wsi_meta = pd.read_csv(f"{COMP_PATH}/wsi_meta.csv")

In [None]:
with open(f"{COMP_PATH}/polygons.jsonl", "r") as json_file:
    json_list = list(json_file)
    
tiles_data = {}
for json_str in tqdm(json_list, total=len(json_list)):
    json_data = json.loads(json_str)
    tiles_data[json_data['id']] = json_data['annotations']

In [None]:
def get_no_of_blood_vessel(img_id):
    mask = np.zeros(CFG.img_size, dtype=np.float32)
    cnt = 0
    for annot in tiles_data[img_id]:
        if annot['type'] != "blood_vessel": continue
        for cord in annot['coordinates']:
            row, col = np.array([i[1] for i in cord]), np.asarray([i[0] for i in cord])
            mask[row, col] = 1
            cnt += 1
    return cnt

In [None]:
df = tile_meta.query('dataset < 3').reset_index(drop=True)
df['no_of_bv'] = df['id'].apply(get_no_of_blood_vessel)
df = df.query('no_of_bv > 0').reset_index(drop=True)

skf = StratifiedKFold(n_splits=CFG.n_folds, random_state=CFG.seed, shuffle=True)
df['fold'] = -1

for fold, (train_idx, test_idx) in enumerate(skf.split(df, df['source_wsi'])):
    df.loc[test_idx, 'fold'] = fold

df.groupby('fold')['source_wsi'].value_counts()

In [None]:
class HubmapDataset(Dataset):
    def __init__(self, df, config, transforms=None):
        self.df = df
        self.img_ids = self.df.id.values
        self.transforms = transforms
        self.processor = Mask2FormerImageProcessor(config)
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, index):
        img_id = self.img_ids[index]
        img_path = f"{COMP_PATH}/train/{img_id}.tif"
        
        try:
            image = np.asarray(Image.open(img_path).convert('RGB'))
        except Exception as ex:
            print(img_path, ex)
            return None
        
        mask = np.zeros((CFG.img_size[0], CFG.img_size[1]), dtype=np.float32)
        cnt = 0
        for annot in tiles_data[img_id]:
            if annot['type'] != "blood_vessel": continue
            for cord in annot['coordinates']:
                row, col = np.array([i[1] for i in cord]), np.asarray([i[0] for i in cord])
                mask[row, col] = 1
                cnt += 1
        
        if self.transforms is not None:
            transformed = self.transforms(image=image, mask=mask)
            image, mask = transformed['image'], torch.unsqueeze(transformed['mask'], 0)
        
        try:
            inputs = self.processor.encode_inputs(pixel_values_list=[image], 
                                                  task_inputs=['instance'],
                                             segmentation_maps=mask, 
                                             ignore_index=0,
                                             return_tensors='pt')
        except Exception as err:
            print(image.shape, mask.shape)
            print(img_id, cnt)
            plt.imshow(np.asarray(Image.open(img_path).convert('RGB')))
            plt.imshow(mask.permute(1, 2, 0))
            print(err)
            plt.show()
            return {}

        inputs['pixel_values'] = inputs['pixel_values'][0]
        return inputs

In [None]:
def transformer(stage):
    if stage == "train":
        return A.Compose([
            A.HorizontalFlip(p=0.5),
            A.Rotate(limit=5),
#             A.augmentations.crops.RandomResizedCrop(height=CFG.img_size[0], width=CFG.img_size[1], scale=(0.8, 1), ratio=(0.45, 0.55)),
            A.Normalize(),
            A.pytorch.transforms.ToTensorV2()
        ])
    else:
        return A.Compose([
                A.Resize(CFG.img_size[0], CFG.img_size[1]),
                A.Normalize(),
                A.pytorch.transforms.ToTensorV2()
            ])

In [None]:
def plot_df(df):
    fig,ax = plt.subplots(1,1,figsize=(15,5))
    ax.plot(df['train_loss'])
    ax.plot(df['valid_loss'])
    ax.legend()
    ax.set_title('Loss')

In [None]:
class HubmapModel(nn.Module):
    def __init__(self, config):
        super(HubmapModel, self).__init__()
        self.model = Mask2FormerForUniversalSegmentation(config)
    
    def forward(self, inputs):
        outputs = self.model(**inputs)
        return outputs

In [None]:
config = Mask2FormerConfig(feature_size=CFG.img_size[0], mask_feature_size=CFG.img_size[0])
config.save_pretrained('./config')

In [None]:
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
    decay = []
    no_decay = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        if len(param.shape) == 1 or np.any([v in name.lower() for v in skip_list]):
            no_decay.append(param)
        else:
            decay.append(param)
    return [
        {'params': no_decay, 'weight_decay': 0.},
        {'params': decay, 'weight_decay': weight_decay}]

In [None]:
def get_optimizer_and_scheduler(model, dataloader, optim="adam"):
    if optim == "adamw":
         optimizer = torch.optim.AdamW(
             add_weight_decay(model,
                              weight_decay=CFG.weight_decay,
                              skip_list=['bias']),
             lr=CFG.one_cycle_max_lr,
             betas=(0.9, 0.999),
             weight_decay=CFG.weight_decay)
    else:
        optimizer = torch.optim.Adam(model.parameters())
    
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer=optimizer, pct_start=0.1,
                                              max_lr=CFG.one_cycle_max_lr, epochs=CFG.epochs, steps_per_epoch=len(dataloader))
    return optimizer, scheduler

In [None]:
def train_one_epoch(dataloader, model, scheduler, optimizer, scaler, epoch):
    model.train()

    total_loss = 0
    pbar = tqdm(dataloader, desc=f"Train: Epoch {epoch + 1}", total=len(dataloader), mininterval=5)

    for inputs in pbar:
        optimizer.zero_grad()
        
        # Using mixed precision training
        with autocast():
            inps = { k: (v.to(device) if type(v) == torch.Tensor else torch.stack(v).to(device)) for k, v in inputs.items() }
            outputs = model(inps)
            loss = outputs.loss
            
            if np.isinf(loss.item()) or np.isnan(loss.item()):
                print(f'Bad loss, skipping the batch')
                del loss, outputs
                gc_collect()
                continue

        # scaler is needed to prevent "gradient underflow"
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scale = scaler.get_scale()
        scaler.update()
        skip_lr_scheduler = scale > scaler.get_scale()
        if scheduler is not None and not skip_lr_scheduler:
            scheduler.step()
        
        lr = scheduler.get_last_lr()[0] if scheduler else CFG.one_cycle_max_lr
        loss = loss.item()
        
        pbar.set_postfix({"loss": loss, "lr": lr})
        total_loss += loss
    
    total_loss /= len(dataloader)
    gc.collect()
    torch.cuda.empty_cache()
    return total_loss

In [None]:
def valid_one_epoch(dataloader, model, epoch):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        pbar = tqdm(dataloader, desc=f'Eval: {epoch + 1}', total=len(dataloader), mininterval=5)

        for inputs in pbar:
            with autocast(enabled=False):
                inputs = { k: (v.to(device) if type(v) == torch.Tensor else [v[0].to(device)]) for k, v in inputs.items() }
                outputs = model(inputs)

                loss = outputs.item()
                pbar.set_postfix({"loss": loss})
                total_loss += loss

    total_loss /= len(dataloader)
    gc.collect()
    torch.cuda.empty_cache()
    return total_loss

In [None]:
def train_fnc(train_dataloader, valid_dataloader, model, fold, optimizer, scheduler):
    train_losses = []
    valid_losses = []
    
    scaler = GradScaler()
    best_loss = 999
    best_score = -1
    for epoch in range(CFG.epochs):
        train_loss = train_one_epoch(train_dataloader, model, scheduler, optimizer, scaler, epoch)
        valid_loss = valid_one_epoch(valid_dataloader, model, epoch)
        
        train_losses.append(train_loss)
        valid_losses.append(valid_loss)
        
        if valid_loss < best_loss:
            best_loss = valid_loss
            save_model(f"fold{fold}_best_loss.pth", model, thres)
            print("New Best Loss")
        print()
        
        print(f"-------- Epoch {epoch + 1} --------")
        print("Train Loss: ", train_loss)
        print("Valid Loss: ", valid_loss)
        print()
        
    column_names = ['train_loss','valid_loss']
    df = pd.DataFrame(np.stack([train_losses, valid_losses], axis=1),columns=column_names)
    display(df)
    plot_df(df)

In [None]:
for fold in range(CFG.n_folds):
    if fold not in CFG.train_folds: continue
    
    train_df = df[df['fold'] != fold]
    valid_df = df[df['fold'] == fold]
    
    train_dataset = HubmapDataset(train_df, config, transformer("train"))
    valid_dataset = HubmapDataset(valid_df, config, transformer("valid"))
    
    train_dataloader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=2 * CFG.batch_size, shuffle=False)
    
    model = HubmapModel(config).to(device)
#     model = Mask2FormerForUniversalSegmentation(config).to(device)
    optimizer, scheduler = get_optimizer_and_scheduler(model, train_dataloader, "adamw")
    
    train_fnc(train_dataloader, valid_dataloader, model, fold, optimizer, scheduler)