In [None]:
from brainmri.dataset.stacker import MriStacker
from brainmri.dataset.dataset import *
from brainmri.models.arch import FPN
from brainmri.runner.train import train_model

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import date

import albumentations as A
from albumentations.pytorch import ToTensor
import torch
import torch.nn as nn

import segmentation_models_pytorch as smp

In [None]:
config = {
     'make_stacks': True,
     'stack_size': 3,
     'data_dir': './data/lgg-mri-segmentation/kaggle_3m/',
     'out_dir': './data/lgg-mri-segmentation/2.5D/StackSize=3',
     'model_out_pth': './models/fpn-scratch_aug_{date}.pth',
     'augmentations': True,
     'epochs': 300,
     'batch_size': 64,
     'lr': 1e-04,
     'optimizer': 'adam',
     'device': 'cuda',
     'num_classes': 1
}

In [None]:
A.__version__

In [None]:
def get_augmentations(is_train, apply_transforms=False):
    if is_train and not apply_transforms:
        print("apply_transforms is False. Augmentations not applied")
        
    return A.Compose([
        A.RandomCrop(width = 128, height = 128, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Transpose(p=0.5),
        A.ShiftScaleRotate(shift_limit=0.01, scale_limit=0.04, rotate_limit=0, p=0.25),

        # Pixels
        A.RandomBrightnessContrast(p=0.5),
        A.RandomGamma(p=0.25),
        A.IAAEmboss(p=0.25),
        A.Blur(p=0.01, blur_limit = 3),

        # Affine
        A.OneOf([
            A.ElasticTransform(p=0.5, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03),
            A.GridDistortion(p=0.5),
            A.OpticalDistortion(p=1, distort_limit=2, shift_limit=0.5)                  
        ], p=0.8),
        ToTensor()
    ])

In [None]:
stacker = MriStacker(root_dir=config.get("data_dir"),
                         out_dir=config.get("out_dir"),
                         stack_size=config.get("stack_size"))

if config.get("make_stacks"):
    stacker.process_patients()

stacker.gen_train_val_test_split()

augs = get_augmentations(is_train=True, apply_transforms=config.get("augmentations"))
train_ds = BrainMriSegmentation(stacker.train_df, config.get("stack_size"),
                                transforms=get_augmentations(is_train=True, apply_transforms=config.get("augmentations")),
                                preprocessing=None)
valid_ds = BrainMriSegmentation(stacker.valid_df, config.get("stack_size"),
                                transforms=A.Compose([ToTensor()]), preprocessing=None)

train_dl = get_dataloader(train_ds, bs=config.get("batch_size"))
valid_dl = get_dataloader(valid_ds, bs=config.get("batch_size"))


model = FPN(3)

optimizer = torch.optim.Adam(model.parameters(), lr=config.get("lr"))

if torch.cuda.device_count() > 1 and config.get("device") == "cuda":
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = nn.DataParallel(model)
    model = model.cuda()

loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5), smp.utils.metrics.Fscore(threshold=0.5)]

train_epoch = smp.utils.train.TrainEpoch(
    model,
    loss=loss,
    metrics=metrics,
    optimizer=optimizer,
    device=config.get("device"),
    verbose=True
)

valid_epoch = smp.utils.train.ValidEpoch(
    model,
    loss=loss,
    metrics=metrics,
    device=config.get("device"),
    verbose=True
)

train_loss, valid_loss, train_fscore, valid_fscore, train_iou, valid_iou = [], [], [], [], [], []

max_score = 0
for i in range(0, config.get("epochs")):
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_dl)
    valid_logs = valid_epoch.run(valid_dl)

    if max_score < valid_logs["fscore"]:
        max_score = valid_logs["fscore"]
        torch.save(model, config.get("model_out_pth").format(date=str(date.today())))
        print("Model saved!")

    train_loss.append(train_logs["dice_loss"])
    valid_loss.append(valid_logs["dice_loss"])
    train_fscore.append(train_logs["fscore"])
    valid_fscore.append(valid_logs["fscore"])
    train_iou.append(train_logs["iou_score"])
    valid_iou.append(valid_logs["iou_score"])  



In [None]:
plt.plot(train_loss)
plt.plot(train_fscore)
plt.plot(train_iou)
plt.plot(valid_loss)
plt.plot(valid_fscore)
plt.plot(valid_iou)                                                                                

plt.legend(["Train Loss",  "Train Dice", "Train IoU", "Val Loss", "Val Dice", "Val IoU"])

In [None]:
best_model = torch.load("./models/fpn-scratch_aug_2021-08-26.pth")

In [None]:
test_ds = BrainMriSegmentation(stacker.test_df, config.get("stack_size"),
                                transforms=A.Compose([ToTensor()]))
test_dl = get_dataloader(test_ds, bs=config.get("batch_size"))

In [None]:
loss = smp.utils.losses.DiceLoss()
metrics = [smp.utils.metrics.IoU(threshold=0.5), smp.utils.metrics.Fscore(threshold=0.5)]
    
test_epoch =  smp.utils.train.ValidEpoch(
    best_model,
    loss=loss,
    metrics=metrics,
    device=config.get("device"),
    verbose=True
)

In [None]:
test_epoch.run(test_dl)