In [1]:
import os
import torch
import torch.nn as nn
import torchvision as tv
import torchio as tio
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
import _01_dataloader_small as prl_dl
import _02_autoencoder as prl_ae
import _05_predictor as prl_pred

import importlib
importlib.reload(prl_dl)
importlib.reload(prl_ae)
importlib.reload(prl_pred)

def patch_to_tensor(patch):
    patch_list = [patch.get(key)["data"] for key in keys]
    #for i in range(len(patch_list)):
     #   patch_list[i] = patch_list[i][None, :]
    patch_tensor = torch.cat(patch_list, dim=0)
    return patch_tensor[None, :]

def plot_3d_tensor(subject, autoencoder, coord = [10, 10, 10]):
    patch = prl_dl.train_loader.dataset[subject]
    
    tensor1 = patch_to_tensor(patch)
    tensor2 = autoencoder(tensor1)
    
    print(patch["name"])
    print(patch["location"][0:3] + torch.tensor([12, 12, 12]))
    
    for image in range(4):
        # Plot first tensor
        tensor1_tmp = tensor1[0, image, :, :, :]

        # Extract the slice at the given coordinate
        slice_x = tensor1_tmp[coord[0], :, :].detach().numpy()
        slice_y = tensor1_tmp[:, coord[1], :].detach().numpy()
        slice_z = tensor1_tmp[:, :, coord[2]].detach().numpy()

        # Create subplots for the three axes
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))

        # Plot slices in each axis
        axes[0].imshow(slice_x, cmap='gray')
        axes[0].set_title('X-axis Slice')

        axes[1].imshow(slice_y, cmap='gray')
        axes[1].set_title('Y-axis Slice')

        axes[2].imshow(slice_z, cmap='gray')
        axes[2].set_title('Z-axis Slice')

        # Adjust spacing between subplots
        plt.tight_layout()

        # Display the plot
        plt.show()
        plt.clf()

        # Plot second tensor
        tensor2_tmp = tensor2[0, image, :, :, :]
        
        # Extract the slice at the given coordinate
        slice_x = tensor2_tmp[coord[0], :, :].detach().numpy()
        slice_y = tensor2_tmp[:, coord[1], :].detach().numpy()
        slice_z = tensor2_tmp[:, :, coord[2]].detach().numpy()

        # Create subplots for the three axes
        fig, axes = plt.subplots(1, 3, figsize=(12, 4))

        # Plot slices in each axis
        axes[0].imshow(slice_x, cmap='gray')
        axes[0].set_title('X-axis Slice')

        axes[1].imshow(slice_y, cmap='gray')
        axes[1].set_title('Y-axis Slice')

        axes[2].imshow(slice_z, cmap='gray')
        axes[2].set_title('Z-axis Slice')

        # Adjust spacing between subplots
        plt.tight_layout()

        # Display the plot
        plt.show()
        plt.clf()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

prl_autoencoder = prl_ae.Autoencoder3D()
prl_autoencoder_joint = prl_ae.Autoencoder3D()
prl_predictor = prl_pred.Predictor3D()

keys = ["t1", "flair", "epi", "phase"]

# Load the state_dict from the checkpoint
prl_autoencoder.load_state_dict(torch.load("models/prl_autoencoder_0718_upweight.pt", 
                                           map_location=torch.device('cpu')))
prl_autoencoder.eval()

# Load the state_dict from the checkpoint
prl_autoencoder_joint.load_state_dict(torch.load("models/prl_autoencoder_joint_0719.pt", 
                                           map_location=torch.device('cpu')))
prl_autoencoder_joint.eval()

# Load the state_dict from the checkpoint
prl_predictor.load_state_dict(torch.load("models/prl_predictor_0719.pt", 
                                           map_location=torch.device('cpu')))
prl_predictor.eval()

def process_lesion_type(lesion_id): # Return tensor of [is_lesion, is_PRL, is_CVS]
    target = torch.zeros(3)
    weight = torch.ones(3)
    if (lesion_id == 0):
        return([target, weight]) # Just return the tensor of 0s

    digits = [int(x) for x in str(lesion_id)] 
    # First digit is always 1 for computational convenience
    
    if digits[1] == 1: # non-PRL lesion
        target[0] = 1

    if digits[1] == 2: # PRL lesion
        target[0] = 1
        target[1] = 1

    if digits[2] == 1: # non-CVS lesion
        target[0] = 1

    if digits[2] == 3:
        target[0] = 1
        target[2] = 1
        
    if digits[2] == 2: # possible CVS 
        target[0] = 1
        weight[2] = 0 # Don't count errors in CVS column against the network
    
    if digits[2] == 9: # no CVS data available for this subject
        target[0] = -1
        target[2] = -1
        weight[0] = 0 # Don't count errors in lesion detection or CVS detection lesion coverage in PRL dataset is poor
        weight[2] = 0 

    return([target, weight])

def get_coords(candidate_id, num_coords, 
               lesion_mask):
    candidate_coords = torch.nonzero(lesion_mask == candidate_id)
    max_coords = min(num_coords, candidate_coords.size()[0])
    random_inds = random.sample(range(candidate_coords.size()[0]), max_coords)
    return(candidate_coords[random_inds, :])

def isolate_lesion(lesion_mask_patch, candidate_id):
    isolation_mask = (lesion_mask_patch == 0) + (lesion_mask_patch == candidate_id)

    if (isolation_mask == False).any():
        return([True, isolation_mask])
    else:
        return([False, isolation_mask])

def extract_patch(coord, candidate_id, 
                  t1, flair, epi, phase, 
                  lesion_mask, lesion_type):
    start_ends = [coord[1] - 12, coord[1] + 12, 
                  coord[2] - 12, coord[2] + 12,
                  coord[3] - 12, coord[3] + 12]
    x_start = max(start_ends[0], 0)
    x_end = min(start_ends[1], t1.size()[1] - 1)
    y_start = max(start_ends[2], 0)
    y_end = min(start_ends[3], t1.size()[2] - 1)
    z_start = max(start_ends[4], 0)
    z_end = min(start_ends[5], t1.size()[3] - 1)

    t1_patch = t1[:, x_start:x_end, y_start:y_end, z_start:z_end]
    flair_patch = flair[:, x_start:x_end, y_start:y_end, z_start:z_end]
    epi_patch = epi[:, x_start:x_end, y_start:y_end, z_start:z_end]
    phase_patch = phase[:, x_start:x_end, y_start:y_end, z_start:z_end]

    patch = torch.cat((t1_patch, flair_patch, epi_patch, phase_patch))
    lesion_mask_patch = lesion_mask[:, x_start:x_end, y_start:y_end, z_start:z_end]
    lesion_type_patch = lesion_type[:, x_start:x_end, y_start:y_end, z_start:z_end]
    if tuple(patch.size()) != (4, 24, 24, 24):
        patch, lesion_mask_patch, lesion_type_patch = pad_patches(patch, lesion_mask_patch, 
                                                                  lesion_type_patch, start_ends)
    is_multiple, isolation_mask = isolate_lesion(lesion_mask_patch, candidate_id)

    if is_multiple:
        patch = patch * (isolation_mask.repeat(4, 1, 1, 1))
        lesion_id = (lesion_type_patch * isolation_mask).unique()
        lesion_id = lesion_id[lesion_id != 0]
    else:
        lesion_id = lesion_type_patch.unique()
        lesion_id = lesion_id[lesion_id != 0]

    return([patch, lesion_id])

def pad_patches(patch, lesion_mask_patch, lesion_type_patch, 
                start_ends):
    patch_pad_tensor = torch.zeros(4, 24, 24, 24)
    mask_pad_tensor = torch.zeros(1, 24, 24, 24)
    type_pad_tensor = torch.zeros(1, 24, 24, 24)
    starts = [start_ends[i] for i in [0, 2, 4]]
    start_patch = [0 - start if start < 0 else 0 for start in starts]
    ends = [start_ends[i] for i in [1, 3, 5]]
    end_patch = [23 - (ends[i] - t1.size()[i + 1]) if ends[i] >= t1.size()[i + 1] else 24 for i in range(len(ends))]

    patch_pad_tensor[:, 
                     start_patch[0]:end_patch[0], 
                     start_patch[1]:end_patch[1], 
                     start_patch[2]:end_patch[2]] = patch
    mask_pad_tensor[:, 
                     start_patch[0]:end_patch[0], 
                     start_patch[1]:end_patch[1], 
                     start_patch[2]:end_patch[2]] = lesion_mask_patch
    type_pad_tensor[:, 
                     start_patch[0]:end_patch[0], 
                     start_patch[1]:end_patch[1], 
                     start_patch[2]:end_patch[2]] = lesion_type_patch
    
    return(patch_pad_tensor, mask_pad_tensor, type_pad_tensor)

def get_predictions(dataset, subject_id, num_coords):
    subject = dataset[subject_id]
    print(subject["name"])

    lesion_mask = subject["lesion_mask"]["data"]
    lesion_type = subject["lesion_type"]["data"]

    t1 = subject["t1"]["data"]
    flair = subject["flair"]["data"]
    epi = subject["epi"]["data"]
    phase = subject["phase"]["data"]

    output_tensor = torch.zeros(int(lesion_mask.max()), 3)
    target_tensor = torch.zeros(int(lesion_mask.max()), 3)

    for candidate_id in range(1, int(lesion_mask.max()) + 1):
        coords = get_coords(candidate_id, num_coords, lesion_mask)
        tmp_coord = coords[0, :]
        target_tensor[candidate_id - 1, :] = process_lesion_type(int(lesion_type[tmp_coord[0], tmp_coord[1], tmp_coord[2], tmp_coord[3]]))[0]
        prediction = torch.zeros(num_coords, 3)

        for i in range(coords.size()[0]):
            coord = coords[i, :]
            patch, lesion_id = extract_patch(coord, candidate_id, 
                                             t1, flair, epi, phase,
                                            lesion_mask, lesion_type)
            patch = patch[None, :, :, :, :]
            output = prl_predictor(prl_autoencoder_joint.encoder(patch))
            prediction[i, :] = output

        if num_coords > 1:
            output_tensor[candidate_id - 1, :] = torch.mean(prediction, dim=0)
        else:
            output_tensor[candidate_id - 1, :] = prediction

        #print("Mean prediction: " + str(output_tensor[candidate_id - 1, :]))
    
    return([subject["name"], output_tensor, target_tensor])

cpu


In [None]:
train_df = pd.DataFrame()
for i in range(len(prl_dl.train_dataset)):
    tmp_name, tmp_output, tmp_target = get_predictions(prl_dl.train_dataset, i, 20)
    tmp_df = pd.DataFrame(torch.cat([tmp_output, tmp_target], dim=1).detach().numpy())
    tmp_df["subject"] = tmp_name
    train_df = pd.concat([train_df, tmp_df])
train_df.to_csv("data/train_preds.csv")

In [2]:
prl_test = prl_dl.test_dataset

In [8]:
prl_test[9]["name"]

'/home/fengling/Documents/prl/data/processed/01-008'

In [None]:
test_df = pd.DataFrame()
for i in range(len(prl_dl.test_dataset)):
    tmp_name, tmp_output, tmp_target = get_predictions(prl_dl.train_dataset, i, 20)
    tmp_df = pd.DataFrame(torch.cat([tmp_output, tmp_target], dim=1).detach().numpy())
    tmp_df["subject"] = tmp_name
    test_df = pd.concat([test_df, tmp_df])
test_df.to_csv("data/test_preds.csv")

/home/fengling/Documents/prl/data/processed/05-001
/home/fengling/Documents/prl/data/processed/07-006
/home/fengling/Documents/prl/data/processed/07-012
/home/fengling/Documents/prl/data/processed/02-009
/home/fengling/Documents/prl/data/processed/08-002
