In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from mri_dataset import MRIDataset
from pytorch_resnet import PytorchResNet3D
import torch
import torch.nn.functional as F

def wl_to_lh(window, level):
    low = level - window/2
    high = level + window/2
    return low,high

def display_image(img, phys_size, x=None, y=None, z=None, window=None, level=None, existing_ax=None):
    width, height, depth = phys_size
    
    size = np.flip(img.shape)
    spacing = phys_size / size

    if x is None:
        x = np.floor(size[0]/2).astype(int)
    if y is None:
        y = np.floor(size[1]/2).astype(int)
    if z is None:
        z = np.floor(size[2]/2).astype(int)

    if window is None:
        window = np.max(img) - np.min(img)

    if level is None:
        level = window / 2 + np.min(img)

    low,high = wl_to_lh(window,level)

    if existing_ax is None:
        # Display the orthogonal slices
        fig, axes = plt.subplots(1, 3, figsize=(14, 8))
    else:
        axes = existing_ax

    axes[0].imshow(img[z,:,:], cmap='gray', clim=(low, high), extent=(0, width, height, 0))
    axes[1].imshow(img[:,y,:], origin='lower', cmap='gray', clim=(low, high), extent=(0, width,  0, depth))
    axes[2].imshow(img[:,:,x], origin='lower', cmap='gray', clim=(low, high), extent=(0, height, 0, depth))

    if existing_ax is None:
        plt.show()

def display_patient_torch(d_set, i, box_size):
    sample = d_set[i][0]
    display_image(sample[0].numpy(), box_size)
    display_image(sample[1].numpy(), box_size)
    display_image(sample[2].numpy(), box_size)

In [None]:

localised_box_size = np.array([80, 80, 112])
generalised_box_size = np.array([0.289, 0.307483, 0.4804149]) * 200

base = '/vol/bitbucket/mb4617'
data_path = f'{base}/MRI_Crohns_Extended/numpy_datasets'
models_path = f'{base}/CrohnsDisease/trained_models'
suffix = 'all_data'

all_modalities = True
localisation = True
attention = True
fold = 0
even_res = False
experiment_round = 4

in_dims = [99 if even_res else 37, 99, 99]
out_dims = [87 if even_res else 31, 87, 87]
input_features = [1, 1, 1] if all_modalities else [1, 0, 0]
folder = 'ti' if localisation else 'ti_generic_larger'
res_str = 'even' if even_res else 'low_axial'
box_size = localised_box_size if localisation else generalised_box_size

train_dataset_path = f'{data_path}/{folder}/{suffix}_{res_str}_res_train_fold{fold}.npz'
test_dataset_path = f'{data_path}/{folder}/{suffix}_{res_str}_res_test_fold{fold}.npz'

print(train_dataset_path)

train_dataset = MRIDataset(train_dataset_path, False, in_dims, out_dims, input_features)
test_dataset = MRIDataset(test_dataset_path, False, in_dims, out_dims, input_features)

In [None]:
curr_model_path = f'{models_path}/{experiment_round}/axial_only_extended_dataset_mode{int(all_modalities)}loc{int(localisation)}att{int(attention)}/fold{fold}'

print(curr_model_path)

model = PytorchResNet3D(out_dims, attention, 0.5, sum(input_features))

model.load_state_dict(torch.load(curr_model_path))
model.eval()

device = torch.device('cuda')
model.to(device=device)
print('Device: ', device)

In [None]:
from sklearn.metrics import classification_report

def dataset_stats(dataset):
    labels, binary_labels, losses, preds = [], [], [], []
    
    for i in range(len(dataset)):
        data, label = dataset[i]
        binary_label = torch.tensor(0 if label == 0 else 1)
        batched_data = data.unsqueeze(0).to(device=device)

        with torch.no_grad():
            out = model(batched_data).cpu()
        
        labels.append(label.item())
        binary_labels.append(binary_label.item())
        losses.append(F.cross_entropy(out, binary_label.unsqueeze(0)).item())
        preds.append(out.argmax(dim=1).squeeze().item())
        
    print(np.array(labels))
    print(np.array(binary_labels))
    print(preds)
    print('Average loss: ', torch.mean(torch.tensor(losses)))
    print(classification_report(binary_labels, preds, target_names=['healthy', 'abnormal'], zero_division=0))

In [None]:
dataset_stats(train_dataset)
dataset_stats(test_dataset)

In [None]:

def test_example(index, dataset):
    data, label = dataset[index]
    batched_data = data.unsqueeze(0).to(device=device)
    
    with torch.no_grad():
        pred = model(batched_data).cpu()
        pred = F.softmax(pred, dim=-1).squeeze()[1].item()
    
    return data, label.item(), pred

def test_and_print(index, dataset):
    fig, axes = plt.subplots(1, 3, figsize=(14, 8))
    
    data, label, pred = test_example(index, dataset)
    
    print('Index: ', index)
    print('Lablel: ', label)
    print('Unhealthy pred: ', pred)
    
    display_image(data[0].numpy(), box_size, existing_ax=axes)
    plt.show()
    

In [None]:
for i in range(5):
    test_and_print(i, train_dataset)

In [None]:
for i in range(5):
    test_and_print(i, test_dataset)