In [None]:
import os
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt

data_path = '/vol/bitbucket/mb4617/MRI_Crohns'
folder = 'A'
extensions = ['nii', 'nii.gz']

def load_image(end):
    main_path = os.path.join(data_path, folder, end)
    for ext in extensions:
        full_path = f'{main_path}.{ext}'
        if os.path.isfile(full_path):
            return sitk.ReadImage(full_path)
    return None

def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, x=None, y=None, z=None, window=None, level=None, existing_ax=None):
    # Convert SimpleITK image to NumPy array
    img_array = sitk.GetArrayFromImage(img)

    # Get image dimensions in millimetres
    size = img.GetSize()
    spacing = img.GetSpacing()
    width  = size[0] * spacing[0]
    height = size[1] * spacing[1]
    depth  = size[2] * spacing[2]

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img_array) - np.min(img_array)

    if level is None:
        level = window / 2 + np.min(img_array)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(1, 3, figsize=(14, 8))
    else:
        axes = existing_ax

    axes[0].imshow(img_array[z,:,:], cmap='gray', clim=(low, high), extent=(0, width, height, 0))
    axes[1].imshow(img_array[:,y,:], origin='lower', cmap='gray', clim=(low, high), extent=(0, width,  0, depth))
    axes[2].imshow(img_array[:,:,x], origin='lower', cmap='gray', clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    axes[0].axhline(y * spacing[1], lw=1)
    axes[0].axvline(x * spacing[0], lw=1)

    axes[1].axhline(z * spacing[2], lw=1)
    axes[1].axvline(x * spacing[0], lw=1)

    axes[2].axhline(z * spacing[2], lw=1)
    axes[2].axvline(y * spacing[1], lw=1)

    if existing_ax is None:
        plt.show()

def print_image_stats(paths, images):
    for i in range(len(paths)):
        print(paths[i])
        img = images[i]
        print("Size: ", img.GetSize())
        print("Spacing: ", img.GetSpacing())
        print("Origin: ", img.GetOrigin())
        print("Direction: ", img.GetDirection())
        print()
    for img in images:
        display_image(img)
        
def display_patient(p):
    p_imgs = [p.axial_image, p.coronal_image, p.axial_postcon_image]
    fig, axes = plt.subplots(3, 3, figsize=(14, 24))

    for i in range(3):
        display_image(p_imgs[i], existing_ax=axes[i])
    plt.show()


In [None]:
img_paths = ['A1 Axial Postcon']

images = [load_image(e) for e in img_paths]


In [None]:
print_image_stats(img_paths, images)


In [None]:
from preprocess import Preprocessor
from metadata import Metadata
from np_generator import NumpyGenerator

# Reverse-engineer dimensions from desired global average pooling size (assuming three downsampling layers)
pool_size = [10, 10, 10]
input_size = [2 * (2 * (2 * x + 1) + 1) + 1 for x in pool_size]
reference_size = [x + pad for x, pad in zip(input_size, [12, 12, 12])]
k = 4
test_proportion = 0.25
print('input_size', input_size)
print('record_size', reference_size)

# Path setting
data_path = '/vol/bitbucket/mb4617/MRI_Crohns'
label_path = '/vol/bitbucket/mb4617/MRI_Crohns/labels'

abnormal_cases = [0]
healthy_cases = [0]

In [None]:
metadata = Metadata(data_path, label_path, abnormal_cases, healthy_cases, dataset_tag='')

In [None]:
print(metadata.patients[0])

In [None]:
for p in metadata.patients:
    p.load_image_data(True, True, True)

# display_patient(metadata.patients[0])


preprocessor = Preprocessor(constant_volume_size=reference_size)
metadata.patients = preprocessor.process(metadata.patients, ileum_crop=True, region_grow_crop=False, statistical_region_crop=False)

for p in metadata.patients:
    print(p.get_id())
    display_patient(p)

In [None]:
same_dir_count = 0
same_space_count = 0

for p in metadata.patients:
    print(p.get_id())
    upper_img = p.axial_postcon_upper_image
    lower_img = p.axial_postcon_lower_image
    
    same_direction = np.allclose(upper_img.GetDirection(), lower_img.GetDirection())
    print(f'Same direction - {same_direction}')
    if not same_direction:
        print(np.reshape(np.array(upper_img.GetDirection()), (3, 3)))
        print(np.reshape(np.array(lower_img.GetDirection()), (3, 3)))
    else:
        same_dir_count += 1
    print()
    
    upper_spacing = np.array(upper_img.GetSpacing())
    lower_spacing = np.array(lower_img.GetSpacing())
    same_spacing = np.allclose(upper_spacing, lower_spacing)
    print(f'Same spacing - {same_spacing}')
    if same_spacing:
        same_space_count += 1
    else:
        print(upper_spacing)
        print(lower_spacing)
    print()
    

In [None]:
preprocessor.dimension = metadata.patients[0].axial_image.GetDimension()

from time import time

def image_contains_box(img, box_center, box_size):
    
    max_coords = img.GetSize()
    
    # Check top of box first, as most likely to be out of bounds
    for d_z in [0.5, -0.5]:
        for d_x in [0.5, -0.5]:
            for d_y in [0.5, -0.5]:
                
                point_phys = box_center + box_size * [d_x, d_y, d_z]
                
                point_coords = img.TransformPhysicalPointToContinuousIndex(point_phys)
                
                for i in range(3):
                    if point_coords[i] < 0 or point_coords[i] >= max_coords[i]:
                        return False
    return True
                
def set_titles(axes, category):
    axes[0].set_title(category + ' Axial')
    axes[1].set_title(category + ' Coronal')
    axes[2].set_title(category + ' Sagittal')

start_t = time()

for p in metadata.patients:
    
    
    print(p.get_id())
    ileum = np.array([p.ileum[1], p.ileum[0], p.ileum[2]])
    
    p.ileum_physical = p.axial_image.TransformContinuousIndexToPhysicalPoint(ileum * 1.0)
    p.ileum_box_size = np.array([80, 80, 112])
    
    ref_img = preprocessor.generate_reference_volume(p)
    
    def temp_resample(img):
        ileum_coords = img.TransformPhysicalPointToIndex(p.ileum_physical)
        
        ileum_z = ileum_coords[2] if ileum_coords[2] >= 0 else None
        
#         display_image(img, x=ileum_coords[0], y=ileum_coords[1], z=ileum_z)
        resampled = sitk.Resample(img, ref_img)
        
        return resampled
        
    re_upper = temp_resample(p.axial_postcon_upper_image)
    re_lower = temp_resample(p.axial_postcon_lower_image)
    
    
    if image_contains_box(p.axial_postcon_lower_image, p.ileum_physical, p.ileum_box_size):
        print('Lower image fully contains box')
        
        fig, axes = plt.subplots(2, 3, figsize=(14, 16))
        display_image(re_upper, existing_ax=axes[0])
        set_titles(axes[0], 'Upper Resampled')
        display_image(re_lower, existing_ax=axes[1])
        set_titles(axes[1], 'Lower Resampled')
        fig.set_facecolor('w')
        plt.show()
        
        continue
    
        
    fig, axes = plt.subplots(5, 3, figsize=(14, 40))
    fig.set_facecolor('w')
    display_image(re_upper, existing_ax=axes[0])
    set_titles(axes[0], 'Upper Resampled')
    display_image(re_lower, existing_ax=axes[1])
    set_titles(axes[1], 'Lower Resampled')
    
    zero_count = 0
        
    sag_sz, cor_sz, ax_sz = re_lower.GetSize()
    
    # Find transition start and end bounding box
    # top and bottom define inclusize range
    box_top = 0
    box_bottom = ax_sz
    
    for x, y in [(0, 0), (0, cor_sz - 1), (sag_sz - 1, 0), (sag_sz - 1, cor_sz - 1)]:
        at_top = True
        for z in reversed(range(ax_sz)):
            if at_top:
                l_pix = re_lower.GetPixel((x, y, z))
                if l_pix != 0:
                    at_top = False
                    box_top = max(z, box_top)
            else:
                u_pix = re_upper.GetPixel((x, y, z))
                if u_pix == 0:
                    box_bottom = min(z + 1, box_bottom)
                    break
                    
    re_lower[:, :, box_top+1:] = re_upper[:, :, box_top+1:]
                    
    def temp_interp(new_img, interp_length):
        for x in range(sag_sz):
            for y in range(cor_sz):
                for z in reversed(range(box_bottom, box_top + 1)):
                    curr_coords = (x, y, z)
                    l_pix = re_lower.GetPixel(curr_coords)

                    if l_pix == 0:
                        new_img.SetPixel(curr_coords, re_upper.GetPixel(curr_coords))
                    else:
                        interp_top = z
                        break

                interp_bottom = max(box_bottom, interp_top - interp_length + 1)
                interp_scale = interp_length + 1
                for i in range(interp_length):
                    curr_coords = (x, y, interp_bottom + i)
                    upper_factor = i + 1
                    lower_factor = interp_scale - upper_factor

                    l_pix = re_lower.GetPixel(curr_coords)
                    u_pix = re_upper.GetPixel(curr_coords)

                    new_pix = (l_pix * lower_factor + u_pix * upper_factor) // interp_scale
                    new_img.SetPixel(curr_coords, new_pix)
        return new_img
    
    display_image(temp_interp(sitk.Image(re_lower), 0), z=box_top, existing_ax=axes[2])
    display_image(temp_interp(sitk.Image(re_lower), 5), z=box_top - 4, existing_ax=axes[3])
    display_image(temp_interp(sitk.Image(re_lower), 10), z=box_top - 4, existing_ax=axes[4])
    
    set_titles(axes[2], 'No Interpolation')
    set_titles(axes[3], '5 Pixels Interpolation')
    set_titles(axes[4], '10 Pixels Interpolation')
    
    plt.show()
#     for z in reversed(range(ax_sz)):
#         layer_fully_lower = True
#         for y in range(cor_sz):
#             for x in range(sag_sz):
#                 l_pix = re_lower.GetPixel((x, y, z))
#                 if l_pix != 0:
#                     continue
                    
#                 u_pix = re_upper.GetPixel((x, y, z))
#                 if u_pix == 0:
#                     continue
                    
#                 layer_fully_lower = False
#                 re_lower.SetPixel((x, y, z), 255)
#         if layer_fully_lower:
#             break
                
#     display_image(re_lower)
                
#     upper_ileum_coords = upper_img.TransformPhysicalPointToIndex(ileum_phys)
    
#     ileum_coords = lower_img.TransformPhysicalPointToIndex(ileum_phys)
#     display_image(upper_img, x=upper_ileum_coords[0], y=upper_ileum_coords[1])
#     display_image(lower_img, x=ileum_coords[0], y=ileum_coords[1], z=ileum_coords[2])

end_t = time()
print(end_t - start_t)
    

In [None]:
print(metadata.patients[-1].axial_postcon_upper_image.GetOrigin())
print(metadata.patients[-1].axial_postcon_lower_image.GetOrigin())