In [None]:
import numpy as np

In [None]:
import os

In [None]:
from pathlib import Path
import sys
import os
current_file = Path(os.getcwd())
root_dir = Path('')
sys.path.append(str(root_dir / 'stylegan2-ada-pytorch'))
sys.path.append(str(root_dir))
sys.path.append(str(root_dir / 'InnerEye-Generative'))
# print(sys.path)
from azureml.core import Workspace
from datetime import datetime
from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from metrics.get_VGG_model import load_model
from loaders.prostate_loader import Prostate2DSimpleDataLoader
from models.UNet2D_seg_baseline import Model
from helpers.loggers import AzureMLLogger, TensorboardWithImgs, LoggerCollectionWithImgs
from argparse import ArgumentParser

print('{} -- {}:{}'.format(datetime.now().date(), datetime.now().hour+1, datetime.now().minute), end=' ')
print('-- Starting up')
seed_everything(1234)
# args
parser = ArgumentParser()
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--local_dataset_path", default='', type=str)
parser.add_argument("--csv_base_name", default='dataset.csv', type=str)  
parser.add_argument("--debug", default=False, type=bool)
parser.add_argument("--gpu", default=None, type=int)
parser.add_argument("--submit_to_azureml", '-aml', action='store_true', default=False)

parser = Trainer.add_argparse_args(parser)
parser = Model.add_model_specific_args(parser)

In [None]:
args = parser.parse_args('')
args.gpu = 0
#args.resume_from_checkpoint = 'outputs/wassertein_b256_b128_wm/epoch=340-step=57624.ckpt'

In [None]:
checkpoint_callback = None

# run on indicated GPU:
if args.gpu is not None and isinstance(args.gpu, int):
    # Make sure that it only uses a single GPU..
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    args.gpus = 1

# initialise model
model = Model(**vars(args))

# initialise loader
loader_gen = Prostate2DSimpleDataLoader(args.local_dataset_path, args.csv_base_name, args.batch_size, input_channels=1, 
                                        labels=args.labels,)

logger = TensorboardWithImgs('./outputs/UNet')
if args.azureml:
    AMLlogger = AzureMLLogger()
    logger = LoggerCollectionWithImgs([logger, AMLlogger])

In [None]:
import torch
mdata = torch.load('outputs/UNet/default/version_12/checkpoints/epoch=67-step=22983.ckpt')

In [None]:
model.load_state_dict(mdata['state_dict'])

In [None]:
import monai
self = model.to('cuda')

ys = []
y_preds = []
for batch in loader_gen.test_dataloader():
    batch = [b.to('cuda') for b in batch]
    x, y = self.prepare_batch(batch)
    y_hat = self.net(x)
    y_pred = torch.argmax(y_hat, 1, keepdim=True)
    y_pred = torch.cat([y_pred==i for i in range(y_hat.shape[1])], 1)
    ys.append(y)
    y_preds.append(y_pred)

y = torch.cat(ys)
y_pred = torch.cat(y_preds) 
metrics = monai.metrics.compute_meandice(y_pred, y, include_background=False)
non_zero_vols = torch.sum(y[:,1:], (2,3,4)) != 0
metrics_means = [metrics[:, i][non_zero_vols[:, i]].mean() for i in range(metrics.shape[1])]
metrics_means.append(metrics[non_zero_vols].mean())

In [None]:
metrics_means

In [None]:
from tqdm.notebook import tqdm

path = 'outputs/UNet/AML/'
dirs = [path + p + '/outputs/UNet/default/version_0/checkpoints/' for p in os.listdir(path)]
ks = []
metrics_all = []
for _dir in tqdm(dirs):
    if os.path.isdir(_dir):
            mean_DICE = np.array([float(p.rsplit('mean_DICE_val=')[1].rsplit('.ckpt')[0]) for p in os.listdir(_dir)])
            epoch = np.array([int(p.rsplit('epoch=')[1].rsplit('-step')[0]) for p in os.listdir(_dir)])
            chpt = _dir + os.listdir(_dir)[np.argmax(epoch[mean_DICE == mean_DICE.max()])]
            
            mdata = torch.load(chpt)
            ks.append(mdata['hyper_parameters']['k_shots'])

            model.load_state_dict(mdata['state_dict'])
            self = model.to('cuda')
            ys = []
            y_preds = []
            for batch in loader_gen.test_dataloader():
                batch = [b.to('cuda') for b in batch]
                x, y = self.prepare_batch(batch)
                y_hat = self.net(x)
                y_pred = torch.argmax(y_hat, 1, keepdim=True)
                y_pred = torch.cat([y_pred==i for i in range(y_hat.shape[1])], 1)
                ys.append(y)
                y_preds.append(y_pred)

            y = torch.cat(ys)
            y_pred = torch.cat(y_preds) 
            metrics = monai.metrics.compute_meandice(y_pred, y, include_background=False)
            non_zero_vols = torch.sum(y[:,1:], (2,3,4)) != 0
            metrics_means = [metrics[:, i][non_zero_vols[:, i]].mean() for i in range(metrics.shape[1])]
            metrics_means.append(metrics[non_zero_vols].mean())
            metrics_means = torch.stack(metrics_means).detach().cpu()
            metrics_all.append(metrics_means)

In [None]:
chpt = 'outputs/UNet/AML/UNet2D_main_1630502521_8ee37684/outputs/epoch=989-step=84149.ckpt'
mdata = torch.load(chpt)
ks.append(35)

model.load_state_dict(mdata['state_dict'])
self = model.to('cuda')
ys = []
y_preds = []
for batch in loader_gen.test_dataloader():
    batch = [b.to('cuda') for b in batch]
    x, y = self.prepare_batch(batch)
    y_hat = self.net(x)
    y_pred = torch.argmax(y_hat, 1, keepdim=True)
    y_pred = torch.cat([y_pred==i for i in range(y_hat.shape[1])], 1)
    ys.append(y)
    y_preds.append(y_pred)

y = torch.cat(ys)
y_pred = torch.cat(y_preds) 
metrics = monai.metrics.compute_meandice(y_pred, y, include_background=False)
non_zero_vols = torch.sum(y[:,1:], (2,3,4)) != 0
metrics_means = [metrics[:, i][non_zero_vols[:, i]].mean() for i in range(metrics.shape[1])]
metrics_means.append(metrics[non_zero_vols].mean())
metrics_means = torch.stack(metrics_means).detach().cpu()
metrics_all.append(metrics_means)

In [None]:
metrics_all = torch.stack(metrics_all).numpy()
ks = np.array(ks)

In [None]:
len(loader_gen.train_dataloader().dataset)

In [None]:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(np.sort(ks), metrics_all[np.argsort(ks), 0], label='femurs')
ax.plot(np.sort(ks), metrics_all[np.argsort(ks), 1], label='bladder')
ax.plot(np.sort(ks), metrics_all[np.argsort(ks), 2], label='prostate')
ax.plot(np.sort(ks), metrics_all[np.argsort(ks), 3], label='all', lw=2)
ax.set_xticks(ks)
ax.set_xticklabels(list(ks)[:-1] + ['all (10k)'])
ax.set_xlabel('Number of training samples')
plt.legend()
plt.title('Mean DICE score over test set')
plt.tight_layout()
plt.savefig('mean_DICE_baseline.png')