In [None]:
!pip install -r ./requirements.txt

In [None]:
!python3 -c "import monai" || pip install -q "monai-weekly[nibabel, tqdm]"
!pip install -q "monai-weekly[nibabel, tqdm, einops]"
!python3 -c "import matplotlib" || pip install -q matplotlib
%matplotlib inline

In [None]:
import re, time, os, shutil, json
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import SimpleITK as sitk  # noqa: N813
import numpy as np
import nibabel as nib
from PIL import Image
from monai.data import create_test_image_3d, list_data_collate, decollate_batch, pad_list_data_collate
import tempfile
import monai
from monai.inferers import sliding_window_inference
from monai.data import ITKReader, PILReader, ImageDataset, DataLoader, Dataset, PersistentDataset, CacheDataset, ArrayDataset
from monai.networks.layers import Norm
from monai.transforms import (
    LoadImage, EnsureChannelFirst, Spacing,
    RandFlip, Resize, EnsureType,
    LoadImaged, EnsureChannelFirstd,
    Resized, EnsureTyped, Compose, ScaleIntensityd, AddChanneld, MapTransform, AsChannelFirstd, EnsureType, Activations, AsDiscrete,
    RandCropByPosNegLabeld, RandRotate90d, LabelToMaskd, RandFlipd, RandRotated, Spacingd, RandAffined
)
#from monai.networks.nets import UNETR
#from torch.utils.tensorboard import SummaryWriter
from monai.transforms.intensity.array import ScaleIntensity
from monai.metrics import DiceMetric
import configdot
import torch
from monai.config import print_config
#from monai.engines import create_multigpu_supervised_trainer

print_config()

In [None]:
!mkdir -p ./MONAI_TMP

In [None]:
config = configdot.parse_config('configs/config.ini')

In [None]:
os.environ['MONAI_DATA_DIRECTORY'] = "./MONAI_TMP"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
BASE_DIR = '/workspace/RawData/Features'
OUTPUT_DIR = '/workspace/RawData/Features/BIDS'
TMP_DIR = '/workspace/Features/tmp'

In [None]:
!ls /workspace/RawData/Features/prep_wf | wc

In [None]:
def assign_feature_maps(sub, feature):
    global BASE_DIR
    global OUTPUT_DIR
    global TMP_DIR
    if feature == 'image':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'sub-{sub}_t1_brain-final.nii.gz')
    elif feature == 't2':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'sub-{sub}_t2_brain-final.nii.gz')
    elif feature == 'flair':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'sub-{sub}_fl_brain-final.nii.gz')
    elif feature == 'blurring-t1':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'Blurring_T1.nii.gz')
    elif feature == 'blurring-Flair':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'Blurring_Flair.nii.gz')
    elif feature == 'cr-t2':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'CR_T2.nii')
    elif feature == 'cr-Flair':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'CR_Flair.nii')
    elif feature == 'thickness':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'thickness_mni.nii')
    elif feature == 'curv':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'curv_mni.nii')
    elif feature == 'sulc':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'sulc_mni.nii')
    elif feature == 'variance':
        feature_map = os.path.join(BASE_DIR, f'preprocessed_data', 'var', f'sub-{sub}_var.nii.gz')
    elif feature == 'mask':
        feature_map = os.path.join(BASE_DIR, f'prep_wf', f'sub-{sub}', f'sub-{sub}_t1_brain-final_mask.nii.gz')
    return feature_map

In [None]:
import numpy as np
subjects_list = np.load('./metadata/metadata_fcd_nG.npy', allow_pickle=True)
subjects_list.item().get('train')

In [None]:
train_list = subjects_list.item().get('train')
val_list = subjects_list.item().get('test')

### images_list also Checks that features and labels available and returns number of available subjects

In [None]:
from sklearn.model_selection import train_test_split

images_list = []
subject_list = []
feat_params = config.dataset.features
# {'image': ['/workspace/Features/preprocessed_data/new_pipeline/sub-8/sub-8_acq-T1Mprage_space-MNI152NLint2_seq-T1w_brain.nii.gz',
#   '/workspace/Features/preprocessed_data/thickness/norm-8.nii.gz',
#   '/workspace/Features/preprocessed_data/curv/norm-8.nii.gz',
#           ]
# }

for i in os.listdir(OUTPUT_DIR):
    sub_ind = re.findall('-(.[a-zA-Z0-9]*|[0-9])', str(i))
    #if sub_ind and not any(x in sub_ind[0] for x in matches): # subjects with 'n', 'G', 'NS', 'C' won't be included
    if sub_ind:
        subject_list.append(sub_ind[0])
        
#random_seed = 666    
#train_list, val_list = train_test_split(subject_list, shuffle=False, train_size=0.80, random_state=random_seed)

train_subs_indcs = []
train_files = []

for sub in train_list:
    images_per_sub = dict()
    images_per_sub['image'] = []
    for feat in feat_params:
        map_path = assign_feature_maps(sub, feat)
        if os.path.isfile(map_path):
            images_per_sub['image'].append(map_path)
        else:
            print(f'No feature {feat} for sub {sub} in train data')
            continue
    if len(images_per_sub['image']) == len(feat_params):
        seg_path = os.path.join(BASE_DIR, 'preprocessed_data/label_bernaskoni', f'{sub}.nii.gz')
        if os.path.isfile(seg_path):
            images_per_sub['seg'] = seg_path
        else:
            continue
        train_subs_indcs.append(sub)
        train_files.append(images_per_sub)

val_subs_indcs = []
val_files = []

for sub in val_list:
    images_per_sub = dict()
    images_per_sub['image'] = []
    for feat in feat_params:
        map_path = assign_feature_maps(sub, feat)
        if os.path.isfile(map_path):
            images_per_sub['image'].append(map_path)
        else:
            print(f'No feature {feat} for sub {sub} in val data')
            continue
    if len(images_per_sub['image']) == len(feat_params):
        seg_path = os.path.join(BASE_DIR, 'preprocessed_data/label_bernaskoni', f'{sub}.nii.gz')
        if os.path.isfile(seg_path):
            images_per_sub['seg'] = seg_path
        else:
            print(f'No {seg_path} for sub {sub}')
            continue
        val_subs_indcs.append(sub)
        val_files.append(images_per_sub)

print(f"Train set length: {len(train_files)}\nTest set length: {len(val_files)}")

## Transformation and Augmentation

In [None]:
spatial_size_conf = tuple(config.default.interpolation_size)

train_transf = Compose(
    [
        LoadImaged(keys=["image", "seg"]),
        EnsureChannelFirstd(keys=["image", "seg"]),
        RandRotated(keys=["image", "seg"], range_x=0.25, range_y=0.25, range_z=0.25, prob=0.9),
        RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=0),
        Resized(keys=["image", "seg"], spatial_size=spatial_size_conf, mode=('area', 'nearest')),
        Spacingd(keys=['seg'], pixdim=1.0),
        ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
        #RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=1),
        #RandFlipd(keys=["image", "seg"], prob=0.5, spatial_axis=2),
        #RandSpatialCropd(keys=["image", "label"], roi_size=[224, 224, 144], random_size=False),
        #RandRotated(keys=["image", "seg"], range_x=0.0, range_y=0.0, range_z=0.75, prob=1),
        EnsureTyped(keys=["image", "seg"], dtype=torch.float),
    ]
)
val_transf = Compose(
    [
        LoadImaged(keys=["image", "seg"]),
        EnsureChannelFirstd(keys=["image", "seg"]),
        Resized(keys=["image", "seg"], spatial_size=spatial_size_conf, mode=('area', 'nearest')),
        Spacingd(keys=['seg'], pixdim=1.0),
        ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
        EnsureTyped(keys=["image", "seg"], dtype=torch.float),
    ]
)

In [None]:
check_dataset = Dataset(data=train_files, transform=train_transf)

In [None]:
check_dataset[0]['seg'].shape

In [None]:
check_dataset[0]['image'].shape

In [None]:
print(torch.cuda.is_available())
torch.cuda.set_device(2)
print(torch.cuda.current_device())

In [None]:
#check_loader = DataLoader(check_dataset, batch_size=10, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())
check_loader = DataLoader(check_dataset, batch_size=4, num_workers=0, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available())

In [None]:
check_data = monai.utils.misc.first(check_loader)

### Check batch size in check_loader and number of 

In [None]:
print(check_data["image"].shape, check_data["seg"].shape)

In [None]:
train_data_example = check_dataset[0]

In [None]:
print(f"image shape: {train_data_example['image'].shape}")
num_of_channels = len(feat_params)
label_ind = round(np.where(train_data_example["seg"]>0)[3].shape[0] / 2)
label_pos = np.where(train_data_example["seg"]>0)[3][label_ind]
plt.figure("image", (24, 6))
for i in range(num_of_channels):
    plt.subplot(1, num_of_channels, i + 1)
    plt.title(f"image channel {feat_params[i]}")
    plt.imshow(train_data_example['image'][i, :, :, label_pos], cmap="gray")
    plt.imshow(train_data_example["seg"][0,:, :, label_pos],interpolation='none', cmap='Reds', alpha=0.3)
    #plt.colorbar()
plt.show()
# also visualize the 3 channels label corresponding to this image
print(f"segmentaion shape: {train_data_example['seg'].shape}")
plt.figure("seg", (4, 6))
plt.imshow(train_data_example["seg"][0,:, :, label_pos], cmap="gray")
#plt.colorbar()
plt.show()

In [None]:
train_ds = monai.data.Dataset(data=train_files, transform=train_transf)
train_loader = DataLoader(
    train_ds,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

val_ds = monai.data.Dataset(data=val_files, transform=val_transf)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=0, collate_fn=list_data_collate)

In [None]:
ind = 1
train_files[ind]

## Transform random rotate augmentation example

### Before augmentation

In [None]:
from nilearn.plotting import plot_img
plot_img = plot_img(train_files[ind]['seg'],
         bg_img=train_files[ind]['image'][0],
         threshold=0.1, alpha=0.5, display_mode='z')
plot_img
print(plot_img.cut_coords) # get coordinate of z where lesion center mass

### Augmentation

In [None]:
plt.figure(figsize=(30,30))
for i in range(7):
    plt.subplot(1, 10, i+1)
    item = train_loader.dataset[ind]
    image, segme = item["image"], item["seg"]
    lab_loc = round(np.where(segme>0)[3].shape[0] / 2)
    lab_pos = np.where(segme>0)[3][lab_loc]
    plt.imshow(np.rot90(image[0,:, :, lab_pos]), cmap='gray')
    plt.imshow(np.rot90(segme[0,:, :, lab_pos]), cmap="Reds", alpha=0.4)
    plt.title("seg overlay")
plt.show()

In [None]:
a = train_loader.dataset.data[0]['seg'].split('/')[6].split('.')[0]
a

In [None]:
train_loader.dataset.data[0]['image'][2]

In [None]:
def one_epoch(model, 
                criterion, 
                opt, 
                config, 
                dataloader, 
                device, 
                writer, 
                epoch, 
                metric_dict_epoch, 
                n_iters_total=0,
                augmentation=None, 
                is_train=True):


    # use amp to accelerate training
    if config.opt.use_scaler:
        scaler = torch.cuda.amp.GradScaler()

    phase_name = 'train' if is_train else 'val'
    loss_name = config.opt.criterion
    metric_dict = defaultdict(list)
    target_metric_name = config.model.target_metric_name 

    if not is_train:
        model.eval()
    else:
        model.train()

    # used to turn on/off gradients
    grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
    with grad_context():
        iterator = enumerate(dataloader)
        val_predictions = {}
        for iter_i, data_tensors in iterator:
            
            brain_tensor, label_tensor = data_tensors['image'], data_tensors['seg']
            #mask_tensor = brain_tensor
            t1 = time.time()


            brain_tensor = brain_tensor.to(device)
            label_tensor = label_tensor.to(device)
            #mask_tensor = mask_tensor.to(device)

            # forward pass
            with autocast(enabled=config.opt.use_scaler):
                label_tensor_predicted = model(brain_tensor) # -> [bs,1,ps,ps,ps]

                loss = criterion(label_tensor_predicted, label_tensor) 


            if is_train:
                opt.zero_grad()

                if config.opt.use_scaler:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                
                if hasattr(config.opt, "grad_clip"):
                    if config.opt.use_scaler:
                        scaler.unscale_(opt)
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                           config.opt.grad_clip)

                metric_dict['grad_norm'].append(calc_gradient_norm(filter(lambda x: x[1].requires_grad, 
                                                model.named_parameters())))

                if config.opt.use_scaler:
                    scaler.step(opt)
                    scaler.update()
                else:
                    opt.step()


            t2 = time.time()    
            dt = t2-t1 # inference time
            
            metric_dict[f'batch_time'].append(dt)
            metric_dict[f'{loss_name}'].append(loss.item())
            #label_tensor_predicted = label_tensor_predicted*mask_tensor
            dice_score = DiceScoreBinary(label_tensor_predicted, label_tensor)
            coverage = (label_tensor_predicted*label_tensor).sum() / label_tensor.sum()
            
            if not is_train:
#                #label = dataloader.dataset.labels[iter_i]
                 #label = dataloader.dataset[iter_i]['seg']
                label = dataloader.dataset.data[iter_i]['seg'].split('/')[6].split('.')[0]
                val_predictions[label] = label_tensor_predicted.detach().cpu().numpy()
            
            metric_dict['coverage'].append(coverage.item())
            metric_dict['dice_score'].append(dice_score.item())
            
            #########
            # PRINT #
            #########
            message = f'For {phase_name}, iter: {iter_i},'
            for title, value in metric_dict.items():
                if title == 'grad_norm':
                    v = np.round(value[-1],6)
                else:
                    v = np.round(value[-1],3)
                message+=f' {title}:{v}'
            print(message)

            # print(f'Epoch: {epoch}, Iter: {iter_i}, \n \
            # Loss_{loss_name}: {loss.item()}, Dice-score: {dice_score.item()}, \n \
            # time: {np.round(dt,2)}-s')

            if is_train and writer is not None:
                for title, value in metric_dict.items():
                    writer.add_scalar(f"{phase_name}_{title}", value[-1], n_iters_total)

            n_iters_total += 1

    target_metric = 0
    for title, value in metric_dict.items():
        m = np.mean(value)
        metric_dict_epoch[phase_name + '_' + title].append(m)
        if title == target_metric_name:
            target_metric = m
        if writer is not None:
            writer.add_scalar(f"{phase_name}_{title}_epoch", m, epoch)
            
    #####################
    # SAVING BEST PREDS #
    #####################
    target_metrics_epoch = metric_dict_epoch[f'val_{target_metric_name}']
    if not is_train:
        if config.dataset.save_best_val_predictions:
            # use greedy-saving: save only if the target metric is improved
            if len(target_metrics_epoch) == 1 or target_metrics_epoch[-1] >= target_metrics_epoch[-2]:
                for label, pred in val_predictions.items():
                    #torch.save(pred, os.path.join(config.dataset.val_preds_path, f'{label}'))
                    torch.save(pred, os.path.join(config.dataset.val_preds_path, f'{label}'))


    return n_iters_total, target_metric

In [None]:
# Plot Image with one epoch iterator
"""
def one_epoch(model, 
                criterion, 
                opt, 
                config, 
                dataloader, 
                device, 
                writer, 
                epoch, 
                metric_dict_epoch, 
                n_iters_total=0,
                augmentation=None, 
                is_train=True):

    plt.figure(figsize=(30,30))
    # use amp to accelerate training
    if config.opt.use_scaler:
        scaler = torch.cuda.amp.GradScaler()

    phase_name = 'train' if is_train else 'val'
    loss_name = config.opt.criterion
    metric_dict = defaultdict(list)
    target_metric_name = config.model.target_metric_name 

    if not is_train:
        model.eval()
    else:
        model.train()

    # used to turn on/off gradients
    grad_context = torch.autograd.enable_grad if is_train else torch.no_grad
    with grad_context():
        iterator = enumerate(dataloader)
        val_predictions = {}
        for iter_i, data_tensors in iterator:
            image, segme = data_tensors['image'], data_tensors['seg']
            
            plt.subplot(1, 11, iter_i+1)
            lab_loc = round(np.where(segme>0)[4].shape[0] / 2)
            lab_pos = np.where(segme>0)[4][lab_loc]
            plt.imshow(image[0, 0,:, :, lab_pos], cmap='gray')
            plt.imshow(segme[0, 0,:, :, lab_pos], cmap="Reds", alpha=0.4)
            plt.imshow(image[0, 2,:, :, lab_pos], cmap="Greens", alpha=0.2)
            #torch.Size([2, 3, 128, 128, 128]) torch.Size([2, 1, 128, 128, 128])
            if iter_i > 10:
                break
"""

In [None]:
print(config.opt.criterion)

In [None]:
from datetime import datetime
from tensorboardX import SummaryWriter
from losses import DiceScoreBinary,\
                   DiceLossBinary,\
                   symmetric_focal_loss,\
                sym_unified_focal_loss,\
                symmetric_focal_tversky_loss,\
                DiceSFL,\
                tversky_loss
import torch.optim as optim
from models.v2v import V2VModel
from utils import save, parse_args, get_capacity, calc_gradient_norm
from collections import defaultdict
from IPython.core.debugger import set_trace
import traceback
import torch.nn.functional as F
from torch.cuda.amp import autocast



##########
# LOGDIR #
##########
MAKE_LOGS = config.default.make_logs
SAVE_MODEL = config.opt.save_model if hasattr(config.opt, "save_model") else True
DEVICE = config.opt.device if hasattr(config.opt, "device") else 1
device = torch.device(DEVICE)
print(device)


experiment_name = '{}@{}'.format(config.default.experiment_comment, datetime.now().strftime("%d.%m.%Y-%H"))
print("Experiment name: {}".format(experiment_name))

writer = None
if MAKE_LOGS:
    experiment_dir = os.path.join(config.default.log_dir, experiment_name)
    if os.path.isdir(experiment_dir):
        shutil.rmtree(experiment_dir)
    os.makedirs(experiment_dir)
    shutil.copy('configs/config.ini', os.path.join(experiment_dir, "config.ini"))
    
    # write .json dataset log
    ds_split_log = {"train": train_subs_indcs,
                    "val": val_subs_indcs}

    with open(os.path.join(config.default.log_dir, experiment_name, 'train_test_split.json'), 'w') as f:
        trvasp = json.dumps(ds_split_log)
        f.write(trvasp)
        f.close()

    if config.dataset.save_best_val_predictions:
        val_preds_path = os.path.join(experiment_dir, 'best_val_preds')
        config.dataset.val_preds_path = val_preds_path
        os.makedirs(val_preds_path)
    writer = SummaryWriter(os.path.join(experiment_dir, "tb"))
    
#########
# MODEL #
#########
if config.model.name == "v2v":
    model = V2VModel(config).to(device)
elif config.model.name == "unet3d":
    model = UnetModel(config).to(device)
capacity = get_capacity(model)

print(f'Model created! Capacity: {capacity}')

if hasattr(config.model, 'weights'):
    model_dict = torch.load(os.path.join(config.model.weights, 'checkpoints/weights.pth'))
    print(f'LOADING from {config.model.weights} \n epoch:', model_dict['epoch'])
    model.load_state_dict(model_dict['model_state'])


################
# CREATE OPTIM #
################
criterion = {
    "BCE":torch.nn.BCELoss, # [probabilities, target]
    "Dice":DiceLossBinary,
    "DiceBCE":None,
    "DiceSFL": DiceSFL(delta=config.opt.delta, gamma=config.opt.gamma),
    "TL": tversky_loss(delta=config.opt.delta),
    "FTL": symmetric_focal_tversky_loss(delta=config.opt.delta, gamma=config.opt.gamma),
    "SFL": symmetric_focal_loss(delta=config.opt.delta, gamma=config.opt.gamma),
    "USFL":sym_unified_focal_loss(weight=config.opt.weight, # 0.5
                                     delta=config.opt.delta,  # 0.6
                                     gamma=config.opt.gamma) # 0.5
}[config.opt.criterion]
opt = optim.Adam(model.parameters(), lr=config.opt.lr)

#####################
# ASSIGN DATALOADER #
#####################
train_dataloader = train_loader
val_dataloader = val_loader

#item = train_loader.dataset[ind]
#image, segme = item["image"], item["seg"]

In [None]:
model

In [None]:
print('Start training!')
metric_dict_epoch = defaultdict(list)
n_iters_total_train = 0 
n_iters_total_val = 0
target_metric = 0
target_metric_prev = -1
try:
    for epoch in range(config.opt.start_epoch, config.opt.n_epochs):
        print (f'TRAIN EPOCH: {epoch} ... ')
        n_iters_total_train, _  = one_epoch(model, 
                                        criterion, 
                                        opt, 
                                        config, 
                                        train_dataloader, 
                                        device, 
                                        writer, 
                                        epoch, 
                                        metric_dict_epoch, 
                                        n_iters_total_train,
                                        augmentation=None, # augmentation None because compose in dataloader
                                        is_train=True)

        print (f'VAL EPOCH: {epoch} ... ')
        n_iters_total_val, target_metric = one_epoch(model, 
                                        criterion, 
                                        opt, 
                                        config, 
                                        val_dataloader, 
                                        device, 
                                        writer, 
                                        epoch, 
                                        metric_dict_epoch, 
                                        n_iters_total_val,
                                        augmentation=None,
                                        is_train=False)

        if SAVE_MODEL and MAKE_LOGS:
            if not config.model.use_greedy_saving:
                print(f'SAVING...')
                save(experiment_dir, model, opt, epoch)
            # use greedy-saving: save only if the target metric is improved
            elif target_metric > target_metric_prev:
                print(f'target_metric = {target_metric}, SAVING...')
                save(experiment_dir, model, opt, epoch)
                target_metric_prev = target_metric
except Exception as e:
    print(traceback.format_exc())
    #set_trace()
    # keyboard interrupt
    if MAKE_LOGS:
        np.save(os.path.join(experiment_dir, 'metric_dict_epoch'), metric_dict_epoch)     


In [None]:
val_subs_indcs

In [None]:
# best_val_preds = {}
# for label in os.listdir(val_preds_path):
#     val_preds_label_path = os.path.join(val_preds_path, label)
#     best_val_preds[label] = torch.load(val_preds_label_path)[0,0]
experiment_name = 'v2v-AUG_YARKIN_onesite-subs_autocast_DICE-loss_lr-1e-3_t1+flair+thick+blurt1@22.06.2022-19'
val_preds_path = os.path.join('./logs',experiment_name, 'best_val_preds')
best_val_preds = {}
for label in os.listdir(val_preds_path):
    experiment_name = 'v2v-AUG_YARKIN_onesite-subs_autocast_DICE-loss_lr-1e-3_t1+flair+thick+blurt1@22.06.2022-19'
    best_val_preds_path = os.path.join(val_preds_path, label)
    best_val_preds[label] = torch.load(best_val_preds_path)

In [None]:
best_val_preds.keys()

In [None]:
#
import seaborn as sns
import warnings
warnings.filterwarnings("ignore")


ind_pred = str(29) # index for best val score tensor
ind = [i for i,x in enumerate(val_subs_indcs) if x == ind_pred][0]
item = val_loader.dataset[ind]
image, segme = item["image"], item["seg"]
#label_tensor_predicted = val_loader.dataset[ind]['seg'].clone().detach()
new_arr = best_val_preds[ind_pred][0][0]
#new_arr = np.where(best_val_preds[ind_pred][0][0] > 0.1, best_val_preds[ind_pred][0][0], 0)
lab_loc = round(np.where(segme>0)[3].shape[0] / 2)
lab_pos = np.where(segme>0)[3][lab_loc]

plt.figure(figsize=(30,7))
for i in range(20):
    plt.subplot(2, 10, i+1)
    plt.imshow(np.rot90(image[0,:, :, i*6]), cmap='gray')
    plt.imshow(np.rot90(new_arr[:, :, i*6].astype(np.float)), cmap="jet", alpha=0.5)
    plt.imshow(np.rot90(segme[0,:, :, i*6]), cmap="Greens", alpha=0.4)
    plt.title(f"Axial Image Slice # {i*12}")
plt.show()

plt.figure(figsize=(30,7))
for i in range(20):
    plt.subplot(2, 10, i+1)
    plt.imshow(np.rot90(image[0,:, i*6, :]), cmap='gray')
    plt.imshow(np.rot90(new_arr[:, i*6, :].astype(np.float)), cmap="jet", alpha=0.5)
    plt.imshow(np.rot90(segme[0,:, i*6, :]), cmap="Greens", alpha=0.4)
    plt.title(f"Sagital Image Slice # {i*12}")
plt.show()

plt.figure(figsize=(30,7))
for i in range(20):
    plt.subplot(2, 10, i+1)
    plt.imshow(np.rot90(image[0,i*6, :, :]), cmap='gray')
    plt.imshow(np.rot90(new_arr[i*6, :, :].astype(np.float)), cmap="jet", alpha=0.5)
    plt.imshow(np.rot90(segme[0,i*6, :, :]), cmap="Greens", alpha=0.4)
    plt.title(f"Sagital Image Slice # {i*12}")
plt.show()


dice_metric = DiceMetric(include_background=False)
dice_metric(y_pred=pred_tensor, y=segme[0])
pred_tensor = torch.Tensor(new_arr)
metric = dice_metric.aggregate().item()
print(f"Dice {metric}")

#plt.figure(figsize=(20,7))
#sns.distplot(best_val_preds[ind_pred][0][0])