In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from mri_dataset import MRIDataset
from pytorch_resnet import PytorchResNet3D
from torch.utils.data import DataLoader
import torch

def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, phys_size, x=None, y=None, z=None, window=None, level=None, existing_ax=None):
    width, height, depth = phys_size
    
    size = np.flip(img.shape)
    spacing = phys_size / size

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img) - np.min(img)

    if level is None:
        level = window / 2 + np.min(img)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(1, 3, figsize=(14, 8))
    else:
        axes = existing_ax

    axes[0].imshow(img[z,:,:], cmap='gray', clim=(low, high), extent=(0, width, height, 0))
    axes[1].imshow(img[:,y,:], origin='lower', cmap='gray', clim=(low, high), extent=(0, width,  0, depth))
    axes[2].imshow(img[:,:,x], origin='lower', cmap='gray', clim=(low, high), extent=(0, height, 0, depth))

    # Additionally display crosshairs
    axes[0].axhline(y * spacing[1], lw=1)
    axes[0].axvline(x * spacing[0], lw=1)

    axes[1].axhline(z * spacing[2], lw=1)
    axes[1].axvline(x * spacing[0], lw=1)

    axes[2].axhline(z * spacing[2], lw=1)
    axes[2].axvline(y * spacing[1], lw=1)

    if existing_ax is None:
        plt.show()

def display_patient_torch(d_set, i, box_size):
    sample = d_set[i][0]
    display_image(sample[0].numpy(), box_size)
    display_image(sample[1].numpy(), box_size)
    display_image(sample[2].numpy(), box_size)

In [None]:

localised_box_size = np.array([80, 80, 112])
generalised_box_size = np.array([0.289, 0.307483, 0.4804149]) * 200

base = '/vol/bitbucket/mb4617'
data_path = f'{base}/MRI_Crohns/numpy_datasets'
models_path = f'{base}/CrohnsDisease/trained_models'
suffix = 'all_data'
input_size = [87, 87, 87]

all_modalities = True
localisation = True
attention = True
fold = 0

input_features = [1, 1, 1] if all_modalities else [1, 0, 0]
folder = 'ti_imb' if localisation else 'ti_imb_generic'

dataset_path = f'{data_path}/{folder}/{suffix}_test_fold{fold}.npz'
train_dataset_path = f'{data_path}/{folder}/{suffix}_train_fold{fold}.npz'

curr_model_path = f'{models_path}/original_dataset_mode{int(all_modalities)}loc{int(localisation)}att{int(attention)}/fold{fold}'

dataset = MRIDataset(train_dataset_path, False, input_size, input_features)

# model = PytorchResNet3D(input_size, attention, 0.5, sum(input_features))

# model.load_state_dict(torch.load(f'{models_path}/best_model/fold{fold}'))
# model.eval()

In [None]:
device = torch.device('cuda')
model.to(device=device)
print('Device: ', device)

loader = DataLoader(dataset, len(dataset), False)

correct = 0
for x, y in loader:
    
    x = x.to(device=device)
    binary_y = torch.where(y == 0, 0, 1).to(device=device)
    print('labels: ', binary_y)

    with torch.no_grad():
        out = model(x)
    
    preds = out.argmax(dim=1).float()
    print('predictions: ', preds)
    print('accuracy: ', (preds == binary_y).float().mean())

In [None]:
correct_predictions = preds == binary_y


for i in range(len(dataset)):
    if correct_predictions[i]:
        continue
    
    print('Index: ', i)
    print('Label: ', binary_y[i])
    print('Pred:  ', preds[i])
    display_patient_torch(dataset, i, localised_box_size)

In [None]:
from sklearn.metrics import confusion_matrix, f1_score

improve = torch.Tensor([1., 0., 1., 1., 0., 0., 0., 1., 0., 0.,
           0., 1., 1., 1., 1., 1., 0., 1., 0., 1.,
           1., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
           0., 1., 1., 1., 0., 0., 1., 1., 0., 0.,
           0., 0., 0.])

print(confusion_matrix(binary_y.cpu(), preds.cpu()))
print(f1_score(binary_y.cpu(), preds.cpu(), zero_division=0, average='weighted'))
print(f1_score(binary_y.cpu(), improve, zero_division=0, average='weighted'))

In [None]:
from mri_dataset import _random_rotate, _random_crop_gen
import torchvision.transforms as T
from scipy.ndimage import rotate

def _random_rotate_test(x):
    angle = np.random.normal(loc=0, scale=4)
    rotated_np = rotate(x, angle, axes=(2, 3), reshape=False, order=1, mode='nearest')
    return torch.from_numpy(rotated_np)

standard_dataset = MRIDataset(dataset_path, True, input_size, input_features, fast_rotate=False)
dataset_rot = MRIDataset(dataset_path, False, input_size, input_features, transforms=T.Lambda(_random_rotate_test))


In [None]:
from time import time
import cProfile

# Full augmentation ~ 37

## ~2s for nearest padding instead of constant

## Singel worker
## 5 - 35s
## 3 - 18s
## 2 - 15s
## 1 - 8.269

with cProfile.Profile() as pr:
    
    rotated_images = [standard_dataset[0][0][0] for _ in range(48)]

pr.print_stats()
# for i in range(1, 5):
#     display_image(rotated_images[i].numpy(), localised_box_size)
#     print(np.allclose(rotated_images[0].numpy(), rotated_images[i].numpy()))
    

In [None]:

# Multimodal tests, different number of workers

## not given - 34.026
## 1 - 36.214s
## 2 - 36.395
## 4 - 36.405
## 6 - 37.477
## 8 - 37.023

## full model tests, 4 workers
# norm 36.66/38.22
# fast  7.48/ 9.03

worker_loader = DataLoader(standard_dataset, len(standard_dataset), False, num_workers=4)

start_t = time()
# with cProfile.Profile() as pr:
    
for x, y in worker_loader:
    x = x.to(device=device)
    binary_y = torch.where(y == 0, 0, 1).to(device=device)

    print(time() - start_t)
    with torch.no_grad():
        out = model(x)
    preds = out.argmax(dim=1).float()

    print((preds == binary_y).float().mean())

# pr.print_stats()
print(time() - start_t)


In [None]:
with cProfile.Profile() as pr:
    
    rotated_images = [standard_dataset[0][0][0] for _ in range(48)]

pr.print_stats()

In [None]:
interp = T.InterpolationMode.BILINEAR

dataset_rot_2 = MRIDataset(dataset_path, False, input_size, input_features, transforms=T.RandomRotation(8, interpolation=interp))


In [None]:

with cProfile.Profile() as pr:
    
    rotated_images = [dataset_rot_2[0][0][0] for _ in range(48)]
    
pr.print_stats()
# for i in range(1, 5):
#     display_image(rotated_images[i].numpy(), localised_box_size)
#     print(np.allclose(rotated_images[0].numpy(), rotated_images[i].numpy()))

In [None]:

# 0.9s for batch of 48
class MyRotationTransform:
    """Rotate by one of the given angles."""

    def __init__(self, std):
        self.std = std

    def __call__(self, x):
        angle = np.random.normal(loc=0, scale=self.std)
        return self._rotate_by(x, angle)
    
    def _rotate_by(self, x, angle):
        
        _, _, height, width = x.shape
        corner_angle = np.arctan(height / width)
        rad_angle = np.radians(np.abs(angle))
        
        distance_to_top_corner = np.hypot(height, width) * 0.5 * np.sin(corner_angle + rad_angle)
        pad_amount = int(np.ceil(distance_to_top_corner - height // 2))
        
        x = T.functional.pad(x, pad_amount, padding_mode='edge')
        
        x = T.functional.rotate(x, angle, interpolation=interp)
        return T.functional.center_crop(x, (height, width))
        
    
dataset_rot_3 = MRIDataset(dataset_path, False, input_size, input_features, transforms=MyRotationTransform(4.0))


with cProfile.Profile() as pr:
    
    rotated_images = [dataset_rot_3[0][0][0] for _ in range(48)]
    
pr.print_stats()


for i in range(0, 5):
    display_image(rotated_images[i].numpy(), localised_box_size)

In [None]:
def _random_rotate_test2(x, angle):
    rotated_np = rotate(x, angle, axes=(2, 3), reshape=False, order=5, mode='nearest')
    return torch.from_numpy(rotated_np)

test_sample = dataset.data[0][0]
print(test_sample.shape)

for ang in range(-8, 9, 2):
    print(ang)
    display_image(_random_rotate_test2(test_sample, ang)[0].numpy(), generalised_box_size)
    display_image(MyRotationTransform(1)._rotate_by(test_sample, ang)[0].numpy(), generalised_box_size)

In [None]:


class RandomModalityShift:
    def __init__(self, out_shape):
        ## Must be even number less than input
        self.out_shape = out_shape
        
    def _shift_single_channel(self, channel):
        depth, height, width = self.out_shape
        d, h, w = channel.shape
        
        z_diff = d - depth
        y_diff = h - height
        x_diff = w - width
        
        k = torch.randint(0, z_diff + 1, size=(1, )).item()
        j = torch.randint(0, y_diff + 1, size=(1, )).item()
        i = torch.randint(0, x_diff + 1, size=(1, )).item()
        
        return channel[k: k + depth, j: j + height, i: i + width]
        
        
        
    def __call__(self, x):
        c, d, h, w = x.shape
        depth, height, width = self.out_shape
        
        # Do nothing if only single modality
        if c == 1:
            return x
        
        axial = x[0]
        
        z_d = (d - depth) // 2
        y_d = (h - height) // 2
        x_d = (w - width) // 2
        
        axial_cropped = axial[z_d: z_d + depth, y_d: y_d + height, x_d: x_d + width]
        
        return torch.stack([axial_cropped, *[self._shift_single_channel(ch) for ch in x[1:]]])
        
        
        

In [None]:
display_patient_torch(dataset.data, 14, localised_box_size)

In [None]:
in_dims = [99, 99, 99]
out_dims = [87, 87, 87]

test_shape = np.ceil((np.array(in_dims) + np.array(out_dims)) / 2).astype(int)
print(test_shape)
test_shift = RandomModalityShift(test_shape)

In [None]:
display_image(dataset.data[14][0][2].numpy(), localised_box_size)

for _ in range(3):
    display_image(test_shift(dataset.data[14][0])[2].numpy(), localised_box_size)



In [None]:
print()