In [22]:
import sys
sys.path.insert(0, '../')
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
import torch
from matplotlib.patches import Patch
from utils.data_utils import BratsDataset3D
from utils.predict import ModelPredict

In [23]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_path = r'C:\Users\johns\OneDrive\Desktop\Models\ECE542_Project_Brain_Segmentation\20K_model_10_epochs.pth'
model = torch.load(model_path)

In [24]:
scan_dir = r'C:\Users\johns\OneDrive\Desktop\Datasets\ECE-542\brain-tumor-segmentation(nii)\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData'
scan_index = 0
dataset = BratsDataset3D(scan_dir)

In [27]:
[input_tensor, label_tensor] = dataset.__getitem__(scan_index)

# Move input to GPU
input_tensor = input_tensor.to(device)
input_tensor = torch.unsqueeze(input_tensor, 0)

In [None]:
model_output = model(input_tensor)
print(max(model_output))

In [None]:
scan_np  = input_tensor.squeeze.numpy()
pred_np  = model_output.squeeze().numpy()  
label_np = label_tensor.squeeze().numpy()  

In [None]:
num_slices = 5
middle_slice = 75  # Adjust as needed
num_rows = 1
num_cols = num_slices

tissue_class_value = 0 
tumor_class_value = 1

plt.figure(figsize=(20, 5))

for i in range(num_slices):
    current_slice = middle_slice + 3 * i
    scan_slice = scan_np[:, :, current_slice]
    pred_slice = pred_np[:, :, current_slice]
    truth_slice = label_np[:, :, current_slice]

    correct_tissue = (pred_slice == truth_slice) & (truth_slice == tissue_class_value)
    incorrect_tissue = (pred_slice != truth_slice) & (truth_slice == tissue_class_value)

    correct_tumor = (pred_slice == truth_slice) & (truth_slice == tumor_class_value)
    incorrect_tumor = (pred_slice != truth_slice) & (truth_slice == tumor_class_value)

    overlay_tissue = dataset.create_overlay(scan_slice, correct_tissue, incorrect_tissue)
    overlay_tumor = dataset.create_overlay(scan_slice, correct_tumor, incorrect_tumor)

    plt.subplot(2, num_slices, i + 1)
    plt.imshow(overlay_tissue)
    plt.title(f'Tissue: Slice {current_slice}')
    plt.axis('off')
    
    # Plot tumor overlay
    plt.subplot(2, num_slices, num_slices + i + 1)
    plt.imshow(overlay_tumor)
    plt.title(f'Tumor: Slice {current_slice}')
    plt.axis('off')


plt.tight_layout()
plt.show()