In [None]:
import torch
from torch import nn
import numpy as np
from pathlib import Path
from torch.utils import data
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.patches as mpatches
import wandb
from torchsummary import summary
from skimage.io import imread

from models.incept_unet import InceptionedUNet
from datasets.blastocyst import SFUDataset, ClinicDataset
from utils.trainer import Trainer
from utils.losses import DicePlusBCELoss

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
IMG_SIZE = 288
ADD_MASKS = True
TRANSFORM = transforms.Compose([transforms.Resize(IMG_SIZE), transforms.CenterCrop(IMG_SIZE)])
SFU_MASK = 'all'
CLINIC_MASK = 'random'

In [None]:
batch_size = 30
n_epochs = 350
n_blocks = 6
channels_out = 5
lr = 5e-5
start_filters = 32
dice_weight = 0.5
smooth = 5

In [None]:
# clinic data
clinic1 = Path('/datasets/clinic1')
clinic2 = Path('/datasets/clinic2')
clinic3 = Path('/datasets/clinic3')
clinic4 = Path('/datasets/clinic4')
root_dirs = [clinic1, clinic2, clinic3, clinic4]
# sfu data
sfu = Path('/datasets/sfu/BlastsOnline')

In [None]:
sfu_dataset_train = SFUDataset(sfu.joinpath('train'), use_augmentations=True,
                               mask=SFU_MASK, add_masks=ADD_MASKS, transform=TRANSFORM)

datasets_train = [sfu_dataset_train]
# datasets_train = []
# for root_dir in root_dirs:
#     dataset = ClinicDataset(root_dir.joinpath('train'), use_augmentations=False,
#                               mask=CLINIC_MASK, transform=TRANSFORM)
#     datasets_train.append(dataset)

dataloader_train = data.DataLoader(dataset=data.ConcatDataset(datasets_train), batch_size=batch_size, shuffle=True)

In [None]:
sfu_dataset_valid = SFUDataset(sfu.joinpath('valid'), use_augmentations=False,
                               mask=SFU_MASK, add_masks=ADD_MASKS, transform=TRANSFORM)

datasets_valid = [sfu_dataset_valid]
# datasets_valid = []
# for root_dir in root_dirs:
#     dataset = ClinicDataset(root_dir.joinpath('valid'), use_augmentations=False,
#                               mask=CLINIC_MASK, transform=TRANSFORM)
#     datasets_valid.append(dataset)

dataloader_valid = data.DataLoader(dataset=data.ConcatDataset(datasets_valid), batch_size=10, shuffle=False)

In [None]:
sfu_dataset_test = SFUDataset(sfu.joinpath('test'), use_augmentations=False,
                               mask=SFU_MASK, add_masks=ADD_MASKS, transform=TRANSFORM)

datasets_test = [sfu_dataset_test]
# datasets_test = []
# for root_dir in root_dirs:
#     dataset = ClinicDataset(root_dir.joinpath('test'), use_augmentations=False,
#                               mask=CLINIC_MASK, transform=TRANSFORM)
#     datasets_test.append(dataset)

dataloader_test = data.DataLoader(dataset=data.ConcatDataset(datasets_test), batch_size=10, shuffle=False)

In [None]:
c_white = colors.colorConverter.to_rgba('white',alpha = 0)
c_red= colors.colorConverter.to_rgba('red', alpha = 1)
c_blue = colors.colorConverter.to_rgba('blue', alpha = 1)
c_green = colors.colorConverter.to_rgba('green', alpha = 1)
c_gray = colors.colorConverter.to_rgba('gray', alpha = 1)
c_khaki = colors.colorConverter.to_rgba('khaki', alpha = 1)
cmap_red = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_red], 512)
cmap_blue = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_blue], 512)
cmap_green = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_green], 512)
cmap_gray = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_gray], 512)
cmap_khaki = colors.LinearSegmentedColormap.from_list('rb_cmap', [c_white,c_khaki], 512)
colormaps = [cmap_red, cmap_green, cmap_blue, cmap_gray, cmap_khaki]

In [None]:
x, y, name = next(iter(dataloader_train))

print(f'x = shape: {x.shape}; type: {x.dtype}')
print(f'x = min: {x.min()}; max: {x.max()}')
print(f'y = shape: {y.shape}; class: {y.unique()}; type: {y.dtype}')

id_ = 0
f, axarr = plt.subplots(1, 2)
axarr[0].imshow(x[id_].permute(1, 2, 0), cmap='gray')
if y.shape[1] == 5:
    axarr[1].imshow(y[id_][0], cmap=cmap_red)
    axarr[1].imshow(y[id_][1], cmap=cmap_green)
    axarr[1].imshow(y[id_][2], cmap=cmap_blue)
    axarr[1].imshow(y[id_][3], cmap=cmap_gray)
    axarr[1].imshow(y[id_][4], cmap=cmap_khaki)
else:
    axarr[1].imshow(y[0].permute(1, 2, 0))

axarr[0].set_title('Image')
axarr[1].set_title('5 Channel GT Mask')
plt.show()

In [None]:
config = {
    'learning_rate': lr,
    'epochs': n_epochs,
    'batch_size': batch_size,
    'channels_out': channels_out,
    'n_blocks': n_blocks,
    'start_filters': start_filters,
    'dice_weight': dice_weight,
    'smooth': smooth,
    'mask': SFU_MASK,
    'inception': 'fc',
    'loss': 'dice_loss+BCE',
    'activation': 'relu',
    'normalization': 'batch',
}
wandb.init(project='PLACEHOLDER', entity='PLACEHOLDER', group=f'PLACEHOLDER',
            config=config, reinit=True)

In [None]:
model_save_path = Path('PLACE_HOLDER')

In [None]:
incept_unet = torch.load(model_save_path)

In [None]:
summary(incept_unet, (1, 288, 288))

In [None]:
incept_unet = InceptionedUNet(1, channels_out, IMG_SIZE, batch_size, n_blocks=n_blocks,
                              start_filters=start_filters, incept_type='fc', fc_size='full',
                              final_activation='sigmoid', normalization='batch',)
incept_unet.to(device)
None

In [None]:
loss_fn = DicePlusBCELoss(dice_weight, smooth, rescale=False)
optimizer = torch.optim.Adam(incept_unet.parameters(), lr=lr)

In [None]:
trainer = Trainer(incept_unet, loss_fn, optimizer, n_epochs, device,
                  dataloader_train, dataloader_valid, dataloader_test, notebook=True)

In [None]:
trainer.model_train(log=True)

In [None]:
mean_score = False
trainer.thresh = 0.55
trainer.metrics = ['IoU', 'DSC', 'VS', 'RVD']
scores = trainer.model_eval(data_type='test', plot=True, display_name=True, mean_score=mean_score)

In [None]:
if mean_score:
    for key, value in scores.items():
        print(f'{key}: {np.mean(value)}')
        wandb.log({key.lower(): np.mean(value)})
if not mean_score:
    for key, value in scores.items():
        ch_1 = [v[0] for v in value]
        ch_2 = [v[1] for v in value]
        ch_3 = [v[2] for v in value]
        ch_4 = [v[3] for v in value]
        ch_5 = [v[4] for v in value]
        print(f'{key} ICM: {np.mean(ch_1)}')
        print(f'{key} TE: {np.mean(ch_2)}')
        print(f'{key} ZP: {np.mean(ch_3)}')
        print(f'{key} Background: {np.mean(ch_4)}')
        print(f'{key} Blastocoel: {np.mean(ch_5)}')

In [None]:
torch.save(incept_unet, model_save_path)