In [None]:
!pip install -q segmentation_models_pytorch
!pip install -qU wandb
!pip install -q scikit-learn==1.0
!pip install -q segmentation-mask-overlay

# utils.py

In [None]:
import random
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from segmentation_mask_overlay import overlay_masks

def load_img(img_path: str) -> np.array:
    '''Load single 2D image from np array'''
    img = np.load(img_path)
    img = np.tile(img[...,None], [1, 1, 1])
    img = img.astype('float32')
    max_num = np.max(img)
    if max_num:
        img/=max_num
    return img

def load_msk(msk_path: str) -> np.array:
    '''Load single 2D mask from np array'''
    msk = np.load(msk_path)
    msk = np.tile(msk[...,None], [1, 1, 1])
    return msk

def set_seed(seed: int = 42) -> None:
    '''Sets the seed so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)

def get_mask_path(image_path: str) -> str:
    '''Get path of the mask .nii.gz file corresponding to the image path'''
    path_list = image_path.split("/")
    path_list[-2] = "masks"
    return "/".join(path_list)

def mask_empty(mask_path: str) -> bool:
    '''Check if mask is empty (only 0s)'''
    mask = np.load(mask_path)
    return not np.any(mask)

def get_case(image_path: str) -> str:
    '''Get case from image path'''
    path_list = image_path.split("/")
    fname_list = path_list[-1].split("-")
    del fname_list[-1]
    return "-".join(fname_list)

def get_id(image_path: str) -> str:
    '''Get study id from image path'''
    path_list = image_path.split("/")
    fname_list = path_list[-1]
    return fname_list

def save_overlay(image: np.array, 
                 mask: np.array, 
                 predict: np.array = None, 
                 out_path: str = "./sample.png") -> None:
    '''Save image with mask overlay'''
    layers = []
    layer_labels = []
    mask = np.where(mask<0.5, 0, 1)
    bool_mask = np.array(mask, dtype=bool)
    layers.append(bool_mask)
    layer_labels.append("mask")
    if isinstance(predict, np.ndarray):
        predict = np.where(predict<0.5, 0, 1)
        bool_predict = np.array(predict, dtype=bool)
        layers.append(bool_predict)
        layer_labels.append("predict")
    cmap = np.array([[0., 0., 1., 1],[1., 0., 0., 1.,]])
    fig = overlay_masks(image, layers, labels=layer_labels, colors=cmap, mask_alpha=0.5)
    fig.savefig(out_path, bbox_inches="tight", dpi=300)
    plt.close(fig)


In [None]:
cmap = plt.cm.tab20(np.arange(2))
cmap

# loss.py

In [None]:
import torch
import segmentation_models_pytorch as smp

DiceLoss    = smp.losses.DiceLoss(mode='binary')

def criterion(y_pred: torch.tensor, y_true: torch.tensor) -> torch.tensor:
    '''The criterion to calculate loss'''
    return DiceLoss(y_pred, y_true)

# dataset.py

In [None]:
import torch
import cv2
import albumentations as A
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
# from utils import load_msk, load_img

class BuildDataset(torch.utils.data.Dataset):
    '''Npy 2D dataset'''

    def __init__(self,
                 dataset_df: pd.DataFrame,
                 label: bool = True,
                 transforms: dict = None):

        self.dataset_df = dataset_df
        self.label = label
        self.img_paths = dataset_df['image_path'].tolist()
        self.msk_paths = dataset_df['mask_path'].tolist()
        self.transforms = transforms

    def __len__(self):
        return len(self.dataset_df)

    def __getitem__(self, index):

        img_path = self.img_paths[index]
        img = []
        img = load_img(img_path)

        if self.label:
            msk_path = self.msk_paths[index]
            msk = load_msk(msk_path)
            if self.transforms:
                data = self.transforms(image=img, mask=msk)
                img = data['image']
                msk = data['mask']
            img = np.transpose(img, (2, 0, 1))
            msk = np.transpose(msk, (2, 0, 1))
            return torch.tensor(img), torch.tensor(msk)

        if self.transforms:
            data = self.transforms(image=img)
            img = data['image']
        img = np.transpose(img, (2, 0, 1))

        return torch.tensor(img)


def get_transforms(cfg: object) -> dict:
    '''Generate transforms dict based on the cfg'''
    data_transforms = {
        "train": A.Compose([
            A.Resize(*cfg.img_size, interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625,
                               scale_limit=0.05,
                               rotate_limit=10, p=0.5),
            A.OneOf([
                A.GridDistortion(num_steps=5,
                                 distort_limit=0.05,
                                 p=1.0),
                A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=1.0)
            ], p=0.25),
            A.CoarseDropout(max_holes=8,
                            max_height=cfg.img_size[0]//20,
                            max_width=cfg.img_size[1]//20,
                            min_holes=5,
                            fill_value=0,
                            mask_fill_value=0,
                            p=0.5),
        ], p=1.0),

        "valid": A.Compose([
            A.Resize(*cfg.img_size, interpolation=cv2.INTER_NEAREST),
        ], p=1.0)
    }
    return data_transforms


def prepare_loaders(dataset_df: pd.DataFrame, fold: int, cfg: object) -> tuple:
    '''Create train and val dataloaders for current fold'''
    data_transforms = get_transforms(cfg)
    train_df = dataset_df.query("fold!=@fold").reset_index(drop=True)
    valid_df = dataset_df.query("fold==@fold").reset_index(drop=True)
    if cfg.debug:
        train_df = train_df.head(32*5).query("empty==0")
        valid_df = valid_df.head(32*3).query("empty==0")
    train_dataset = BuildDataset(train_df, transforms=data_transforms['train'])
    valid_dataset = BuildDataset(valid_df, transforms=data_transforms['valid'])

    train_loader = DataLoader(train_dataset, batch_size=cfg.train_bs if not cfg.debug else 20,
                              num_workers=4, shuffle=True, pin_memory=True, drop_last=False)
    valid_loader = DataLoader(valid_dataset, batch_size=cfg.valid_bs if not cfg.debug else 20,
                              num_workers=4, shuffle=False, pin_memory=True)
    return train_loader, valid_loader


# model.py

In [None]:
import torch
import segmentation_models_pytorch as smp

def build_model(cfg: object) -> object:
    '''Create model'''
    model = smp.Unet(
        encoder_name=cfg.backbone,
        encoder_weights="imagenet",
        in_channels=1,
        classes=cfg.num_classes,
        activation=None,
    )
    model.to(cfg.device)
    return model

def load_model(model_path: str, cfg: object) -> object:
    '''Load model from checkpoint'''
    model = build_model(cfg)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model


# inference.py

In [None]:
@torch.no_grad()
def predict_image(model: object,
                  image: np.array,
                  device: torch.cuda.device):
    '''Single prediction of an image'''
    image = cv2.normalize(image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    image = cv2.resize(image, (224,224), interpolation = cv2.INTER_AREA)
    image = np.expand_dims(image, axis=0)
    image = np.expand_dims(image, axis=0)
    image_tensor = torch.tensor(image).to(device)
    with torch.no_grad():
        y_pred  = model(image_tensor)
        y_pred = (nn.Sigmoid()(y_pred)>0.01).double()
    return y_pred.detach().cpu().numpy()[0][0]

# trainer.py

In [None]:
import gc
import time
import copy
from collections import defaultdict
import numpy as np
import torch
from torch import nn
from torch.cuda import amp
from torch.optim import lr_scheduler
# from loss import criterion

def fetch_scheduler(optimizer: object, cfg: object) -> object:
    '''Create scheduler object'''

    if cfg.scheduler is None:
        return None
    elif cfg.scheduler == 'CosineAnnealingLR':
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,T_max=cfg.T_max, 
                                                   eta_min=cfg.min_lr)
    elif cfg.scheduler == 'CosineAnnealingWarmRestarts':
        scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=cfg.T_0, 
                                                             eta_min=cfg.min_lr)
    elif cfg.scheduler == 'ReduceLROnPlateau':
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                   mode='min',
                                                   factor=0.1,
                                                   patience=4,
                                                   threshold=0.0001,
                                                   min_lr=cfg.min_lr,)
    elif cfg.scheduer == 'ExponentialLR':
        scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.85)

    return scheduler


def train_one_epoch(model: object,
                    optimizer: object,
                    dataloader: object,
                    cfg: object) -> float:

    '''Single epoch training loop'''

    model.train()
    scaler = amp.GradScaler()
    dataset_size = 0
    running_loss = 0.0

    for step, (images, masks) in enumerate(dataloader):
        images = images.to(cfg.device, dtype=torch.float)
        masks  = masks.to(cfg.device, dtype=torch.float)
        batch_size = images.size(0)

        with amp.autocast(enabled=True):
            y_pred = model(images)
            loss   = criterion(y_pred, masks)
            loss   = loss / cfg.n_accumulate

        scaler.scale(loss).backward()

        if (step + 1) % cfg.n_accumulate == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size
        current_lr = optimizer.param_groups[0]['lr']
        
    torch.cuda.empty_cache()
    gc.collect()
    return epoch_loss

@torch.no_grad()
def valid_one_epoch(model: object,
                    dataloader: object,
                    scheduler: object,
                    device: torch.cuda.device):
    '''Single epoch training loop'''

    model.eval()
    dataset_size = 0
    running_loss = 0.0
    val_scores = []

    for _, (images, masks) in enumerate(dataloader):
        images  = images.to(device, dtype=torch.float)
        masks   = masks.to(device, dtype=torch.float)
        batch_size = images.size(0)
        y_pred  = model(images)
        loss    = criterion(y_pred, masks)
        running_loss += (loss.item() * batch_size)
        dataset_size += batch_size
        epoch_loss = running_loss / dataset_size
        y_pred = nn.Sigmoid()(y_pred)

        val_scores.append(1-epoch_loss)

        current_lr = optimizer.param_groups[0]['lr']

    if scheduler is not None:
        scheduler.step(epoch_loss)
        
    torch.cuda.empty_cache()
    gc.collect()

    return epoch_loss

def run_training(model: object,
                 train_loader: torch.utils.data.DataLoader,
                 valid_loader: torch.utils.data.DataLoader,
                 optimizer: object,
                 scheduler: object,
                 cfg: object) -> tuple:
    '''Main training loop'''
    # To automatically log gradients
#     if cfg.use_wandb:
#         wandb.watch(model, log_freq=100)

    start = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = np.inf
    best_epoch = -1
    history = defaultdict(list)

    for epoch in range(1, cfg.epochs + 1):
        gc.collect()
        print(f'Epoch {epoch}/{cfg.epochs}')

        train_loss = train_one_epoch(model,
                                     optimizer,
                                     train_loader,
                                     cfg)
        
        val_loss = valid_one_epoch(model,
                                   valid_loader,
                                   scheduler,
                                   device=cfg.device)

        history['Train Loss'].append(train_loss)
        history['Valid Loss'].append(val_loss)

        # Log the metrics
        if cfg.use_wandb:
            wandb.log({"Train Loss": train_loss,
                       "Valid Loss": val_loss,
                       "LR":scheduler.get_last_lr()[0]
                      })

        # deep copy the model
        if val_loss < best_loss:
            best_epoch = epoch
            run.summary["Best Epoch"]   = best_epoch
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = f"best_epoch-{fold:02d}.bin"
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            wandb.save(PATH)
            best_loss = val_loss
            
        last_model_wts = copy.deepcopy(model.state_dict())
        PATH = f"last_epoch-{fold:02d}.bin"
        torch.save(model.state_dict(), PATH)

    end = time.time()
    time_elapsed = end - start

    print('Training complete in {:.0f}h {:.0f}m {:.0f}s'.format(
        time_elapsed // 3600, (time_elapsed % 3600) // 60, (time_elapsed % 3600) % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, history, best_loss

# train.py

In [None]:
import sys
import glob
import gc
import json
import configparser
import torch
import wandb
import pandas as pd
from sklearn.model_selection import StratifiedGroupKFold
import matplotlib.pyplot as plt
# import utils
# from model import build_model
# from trainer import fetch_scheduler, run_training
# from dataset import get_transforms, prepare_loaders

def parse_config(config_path: str) -> object:
    '''Parse config from .ini file'''
#     config = configparser.ConfigParser()
#     config.read(config_path)

#     class CFG:
#         '''Main config for training and logging'''
#         seed = int(config["TRAIN"]["seed"])
#         debug = bool(config["TRAIN"]["debug"])
#         exp_name = config["WANDB"]["exp_name"]
#         comment = config["WANDB"]["comment"]
#         model_name = config["TRAIN"]["model_name"]
#         backbone = config["TRAIN"]["backbone"]
#         train_bs = int(config["TRAIN"]["train_bs"])
#         valid_bs = int(config["TRAIN"]["valid_bs"])
#         img_size = json.loads(config["TRAIN"]["img_size"])
#         epochs = int(config["TRAIN"]["epochs"])
#         lr = float(config["TRAIN"]["lr"])
#         scheduler = config["TRAIN"]["scheduler"]
#         min_lr = float(config["TRAIN"]["min_lr"])
#         T_max = int(30000/train_bs*epochs)+50
#         T_0 = int(config["TRAIN"]["T_0"])
#         warmup_epochs = int(config["TRAIN"]["warmup_epochs"])
#         wd = float(config["TRAIN"]["wd"])
#         n_accumulate = max(1, 32//train_bs)
#         n_fold = int(config["TRAIN"]["n_fold"])
#         num_classes = int(config["TRAIN"]["num_classes"])
#         # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#         device = 'cpu'
#         use_wandb = config["WANDB"]["use_wandb"]

    class CFG:
        seed          = 101
        debug         = False # set debug=False for Full Training
        exp_name      = 'Baselinev2'
        comment       = 'unet-efficientnet_b1-224x224-aug2-split2'
        model_name    = 'Unet'
        backbone      = 'efficientnet-b1'
        train_bs      = 128
        valid_bs      = train_bs*2
        img_size      = [224, 224]
        epochs        = 15
        lr            = 2e-3
        scheduler     = 'CosineAnnealingLR'
        min_lr        = 1e-6
        T_max         = int(30000/train_bs*epochs)+50
        T_0           = 25
        warmup_epochs = 0
        wd            = 1e-6
        n_accumulate  = max(1, 32//train_bs)
        n_fold        = 5
        num_classes   = 1
        device        = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        use_wandb = True
        wandb_secret = '2357b81796ab7246d2282d91387ec7f03ad92114'

    return CFG

# if __name__ == "__main__":
if True:
    # CFG_PATH = sys.argv[1]
    CFG_PATH = "./default_params.ini"
    CFG = parse_config(CFG_PATH)

    set_seed(CFG.seed)
    gc.collect()

    BASE_PATH  = '/kaggle/input/pleural-effusion'
    image_paths = glob.glob("/kaggle/input/pleural-effusion/images/*.npy")

    mask_paths = [get_mask_path(x) for x in image_paths]
    is_empty = [mask_empty(x) for x in mask_paths]

    df = pd.DataFrame({"image_path":image_paths, 
                    "mask_path":mask_paths, 
                    "empty": is_empty})

    df["case"] = df["image_path"].apply(get_case)
    df["id"] = df["image_path"].apply(get_id)

    # K-fold split
    skf = StratifiedGroupKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
    for fold, (train_idx, val_idx) in enumerate(
        skf.split(df, df['empty'], groups = df["case"])):
        df.loc[val_idx, 'fold'] = fold

    model = build_model(CFG)
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
    scheduler = fetch_scheduler(optimizer, CFG)
    # Main Loop
    data_transforms = get_transforms(CFG)
    best_dice = -1
    best_fold = -1
    best_losslist = []
    for fold in range(CFG.n_fold):
#     Best fold 
#     fold = 3
#     if True:
        print(f'### Fold: {fold+1} of {CFG.n_fold}')
        
        if CFG.use_wandb:
            wandb.login(key=CFG.wandb_secret)
            anonymous = None
            
            run = wandb.init(project='pleural-effusion-2d-seg',
                             config={k:v for k, v in dict(vars(CFG)).items() if '__' not in k},
                             name=f"fold-{fold}|dim-{CFG.img_size[0]}x{CFG.img_size[1]}|model-{CFG.model_name}",
                             group=CFG.comment,
                            )
            
        train_loader, valid_loader = prepare_loaders(df, fold, CFG)
        model = build_model(CFG)
        optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
        scheduler = fetch_scheduler(optimizer, CFG)
        model, history, fold_best_dice = run_training(model,
                                                      train_loader, 
                                                      valid_loader, 
                                                      optimizer, 
                                                      scheduler,
                                                      CFG)
        if fold_best_dice < best_dice:
            print(f"New best fold: {fold}")
            best_fold = fold
            PATH = f"best_epoch-all-folds.bin"
            torch.save(model.state_dict(), PATH)
            wandb.save(PATH)
            best_dice = fold_best_dice
            best_losslist = history["Valid Loss"]
        if not os.path.exists("/kaggle/working/artefacts"):
            os.makedirs("/kaggle/working/artefacts")
        fig, ax = plt.subplots( nrows=1, ncols=1 )
        ax.plot(best_losslist)
        ax.title.set_text('Dice loss')
        ax.title.set_text('Dice loss')
        ax.set_xlabel("Epochs")
        ax.set_ylabel("1 - Dice")
        fig.savefig("/kaggle/working/artefacts/dice_plot.png")
        plt.close(fig)
        if CFG.use_wandb:
            run.finish()

# Temp

In [None]:
best_model = load_model(f"best_epoch-03.bin", CFG)
_, valid_loader = prepare_loaders(df, fold, CFG)
preds = []
masks = []
imgs = []
for img_batch, mask_batch in valid_loader:
    with torch.no_grad():
        img_batch = img_batch.to(CFG.device)
        pred_batch = model(img_batch)
        pred_batch = (nn.Sigmoid()(pred_batch)>0.5).double()
    preds.append(pred_batch)
    masks.append(mask_batch)
    imgs.append(img_batch)
imgs = torch.mean(torch.stack(imgs, dim=0), dim=0).cpu().detach().numpy()
preds = torch.mean(torch.stack(preds, dim=0), dim=0).cpu().detach().numpy()
masks = torch.mean(torch.stack(masks, dim=0), dim=0).cpu().detach().numpy()

In [None]:
# !rm /kaggle/working/artifacts/
if not os.path.isdir("/kaggle/working/artifacts/"):
    os.makedirs("/kaggle/working/artifacts/")

In [None]:
for i in range(imgs.shape[0]):
    img = imgs[i][0]
    pred = preds[i][0]
    mask = masks[i][0]
    out_path = f"/kaggle/working/artifacts/{i}.png"
    save_overlay(img, mask, pred, out_path)

In [None]:
fig, ax = plt.subplots( nrows=1, ncols=1 )
ax.plot(history["Valid Loss"])
ax.set_xlabel('Epochs')
ax.set_ylabel('1 - Dice')
ax.title.set_text("Dice loss")
fig.savefig("/kaggle/working/dice_plot.png")
plt.close(fig)

In [None]:
# import matplotlib.pyplot as plt
# plt.imshow(masks[1][0])

In [None]:
# plt.imshow(preds[1][0])

In [None]:
# i=1
# img = imgs[i][0]
# pred = preds[i][0]
# mask = masks[i][0]
# out_path = f"/kaggle/working/artifacts/{i}.png"
# save_overlay(img, mask, pred, out_path)

In [None]:
# pred.shape

In [None]:
# mask = np.where(mask<0.5, 0, 1)

In [None]:
# plt.imshow(mask)

In [None]:

# bool_mask = np.array(mask, dtype=bool)

In [None]:
# plt.imshow(bool_mask)

In [None]:
# mask