In [None]:
# # For nebula torch installation for A100
!pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113
#!pip install -r ./requirements.txt
#!pip install tensorboard

In [None]:
%config Completer.use_jedi = False
%load_ext autoreload
import re, time, os, shutil, json
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from PIL import Image
from monai.data import list_data_collate
import tempfile
import monai
from monai.data import DataLoader, Dataset, CacheDataset
from monai.apps import CrossValidation
from monai.transforms.intensity.array import ScaleIntensity
from monai.transforms import (
    LoadImage, EnsureChannelFirst, Spacing,
    RandFlip, Resize, EnsureType,
    LoadImaged, EnsureChannelFirstd,
    Resized, EnsureTyped, Compose, ScaleIntensityd, 
    AddChanneld, MapTransform, AsChannelFirstd, EnsureType, 
    Activations, AsDiscrete, RandCropByPosNegLabeld, 
    RandRotate90d, LabelToMaskd, RandFlipd, RandRotated, Spacingd, RandAffined,
    RandShiftIntensityd, Lambdad, MaskIntensityd
)
from utils import get_label, to_numpy

from dataset import (
    setup_dataloaders,
    create_datafile,
    setup_datafiles,
    setup_transformations
    )

import configdot
import torch
from monai.config import print_config
from IPython.core.debugger import set_trace
import tqdm
import tqdm.auto

def get_label(path):
    '''
    Extracts label from path, e.g.:
    '/workspace/RawData/Features/preprocessed_data/label_bernaskoni/n16.nii.gz' -> 'n16'
    '''
    return path.split('/')[-1].split('.')[0]

%autoreload 2

In [None]:
!mkdir -p ./MONAI_TMP

In [None]:
os.environ['MONAI_DATA_DIRECTORY'] = "./MONAI_TMP"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

# Create dataset

In [None]:
config = configdot.parse_config('configs/config.ini')

In [None]:
print(torch.cuda.is_available())
DEVICE = config.opt.device if hasattr(config.opt, "device") else 0
device = torch.device(DEVICE)
torch.cuda.set_device(DEVICE)

print('Setting GPU#:', DEVICE)
print('Using GPU#:', torch.cuda.current_device())

In [None]:
metadata_path = config.dataset.metadata_path
 
scaling_dict = None
if config.dataset.scaling_method in 'torchio':
    scaling_dict = 'torchio'
elif config.dataset.scaling_method in 'scale_metadata':

    scaling_data_path = config.dataset.scaling_metadata_path
    scaling_dict = np.load(scaling_data_path, allow_pickle=True).item()
else:
    print('Warning! no SCALING METADATA used! Applying naive independent MinMax...')

    
split_dict = np.load(metadata_path, allow_pickle=True).item()
train_list = split_dict['train']
val_list = split_dict['test']

images_list = []
feat_params = config.dataset.features

# Flag to add mask as additional sequence to Subset
add_mask = config.dataset.trim_background

train_files, train_missing_files = create_datafile(train_list, feat_params, mask=add_mask)
val_files, val_missing_files = create_datafile(val_list, feat_params, mask=add_mask)

In [None]:
from dataset import *

from torchio.transforms.preprocessing.intensity import histogram_standardization
from torchio.transforms.preprocessing.intensity import z_normalization

def scaling_as_torchio(data_dict, features, scaling_dict):
    mask_bool = data_dict["mask"] > 0.
    features_ = features
    for i, feature in enumerate(features_):
        #  condition, beceause some features like curv, sulc, thickness - don't need in scale, however, can be done.
        if feature not in ['blurring-t1', 'blurring-t2', 'blurring-Flair', 'cr-t2', 'cr-Flair', 'variance', 'entropy']:
            landmarks_path = Path(f'/workspace/FCDNet/landmarks/{feature}_landmarks.npy')
        else:
            landmarks_path = Path(f'/workspace/FCDNet/landmarks/{feature}_False_landmarks.npy')
        landmark =  np.load(landmarks_path)
        #d = torch.tensor(data_dict["image"][i])
        #m = torch.tensor(mask_bool)
        d = data_dict["image"][i]
        m = mask_bool
        prin
        d_n = histogram_standardization._normalize(d, landmark, m)
        tensor = z_normalization.ZNormalization.znorm(d_n, m)
        if tensor is not None:
            data_dict["image"][i] = tensor
    return data_dict

In [None]:
keys=["image", "seg", "mask"]
train_transf = Compose(
                        [
                            LoadImaged(keys=keys),
                            EnsureChannelFirstd(keys=keys),
                            mask_transform,
                            EnsureTyped(keys=keys, dtype=torch.float)
                        ]
                    )

In [None]:
train_ds = monai.data.Dataset(data=train_files, transform=train_transf)

In [None]:
features = config.dataset.features

In [None]:
from matplotlib.colors import LogNorm

data_dict = train_ds[0]
mask_bool = data_dict["mask"] > 0.
features_ = features


for i, feature in enumerate(features_):
    fig, ax = plt.subplots(1,2, figsize=(20, 10))
    #  condition, beceause some features like curv, sulc, thickness - don't need in scale, however, can be done.
    if feature not in ['blurring-t1', 'blurring-t2', 'blurring-Flair', 'cr-t2', 'cr-Flair', 'variance', 'entropy']:
        landmarks_path = Path(f'/workspace/FCDNet/landmarks/{feature}_landmarks.npy')
    else:
        landmarks_path = Path(f'/workspace/FCDNet/landmarks/{feature}_False_landmarks.npy')
    landmark =  np.load(landmarks_path)
    #d = torch.tensor(data_dict["image"][i])
    #m = torch.tensor(mask_bool)
    d = data_dict["image"][i]
    #m = mask_bool
    
    #ax[0].hist(d[:,:,:].numpy().flatten())
    a = ax[0].imshow(d[:,:,60].numpy())
    plt.colorbar(a, ax=ax[0])
    
    d_n = histogram_standardization._normalize(d, landmark, mask_bool)
    tensor = histogram_standardization._normalize(d, landmark, mask_bool)
    #tensor = z_normalization.ZNormalization.znorm(d_n, mask_bool)
    print(tensor is not None)
    if tensor is not None:
        data_dict["image"][i] = tensor
        print(f'{feature} normalized: \n Min Value: {data_dict["image"][i].max()} \n Max Value: {data_dict["image"][i].min()}')
    
    #ax[1].hist(tensor[:,:,:].numpy().flatten())
    #b = ax[1].matshow(tensor[:,:,60].numpy(),norm=LogNorm(vmin=10e-6, vmax=2))
    b = ax[1].imshow(tensor[:,:,60].numpy())

    
    plt.colorbar(b, ax=ax[1])
    plt.show()

In [None]:
from importlib.metadata import version 
version('torchio')

In [None]:
from dataset import *

interpolate = config.default.interpolate
if interpolate:
    spatial_size_conf = tuple(config.default.interpolation_size)
features = config.dataset.features

assert config.dataset.trim_background
keys=["image", "seg", "mask"]
sep_k=["seg", "mask"]

if scaling_dict in 'torchio':
    scaler = scaling_torchio_wrapper(features, scaling_dict)
elif scaling_dict in 'scale_metadata':
    scaler = scaling_specified_wrapper(features, scaling_dict)
else:
    scaler = ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True)

# no-augmentation transformation
val_transf = Compose(
                    [
                        LoadImaged(keys=keys),
                        EnsureChannelFirstd(keys=keys)
                    ] + ([Resized(keys=keys, spatial_size=spatial_size_conf)] if interpolate else []) + \
                    [
                        Spacingd(keys=sep_k, pixdim=1.0),
                        mask_transform, # zero the non-mask values
                        binarize_target,
                        scaler,
                        EnsureTyped(keys=sep_k, dtype=torch.float),
                    ]
                    )

if config.opt.augmentation:

    rand_affine_prob = config.opt.rand_affine_prob
    rot_range = config.opt.rotation_range
    shear_range = config.opt.shear_range
    scale_range = config.opt.scale_range
    translate_range = config.opt.translate_range

    noise_std = config.opt.noise_std
    flip_prob = config.opt.flip_prob
    rand_zoom_prob = config.opt.rand_zoom_prob

    # basic operations
    transforms = [LoadImaged(keys=keys), 
                  EnsureChannelFirstd(keys=keys),

                 ] + ([Resized(keys=keys, spatial_size=spatial_size_conf)] if interpolate else []) + \
                 [mask_transform,scaler, Spacingd(keys=sep_k, pixdim=1.0)]

    if rand_affine_prob == 0 and rot_range > 0:
        transforms.append(RandRotated(keys=keys, # apply to all!
                            range_x=rot_range, 
                            range_y=rot_range, 
                            range_z=rot_range, 
                            prob=0.5)
                         )
    if flip_prob > 0:
        transforms.append(RandFlipd(keys=keys, # apply to all!
                                    prob=flip_prob, 
                                    spatial_axis=0))

    if rand_affine_prob > 0:
        transforms.append(RandAffined(prob=rand_affine_prob, 
                                     rotate_range=[rot_range, rot_range, rot_range], 
                                     shear_range=[shear_range, shear_range, shear_range], 
                                     translate_range=[translate_range, translate_range, translate_range], 
                                     scale_range=[scale_range, scale_range, scale_range], 
                                     padding_mode='zeros',
                                     keys=keys # apply to all!
                                    )
                         )

    if noise_std > 0:
        transforms.append(RandGaussianNoised(prob=0.5, 
                                            mean=0.0, 
                                            std=noise_std, 
                                            keys=["image"]
                                           )
                         )

    if rand_zoom_prob > 0:
        transforms.append(RandZoomd(prob=0.5, min_zoom=0.9, max_zoom=1.1, keys=keys))

    # add the rest 
    transforms.extend([ # zero the non-mask values
                        binarize_target,
                        EnsureTyped(keys=sep_k, dtype=torch.float),
                     ]
                    )

    train_transf = Compose(transforms)
else:

    train_transf = val_transf

### Loading train-test split


In [None]:
# import numpy as np
# subjects_list = np.load('./metadata/metadata_fcd_nG.npy', allow_pickle=True).item()

In [None]:
# subjects_list

In [None]:
# train_list = subjects_list.get('train')
# val_list = subjects_list.get('test')

# feat_params = config.dataset.features

# print(len(train_list), len(val_list))

In [None]:
# feat_params

In [None]:
config = configdot.parse_config('configs/config.ini')
train_loader, val_loader = setup_dataloaders(config, pin_memory=False)

### Transformation and Augmentation

In [None]:
for i,check_data_sample in enumerate(train_loader):
    
    
    brain_tensor, label_tensor, mask_tensor = (check_data_sample['image'],
                                           check_data_sample['seg'],
                                           check_data_sample['mask']
                                           )

    batch_size = brain_tensor.shape[0]
    for k in range(batch_size):  

        image = brain_tensor[k]
        seg = label_tensor[k]
        mask = mask_tensor[k]

        num_of_channels = len(image)
        # choose z-coord where there is a label maximum over other axes
        label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()

        #mask = image[:1,...] <= 0 # `background mask

        torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

        fig = plt.figure("image", (10, 5), dpi=100)
        
        ax1 = plt.subplot(1, 2, 1)
        ax1.imshow(image[0,:,:,label_pos], cmap='gray')

        ax2 = plt.subplot(1, 2, 2)
        ax2.hist(image[0,:,:,label_pos].cpu().numpy().flatten(), bins=100)
        #ax2.imshow(seg[0,:,:,label_pos], alpha=0.5, cmap='jet')

        # plt.colorbar()
        # plt.xticks([])
        # plt.yticks([])

        # fig.suptitle(label, fontsize=20, color='blue')
        # plt.tight_layout()
        plt.show()
        
    # if i > 6:
    #     break


In [None]:
# assert config.default.interpolate
# spatial_size_conf = tuple(config.default.interpolation_size)
# #masked = config.dataset.trim_background
# masked = True

# def masked_transform(data_dict):
#     data_dict["image"] = data_dict["image"] * data_dict["mask"]
#     return data_dict

# if masked:
#     keys=["image", "seg", "mask"]
#     sep_k=["seg", "mask"]
# else:
#     keys=["image", "seg"]
#     sep_k=["seg"]

# if config.opt.augmentation:
#     rot_range = 0.5 

#     train_transf = Compose(
#         [
#             LoadImaged(keys=keys),
#             EnsureChannelFirstd(keys=keys),
#             RandRotated(keys=keys, 
#                         range_x=rot_range, 
#                         # range_y=rot_range, 
#                         # range_z=rot_range, 
#                         prob=1),
#             RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
#             Spacingd(keys=sep_k, pixdim=1.0),
#             Resized(keys=keys, spatial_size=spatial_size_conf),
#             ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
#             # masked_transform,
#             EnsureTyped(keys=keys, dtype=torch.float)
#         ]
#     )

#     val_transf = Compose(
#         [
#                 LoadImaged(keys=keys),
#                 EnsureChannelFirstd(keys=keys),
#                 Spacingd(keys=sep_k, pixdim=1.0),
#                 Resized(keys=keys, spatial_size=spatial_size_conf),
#                 ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
#                 # masked_transform,
#                 EnsureTyped(keys=keys, dtype=torch.float),
#             ]
#     )
    
# else:
#     raise NotImplementedError

In [None]:
# config.opt.augmentation

### Visualization

In [None]:
# check_batch_size = 2
# check_dataset = Dataset(data=train_files, transform=train_transf)
# check_loader = DataLoader(check_dataset, 
#                           batch_size=check_batch_size, 
#                           num_workers=0, 
#                           collate_fn=list_data_collate, 
#                           pin_memory=torch.cuda.is_available(),
#                           shuffle=True
#                           )

# check_data = monai.utils.misc.first(check_loader)
# # check_data = monai.utils.misc.first(train_loader)

In [None]:
# for check_data_sample in check_dataset:
#     break

In [None]:
# check_data['image'][:,1,...]

In [None]:
# check_data['image'][:,1,...]

In [None]:
# for k in range(check_batch_size):
    
#     image = check_data['image'][k]
#     seg = check_data['seg'][k]
#     mask = check_data['mask'][k]
#     label = get_label(check_dataset.data[k]['seg'])
    
#     print(f"image shape: {image.shape}")
    
#     num_of_channels = len(feat_params)
#     # choose z-coord where there is a label maximum over other axes
#     label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
#     #mask = image[:1,...] <= 0 # `background mask
    
#     torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

#     fig = plt.figure("image", (10, 5), dpi=200)
#     ax1 = plt.subplot(1, 2, 1)
#     #plt.title(f"{feat_params[i]}")
#     ax1.imshow(image[1,:,:,label_pos], cmap='gray')
#     ax2 = plt.subplot(1, 2, 2)
#     ax2.imshow(mask[0,:,:,label_pos], alpha=0.5)
    
#     # plt.colorbar()
#     plt.xticks([])
#     plt.yticks([])
        
#     fig.suptitle(label, fontsize=20, color='blue')
#     # plt.tight_layout()
#     plt.show()
#     if k > 2:
#         break

In [None]:
# np.array(feat_params)[torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1]

In [None]:
# for k in range(check_batch_size):
    
#     image = check_data['image'][k]
#     seg = check_data['seg'][k]
#     label = get_label(check_dataset.data[k]['seg'])
    
#     print(f"image shape: {image.shape}")
    
#     num_of_channels = len(feat_params)
#     # choose z-coord where there is a label maximum over other axes
#     label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
#     mask = image[:1,...] <= 0 # `background mask
#     torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

#     fig = plt.figure("image", (5, 5), dpi=200)
    
#     plt.subplot(1, 1, 1)
#     image_bin = image[-2,:,:,label_pos] > 0
#     plt.imshow(image_bin, cmap='gray')
#     plt.colorbar()
#     #plt.title(f"{feat_params[i]}, {image_bin.sum()}")
#     # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#     # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#     plt.xticks([])
#     plt.yticks([])
        
#     fig.suptitle(label, fontsize=20, color='blue')
#     # plt.tight_layout()
#     plt.show()

In [None]:
# figures_per_row = 6 # for visualization
# for k in range(check_batch_size):
    
#     image = check_data['image'][k]
#     seg = check_data['seg'][k]
#     label = get_label(check_dataset.data[k]['seg'])
    
#     print(f"image shape: {image.shape}")
    
#     num_of_channels = len(feat_params)
#     # choose z-coord where there is a label maximum over other axes
#     label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
#     mask = image[:1,...] <= 0 # `background mask
#     torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

#     fig = plt.figure("image", (len(feat_params), 5), dpi=200)
#     for i in range(num_of_channels):
#         nrows = int(np.ceil(num_of_channels/figures_per_row))
#         cols = num_of_channels%figures_per_row
#         plt.subplot(nrows, figures_per_row, i+1)
#         plt.title(f"{feat_params[i]}")
#         plt.imshow(image[i,:,:,label_pos] > 0, cmap="gray")
#         # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#         # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#         plt.xticks([])
#         plt.yticks([])
        
#     fig.suptitle(label, fontsize=20, color='blue')
#     # plt.tight_layout()
#     plt.show()

# Check dataloaders

In [None]:
config = configdot.parse_config('configs/config.ini')
# train_loader, val_loader = setup_dataloaders(config)

In [None]:
def mask_transform(data_dict):
    data_dict["mask"] = (data_dict["mask"] > 0).astype(np.int)
    data_dict["image"] = data_dict["image"] * (data_dict["mask"])
    return data_dict

def setup_transformations(config):
    
    # assert False, 'Check mask mult!'
    
    assert config.default.interpolate
    spatial_size_conf = tuple(config.default.interpolation_size)
    
    keys=["image", "seg", "mask"]
    sep_k=["seg", "mask"]

    if config.opt.augmentation:
        rot_range = config.opt.rotation_range

        train_transf = Compose(
            [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),
                RandRotated(keys=keys, 
                            range_x=rot_range, #rot_range, 
                            range_y=rot_range, 
                            range_z=rot_range, 
                            prob=0.5),
                RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
                Resized(keys=keys, spatial_size=spatial_size_conf),
                Spacingd(keys=sep_k, pixdim=1.0),
                ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
                mask_transform, 
                EnsureTyped(keys=sep_k, dtype=torch.float),
            ]
        )

        val_transf = Compose(
            [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),
                Resized(keys=keys, spatial_size=spatial_size_conf),
                Spacingd(keys=sep_k, pixdim=1.0),
                ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
                mask_transform,
                EnsureTyped(keys=sep_k, dtype=torch.float),
            ]
        )

    else:
        raise NotImplementedError
        
    return train_transf, val_transf

In [None]:
metadata_path = config.dataset.metadata_path
split_dict = np.load(metadata_path, allow_pickle=True).item()   



train_files, val_files = setup_datafiles(split_dict, config)
train_transf, val_transf = setup_transformations(config)

# training dataset
train_ds = monai.data.Dataset(data=train_files, transform=train_transf)
train_loader = DataLoader(
    train_ds,
    batch_size=config.opt.train_batch_size,
    shuffle=config.dataset.shuffle_train,
    num_workers=0,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# validation dataset
val_ds = monai.data.Dataset(data=val_files, transform=val_transf)
val_loader = DataLoader(val_ds, 
                        batch_size=config.opt.val_batch_size, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )

In [None]:
val_loader_aug = DataLoader(monai.data.Dataset(data=val_files, transform=train_transf), 
                        batch_size=config.opt.val_batch_size, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )

In [None]:
data = next(iter(val_loader_aug))
mask = data['mask'][0,0]

In [None]:
nbins=50
h1 = torch.histc(data['image'][0,0][mask > 0], min=0, max=1, bins=nbins).numpy()
h2 = torch.histc(data['image'][0,1][mask > 0], min=0, max=1, bins=nbins).numpy()

In [None]:
plt.plot(np.linspace(1e-10,1,nbins), h1, alpha=0.5) # , align='edge'
plt.plot(np.linspace(1e-10,1,nbins), h2, alpha=0.2)
# plt.xticks(np.arange(100)[::20], np.linspace(0,1,100)[::20])
# plt.xticks(np.linspace(0,1,100)[::20], np.linspace(0,1,100)[::20])

In [None]:
# image = data['image'][0,0]
# curv = data['image'][0,1]

# plt.hist(image[mask > 0] ,bins=50, alpha=0.5)
# plt.hist(curv[mask > 0] ,bins=50, alpha=0.5)
# plt.show()

In [None]:
def show_prediction_slice(b_ind=0, c_ind=0):
    
    '''
    b_ind - batch_index
    c_ind - channel index for `brain_tensor`
    brain_tensor - [bs,C,1,H,W,D]
    mask_tensor - [bs,1,1,H,W,D]
    label_tensor - [bs,1,1,H,W,D]
    label_tensor_predicted - [bs,1,1,H,W,D]
    '''
    
    label_pos = (label_tensor[b_ind,0] > 0).sum(dim=(0,1)).argmax().item()
    
    fig = plt.figure("image", (3*5, 5), dpi=100)
    
    ax1 = plt.subplot(1, 3, 1)
    ax1.imshow(to_numpy(brain_tensor[b_ind,c_ind,:,:,label_pos]), cmap='gray')
    # ax1.imshow(to_numpy(mask_tensor[b_ind,0,:,:,label_pos]), alpha=0.2, cmap='Reds')
    ax1.imshow(to_numpy(label_tensor[b_ind,0,:,:,label_pos]), alpha=0.6, cmap='Reds')
    
    ax2 = plt.subplot(1, 3, 2)
    ax2.imshow(to_numpy(brain_tensor[b_ind,c_ind,:,:,label_pos]), cmap='gray')
    # ax2.imshow(to_numpy(mask_tensor[b_ind,0,:,:,label_pos]), alpha=0.2, cmap='Reds')
    ax2.imshow(to_numpy(label_tensor_predicted[b_ind,0,:,:,label_pos]), alpha=0.6, cmap='Reds')
    
    ax3 = plt.subplot(1, 3, 3)
    ax3.imshow(to_numpy(mask_tensor[b_ind,0,:,:,label_pos]), cmap='jet')
    
    plt.xticks([])
    plt.yticks([])
    
    plt.show()

In [None]:
!python train_seg.py

In [None]:
for data_tensors in val_loader_aug:
    brain_tensor, label_tensor, mask_tensor = (
                                                data_tensors['image'].to(device),
                                                data_tensors['seg'].to(device),
                                                data_tensors['mask'].to(device)
                                                )
label_tensor_predicted = label_tensor

show_prediction_slice()

In [None]:
for k,(data, data_aug) in enumerate(zip(val_loader, val_loader_aug)):

    image = data['image'][0,1] 
    image_aug = data_aug['image'][0,0]
    
    seg = data['seg'][0,0]
    seg_aug = data_aug['seg'][0,0]
    
    mask = data['mask'][0,0]
    mask_aug = data_aug['mask'][0,0]
    
    label = get_label(val_loader.dataset.data[k]['seg'])
    
    # choose z-coord where there is a label maximum over other axes
    label_pos = (seg > 0).sum(dim=(0,1)).argmax().item()
    
    fig = plt.figure("image", (2*5, 5), dpi=200)

    ax1 = plt.subplot(1, 2, 1)
    ax1.imshow(image[:,:,label_pos]) # , cmap='gray'
    ax1.imshow(mask[:,:,label_pos], alpha=0.2, cmap='jet')
    
    ax2 = plt.subplot(1, 2, 2)
    ax2.imshow(image_aug[:,:,label_pos]) # , cmap='gray'
    ax2.imshow(mask_aug[:,:,label_pos], alpha=0.2, cmap='jet')
    
    # plt.colorbar()
    plt.xticks([])
    plt.yticks([])
        
    fig.suptitle(label, fontsize=20, color='blue')
    # plt.tight_layout()
    plt.show()
    
    if k > 1:
        break

In [None]:
# data['image'][0,0]

In [None]:
# # uncomment to check non-zero backgrounds in dataloader

# features_cumsum = torch.zeros(len(config.dataset.features))
# for train_batch in tqdm(train_loader):
#     features_cumsum += train_batch['image'].sum(0)[:,:5,:5,:5].sum(dim=(-1,-2,-3))
# for val_batch in tqdm(val_loader):
#     features_cumsum += val_batch['image'].sum(0)[:,:5,:5,:5].sum(dim=(-1,-2,-3))
# np.array(config.dataset.features)[features_cumsum > 0]

In [None]:
n_features = check_data['image'].shape[1]

In [None]:
mask = check_data['mask']
mask.unique()

In [None]:
for k in range(2):
    
    image = check_data['image'][k]
    seg = check_data['seg'][k]
    mask = check_data['mask'][k]
    label = get_label(train_ds.data[k]['seg'])
    
    print(f"image shape: {image.shape}")
    
    # choose z-coord where there is a label maximum over other axes
    label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
    fig = plt.figure("image", (n_features*5, 5), dpi=200)
    for i in range(n_features):
        ax1 = plt.subplot(1, n_features, i+1)
        ax1.imshow(image[i,:,:,label_pos], cmap='gray')
        # ax2 = plt.subplot(1, 2*n_features, i+1)
        ax1.imshow(mask[0,:,:,label_pos], alpha=0.2, cmap='jet')
    
    # plt.colorbar()
    plt.xticks([])
    plt.yticks([])
        
    fig.suptitle(label, fontsize=20, color='blue')
    # plt.tight_layout()
    plt.show()
    if k > 2:
        break

### Cross - Validation

In [None]:
train_loaders, val_loaders, test_loader = setup_dataloaders_cv(config)

In [None]:
@abstractmethod
class CVDataset(ABC, CacheDataset):
    """
    Base class to generate cross validation datasets.

    """
    def __init__(
        self,
        data,
        transform,
        cache_num=sys.maxsize,
        cache_rate=1.0,
        num_workers=None,
    ) -> None:
        data = self._split_datalist(datalist=data)
        CacheDataset.__init__(
            self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
        )

    def _split_datalist(self, datalist):
        raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")

In [None]:
def mask_transform(data_dict):
    data_dict["mask"] = (data_dict["mask"] > 0).astype(np.int)
    data_dict["image"] = data_dict["image"] * (data_dict["mask"])
    return data_dict
    
spatial_size_conf = ([128, 128, 128])
    
keys=["image", "seg", "mask"]
sep_k=["seg", "mask"]

rot_range = 0.15


train_transf = Compose(
            [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),
                RandRotated(keys=keys, 
                            range_x=rot_range, #rot_range, 
                            range_y=rot_range, 
                            range_z=rot_range, 
                            prob=0.5),
                RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
                Resized(keys=keys, spatial_size=spatial_size_conf),
                Spacingd(keys=sep_k, pixdim=1.0),
                ScalePerfectly(keys=['image'], 
                               minv=[None, None, -4, 10, ...], 
                               maxv=[None, None, -4, 10, ...], 
                               channelwise=True)
                # ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
                mask_transform, 
                EnsureTyped(keys=sep_k, dtype=torch.float),
            ]
        )

val_transf = Compose(
            [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),
                Resized(keys=keys, spatial_size=spatial_size_conf),
                Spacingd(keys=sep_k, pixdim=1.0),
                ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
                mask_transform,
                EnsureTyped(keys=sep_k, dtype=torch.float),
            ]
        )

In [None]:
num = 8
folds = list(range(num))

cvdataset = CrossValidation(
    dataset_cls=CVDataset,
    data=dataset_filepaths[0][:80],
    nfolds=8,
    seed=42,
    transform=train_transf,
)

In [None]:
folds

In [None]:
train_dss = [cvdataset.get_dataset(folds=folds[0: i] + folds[(i + 1):]) for i in folds]
val_dss = [cvdataset.get_dataset(folds=i, transform=val_transf) for i in range(num)]

In [None]:
train_loaders = [DataLoader(train_dss[i], batch_size=2, shuffle=True, num_workers=0) for i in folds]
val_loaders = [DataLoader(val_dss[i], batch_size=1, num_workers=0) for i in folds]

In [None]:
test_ds = CacheDataset(data=dataset_filepaths[0][80:], transform=val_transf, num_workers=None)
test_loader = DataLoader(test_ds, batch_size=1, num_workers=0)

In [None]:
train_dss

In [None]:
for i in train_dss[0]:
    #print(len(i))
    file_path = i['image_meta_dict']['filename_or_obj']
    index = file_path.split('/')[5].split('-')[-1]
    print(index)

In [None]:
from models.v2v import V2VModel

assert config.model.name == "v2v"
best_model = V2VModel(config).to(device)