# Import & Settings

In [None]:
import torch
import numpy as np
import os
import segmentation_models_pytorch as smp
import utils

import warnings
warnings.filterwarnings('ignore')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

BATCH_SIZE = 16
NUM_WORKERS = 0 
WIDTH = 336  # 512
HEIGHT = 224 # 384

In [None]:
model_id = 1
title = 'unet'
model_name = title + '_' + str(model_id)
log = utils.Logger(verbose=True, title=os.path.join('seg', title))
log.logger.info("{}".format(model_name))

# Data Preparation

In [None]:
# path of img data
pth_train_img = 'Data/ISIC2017/Aug_Training_Data'
pth_valid_img = 'Data/ISIC2017/ISIC-2017_Validation_Data'
pth_test_img = 'Data/ISIC2017/ISIC-2017_Test_Data'
pth_train_mask = 'Data/ISIC2017/ISIC-2017_Training_Part1_GroundTruth'
pth_valid_mask = 'Data/ISIC2017/ISIC-2017_Validation_Part1_GroundTruth'
pth_test_mask = 'Data/ISIC2017/ISIC-2017_Test_v2_Part1_GroundTruth'


ann_train = 'Data/ISIC2017/ISIC-2017_Training_Part3_GroundTruth.csv'
ann_valid = 'Data/ISIC2017/ISIC-2017_Validation_Part3_GroundTruth.csv'
ann_test = 'Data/ISIC2017/ISIC-2017_Test_v2_Part3_GroundTruth.csv'

In [None]:
from torch.utils import data
import albumentations as A
from torchvision import transforms

# Augmentation transforms for both images and masks
# https://github.com/albumentations-team/albumentations#spatial-level-transforms
trans_train = A.Compose([# A.ElasticTransform(),
                         A.RandomResizedCrop(width=WIDTH, height=HEIGHT, scale=(0.6, 1.3), ratio=(0.75, 1.3333333333333333)),
                         A.Flip(p=0.5),
                         A.Rotate(limit=180),
                         # A.Sharpen(),
                         # A.ColorJitter(),
                         A.GaussNoise(),
                         ])

trans_test = A.Compose([A.Resize(height=int(HEIGHT*1.1), width=int(WIDTH*1.1)),
                        A.CenterCrop(height=HEIGHT, width=WIDTH)
                        ])

# normalization
trans_img = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],
                                                     [0.229, 0.224, 0.225])
                                ])


def trans_mask(mask):
    return torch.as_tensor(np.array(mask/255), dtype=torch.int64)

In [None]:
train_data = utils.SegData(ann_train, pth_train_img, pth_train_mask, trans_train, trans_img, trans_mask)
valid_data = utils.SegData(ann_valid, pth_valid_img, pth_valid_mask, trans_test, trans_img, trans_mask)

train_loader = data.DataLoader(train_data, BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, drop_last=True)
valid_loader = data.DataLoader(valid_data, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
from utils.visualize import show_seg_samples

x, y = next(iter(valid_loader))
show_seg_samples(x, y, title="Segmentation Examples")

# Model Design

Dense-UNet / Res-UNet

UNet

ARL-UNet

In [None]:
from nets import arlunet

model = arlunet(pretrained='arl18')
log.logger.info("ARL-UNet | Size: ({}, {})".format(WIDTH, HEIGHT))

In [None]:
model.to(device)

# Training

In [None]:
init_lr = 1e-4
weight_decay = 1e-4
max_epoch = 150
test_period = 1
early_threshold = 45

criterion = utils.DiceCE()
optimizer = torch.optim.AdamW(model.parameters(), lr=init_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epoch, eta_min=0)

log.logger.info("Criterion: {}\nOptimizer: {}\nScheduler: {}".format(criterion, optimizer, scheduler))

trainer = utils.SegTrain(device, log, model_name, optimizer, scheduler, 0, 0, None)

acc, iou = trainer.eval(model, valid_loader)
log.logger.info("Initial Performance on Valid Set: Acc: {}, IoU: {}".format(acc, iou))

history = trainer.fit(model, train_loader, valid_loader, criterion, max_epoch, test_period, early_threshold)

In [None]:
import matplotlib.pyplot as plt

if not os.path.exists('fig'):
    os.makedirs('fig')

def plot_loss(history):
    plt.figure(dpi=100)
    plt.plot(history['costs'])
    plt.title('Loss')
    plt.xlabel('Epoch')
    plt.grid(axis='y')
    plt.savefig('fig/{}_loss.png'.format(model_name))
    plt.show()
    
def plot_score(history):
    plt.figure(dpi=100)
    plt.plot(history['train_ious'], label='train_mIoU')
    plt.plot(history['val_ious'], label='val_mIoU')
    plt.title('Mean IoU')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(axis='y')
    plt.savefig('fig/{}_iou.png'.format(model_name))
    plt.show()
    
def plot_acc(history):
    plt.figure(dpi=100)
    plt.plot(history['train_accs'], label='train_accuracy')
    plt.plot(history['val_accs'], label='val_accuracy')
    plt.title('Pixel Accurary')
    plt.xlabel('Epoch')
    plt.legend()
    plt.grid(axis='y')
    plt.savefig('fig/{}_acc.png'.format(model_name))
    plt.show()

In [None]:
plot_loss(history)
plot_score(history)
plot_acc(history)

# Evaluation

In [None]:
from utils.evaluation import pixel_accuracy, pixel_sensitivity, pixel_specificity, mIoU, mDSC, mTJI, seg_predict

del train_loader, valid_loader

test_data = utils.SegData(ann_test, pth_test_img, pth_test_mask, trans_test, trans_img, trans_mask)
test_loader = data.DataLoader(test_data, BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

In [None]:
# model = utils.load_model(device, name="dense_unet_1.pkl")
# model

In [None]:
mask, pred_mask = seg_predict(model, test_loader)

In [None]:
pixel_acc = pixel_accuracy(pred_mask, mask)
log.logger.info("Pixel Accuracy: {}".format(pixel_acc))

pixel_se = pixel_sensitivity(pred_mask, mask)
log.logger.info("Pixel Sensitivity: {}".format(pixel_se))

pixel_sp = pixel_specificity(pred_mask, mask)
log.logger.info("Pixel Specificity: {}".format(pixel_sp))

iou_score = mIoU(pred_mask, mask)
log.logger.info("Mean IoU: {}".format(iou_score))

dsc_score = mDSC(pred_mask, mask)
log.logger.info("Mean DSC: {}".format(dsc_score))

tji_score = mTJI(pred_mask, mask)
log.logger.info("Mean TJI: {}".format(tji_score))

In [None]:
x, y = next(iter(test_loader))
pred = model(x.to(device))
pred = torch.argmax(pred, dim=1).to('cpu')
show_seg_samples(x, pred, title="Predictions")

In [None]:
show_seg_samples(x, y, title="GroundTruths")