In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, phys_size, x=None, y=None, z=None, window=None, level=None, existing_ax=None):
    width, height, depth = phys_size
    
    size = np.flip(img.shape)
    spacing = phys_size / size

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img) - np.min(img)

    if level is None:
        level = window / 2 + np.min(img)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(1, 3, figsize=(14, 8))
    else:
        axes = existing_ax

    axes[0].imshow(img[z,:,:], cmap='gray', clim=(low, high), extent=(0, width, height, 0))
    axes[1].imshow(img[:,y,:], origin='lower', cmap='gray', clim=(low, high), extent=(0, width,  0, depth))
    axes[2].imshow(img[:,:,x], origin='lower', cmap='gray', clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    axes[0].axhline(y * spacing[1], lw=1)
    axes[0].axvline(x * spacing[0], lw=1)

    axes[1].axhline(z * spacing[2], lw=1)
    axes[1].axvline(x * spacing[0], lw=1)

    axes[2].axhline(z * spacing[2], lw=1)
    axes[2].axvline(y * spacing[1], lw=1)

    if existing_ax is None:
        plt.show()

In [None]:
localised_box_size = np.array([80, 80, 112])
generalised_box_size = np.array([0.289, 0.307483, 0.4804149]) * 200


data_path = '/vol/bitbucket/mb4617/MRI_Crohns/numpy_datasets'
folder = 'ti_imb'
suffix = 'all_data'
fold = 0
train_mode = True
mode_str = 'train' if train_mode else 'test'

dataset_path = f'{data_path}/{folder}/{suffix}_{mode_str}_fold{fold}.npz'

In [None]:
np_dataset = np.load(dataset_path)

data = np.stack([np_dataset['axial_t2'], np_dataset['coronal_t2'], np_dataset['axial_pc']], axis=1)

In [None]:
def display_patient(i):
    print(np_dataset['index'][i])
    print(np_dataset['label'][i])
    display_image(data[i][0], localised_box_size)
    display_image(data[i][1], localised_box_size)
    display_image(data[i][2], localised_box_size)

In [None]:
display_patient(14)

In [None]:
from mri_dataset import MRIDataset

In [None]:
torch_dataset = MRIDataset(dataset_path, True, [87, 87, 87], [1, 1, 1])

def display_patient_torch(i):
    sample = torch_dataset[i][0]
    display_image(sample[0].numpy(), localised_box_size)
    display_image(sample[1].numpy(), localised_box_size)
    display_image(sample[2].numpy(), localised_box_size)

In [None]:
print(torch_dataset[14][0].shape)
display_patient_torch(14)

In [None]:
print(torch_dataset[14][0].shape)
display_patient_torch(14)