In [None]:
!pip install -r ./requirements.txt

In [None]:
# For nebula torch installation for A100
!pip3 install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113

In [None]:
%load_ext autoreload
import re, time, os, shutil, json
from easydict import EasyDict as edict
import matplotlib.pyplot as plt
import numpy as np
import nibabel as nib
from PIL import Image
from monai.data import list_data_collate
import tempfile
import monai
from monai.data import DataLoader, Dataset 
from monai.transforms.intensity.array import ScaleIntensity
from monai.transforms import (
    LoadImage, EnsureChannelFirst, Spacing,
    RandFlip, Resize, EnsureType,
    LoadImaged, EnsureChannelFirstd,
    Resized, EnsureTyped, Compose, ScaleIntensityd, 
    AddChanneld, MapTransform, AsChannelFirstd, EnsureType, 
    Activations, AsDiscrete, RandCropByPosNegLabeld, 
    RandRotate90d, LabelToMaskd, RandFlipd, RandRotated, Spacingd, RandAffined,
    RandShiftIntensityd, Lambdad, MaskIntensityd
)
from utils import get_label

from dataset import setup_dataloaders, setup_datafiles

import configdot
import torch
from monai.config import print_config
from IPython.core.debugger import set_trace

%autoreload 2

In [None]:
!mkdir -p ./MONAI_TMP

In [None]:
os.environ['MONAI_DATA_DIRECTORY'] = "./MONAI_TMP"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
logdir = '/workspace/RawData/FCDNet/logs/features_comparison/t1_all'

os.listdir(logdir)

In [None]:
logdir = '/workspace/RawData/FCDNet/logs/features_comparison/t1_all'
for logpath in os.listdir(logdir):
    confpath = os.path.join(logdir, logpath, 'configs/config.ini')
    config = configdot.parse_config(confpath)
    break

In [None]:
logpath

In [None]:
config

# Create dataset

In [None]:
config = configdot.parse_config('configs/config.ini')

In [None]:
print(torch.cuda.is_available())
DEVICE = config.opt.device if hasattr(config.opt, "device") else 0
#DEVICE=0
device = torch.device(DEVICE)
torch.cuda.set_device(DEVICE)

print('Setting GPU#:', DEVICE)
print('Using GPU#:', torch.cuda.current_device())

### Loading train-test split


In [None]:
import numpy as np
subjects_list = np.load('./metadata/metadata_fcd_nG.npy', allow_pickle=True).item()

In [None]:
subjects_list

In [None]:
train_list = subjects_list.get('train')
val_list = subjects_list.get('test')

feat_params = config.dataset.features

print(len(train_list), len(val_list))

In [None]:
feat_params

In [None]:
# train_loader, val_loader = setup_dataloaders(config)
train_files, val_files = setup_datafiles(subjects_list, config)

### Transformation and Augmentation

In [None]:
assert config.default.interpolate
spatial_size_conf = tuple(config.default.interpolation_size)
#masked = config.dataset.trim_background
masked = True

def masked_transform(data_dict):
    # set_trace()
    data_dict["image"] = data_dict["image"] * data_dict["mask"]
    return data_dict

if masked:
    keys=["image", "seg", "mask"]
    sep_k=["seg", "mask"]
else:
    keys=["image", "seg"]
    sep_k=["seg"]

if config.opt.augmentation:
    rot_range = 0.5 #config.opt.rotation_range

    train_transf = Compose(
        [
            LoadImaged(keys=keys),
            EnsureChannelFirstd(keys=keys),
            RandRotated(keys=keys, 
                        range_x=rot_range, 
                        # range_y=rot_range, 
                        # range_z=rot_range, 
                        prob=1),
            RandFlipd(keys=keys, prob=0.5, spatial_axis=0),
            Spacingd(keys=sep_k, pixdim=1.0),
            Resized(keys=keys, spatial_size=spatial_size_conf),
            ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
            # masked_transform,
            EnsureTyped(keys=keys, dtype=torch.float)
        ]
    )

    val_transf = Compose(
        [
                LoadImaged(keys=keys),
                EnsureChannelFirstd(keys=keys),
                Spacingd(keys=sep_k, pixdim=1.0),
                Resized(keys=keys, spatial_size=spatial_size_conf),
                ScaleIntensityd(keys=["image"], minv=0.0, maxv=1.0, channel_wise=True),
                # masked_transform,
                EnsureTyped(keys=keys, dtype=torch.float),
            ]
    )
    
else:
    raise NotImplementedError

In [None]:
config.opt.augmentation

### Visualization

In [None]:
def get_label(path):
    '''
    Extracts label from path, e.g.:
    '/workspace/RawData/Features/preprocessed_data/label_bernaskoni/n16.nii.gz' -> 'n16'
    '''
    return path.split('/')[-1].split('.')[0]

In [None]:
check_batch_size = 2
check_dataset = Dataset(data=train_files, transform=train_transf)
check_loader = DataLoader(check_dataset, 
                          batch_size=check_batch_size, 
                          num_workers=0, 
                          collate_fn=list_data_collate, 
                          pin_memory=torch.cuda.is_available(),
                          shuffle=True
                          )

check_data = monai.utils.misc.first(check_loader)
# check_data = monai.utils.misc.first(train_loader)

In [None]:
# for check_data_sample in check_dataset:
#     break

In [None]:
# check_data['image'][:,1,...]

In [None]:
print(check_data["image"].shape, check_data["seg"].shape, check_data["mask"].shape, len(feat_params))

In [None]:
feat_params

In [None]:
# check_data['image'][:,1,...]

In [None]:
for k in range(check_batch_size):
    
    image = check_data['image'][k]
    seg = check_data['seg'][k]
    mask = check_data['mask'][k]
    label = get_label(check_dataset.data[k]['seg'])
    
    print(f"image shape: {image.shape}")
    
    num_of_channels = len(feat_params)
    # choose z-coord where there is a label maximum over other axes
    label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
    #mask = image[:1,...] <= 0 # `background mask
    
    torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

    fig = plt.figure("image", (10, 5), dpi=200)
    ax1 = plt.subplot(1, 2, 1)
    #plt.title(f"{feat_params[i]}")
    ax1.imshow(image[1,:,:,label_pos], cmap='gray')
    ax2 = plt.subplot(1, 2, 2)
    ax2.imshow(mask[0,:,:,label_pos], alpha=0.5)
    
    # plt.colorbar()
    plt.xticks([])
    plt.yticks([])
        
    fig.suptitle(label, fontsize=20, color='blue')
    # plt.tight_layout()
    plt.show()
    if k > 2:
        break

In [None]:
# np.array(feat_params)[torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1]

In [None]:
feat_params

In [None]:
image[1,:,:,]

In [None]:
for k in range(check_batch_size):
    
    image = check_data['image'][k]
    seg = check_data['seg'][k]
    label = get_label(check_dataset.data[k]['seg'])
    
    print(f"image shape: {image.shape}")
    
    num_of_channels = len(feat_params)
    # choose z-coord where there is a label maximum over other axes
    label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
    mask = image[:1,...] <= 0 # `background mask
    torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

    fig = plt.figure("image", (5, 5), dpi=200)
    
    plt.subplot(1, 1, 1)
    image_bin = image[-2,:,:,label_pos] > 0
    plt.imshow(image_bin, cmap='gray')
    plt.colorbar()
    #plt.title(f"{feat_params[i]}, {image_bin.sum()}")
    # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
    # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
    plt.xticks([])
    plt.yticks([])
        
    fig.suptitle(label, fontsize=20, color='blue')
    # plt.tight_layout()
    plt.show()

In [None]:
# figures_per_row = 6 # for visualization
# for k in range(check_batch_size):
    
#     image = check_data['image'][k]
#     seg = check_data['seg'][k]
#     label = get_label(check_dataset.data[k]['seg'])
    
#     print(f"image shape: {image.shape}")
    
#     num_of_channels = len(feat_params)
#     # choose z-coord where there is a label maximum over other axes
#     label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    
#     mask = image[:1,...] <= 0 # `background mask
#     torch.sum(mask * image, dim=(-1,-2,-3)).type(torch.int) > 1

#     fig = plt.figure("image", (len(feat_params), 5), dpi=200)
#     for i in range(num_of_channels):
#         nrows = int(np.ceil(num_of_channels/figures_per_row))
#         cols = num_of_channels%figures_per_row
#         plt.subplot(nrows, figures_per_row, i+1)
#         plt.title(f"{feat_params[i]}")
#         plt.imshow(image[i,:,:,label_pos] > 0, cmap="gray")
#         # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#         # plt.imshow(seg[0,:,:,label_pos], interpolation='none', cmap='Reds', alpha=0.3)
#         plt.xticks([])
#         plt.yticks([])
        
#     fig.suptitle(label, fontsize=20, color='blue')
#     # plt.tight_layout()
#     plt.show()

# Setup dataloaders

In [None]:
# training dataset
train_ds = monai.data.Dataset(data=train_files, transform=train_transf)
train_loader = DataLoader(
    train_ds,
    batch_size=config.opt.train_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# validation dataset
val_ds = monai.data.Dataset(data=val_files, transform=val_transf)
val_loader = DataLoader(val_ds, 
                        batch_size=config.opt.val_batch_size, 
                        num_workers=0, 
                        collate_fn=list_data_collate,
                        shuffle=False # important not to shuffle, to ensure label correspondence
                        )

### Before augmentation

In [None]:
# from nilearn.plotting import plot_img
# ind = 0
# plot_img = plot_img(train_files[ind]['seg'],
#          bg_img=train_files[ind]['image'][0],
#          threshold=0.1, alpha=0.5, display_mode='z')
# plot_img
# print(plot_img.cut_coords) # get coordinate of z where lesion center mass

### After augmentation

In [None]:
ind=0
plt.figure(figsize=(30,30))
for i in range(7):
    plt.subplot(1, 10, i+1)
    item = check_loader.dataset[ind]
    image, seg, mask = (item["image"], item["seg"], item["mask"])
    label_pos = (seg > 0).sum(dim=(0,1,2)).argmax().item()
    plt.imshow(np.rot90(image[0,:, :, label_pos]), cmap='gray')
    plt.imshow(np.rot90(seg[0,:, :, label_pos]), cmap="Reds", alpha=0.4)
    plt.imshow(np.rot90(mask[0,:, :, label_pos]), cmap="Greens", alpha=0.6)
    plt.title("seg overlay")
plt.show()