In [1]:
import sys
sys.path.insert(0, '../')
import matplotlib.pyplot as plt
import torch
import numpy as np
import os
from torch.utils.data import DataLoader, TensorDataset
from unet.model import GeneralUNet
from utils.data_utils import BratsDataset3D
import plotly.graph_objects as go



In [82]:
model_dir = None # MODEL STATE DICT FILE PATH HERE
data_dir = None # DATA FILE PATH HERE
data_idx = None # IDX OF INTEREST
threshold = None # THRESHOLD PREDICTION

In [83]:
# MODEL STRUCTURE HERE, ADJUST AS REQUIRED
model = GeneralUNet(in_channels=4,  
                    conv_kernel_size=3,
                    pool_kernel_size=2,
                    up_kernel_size=2,
                    dropout=0.1,
                    conv_stride=1,
                    conv_padding=1,
                    conv3d=True,
                    size=4,  
                    complex=8)  
state_dict = torch.load(model_dir)
model.load_state_dict(state_dict)
model.eval()

In [85]:
# Create Dataset
dataset = BratsDataset3D(data_dir)
item = dataset.__getitem__(data_idx) # Get item
unsq_item = item[0].unsqueeze(0)
pred = model(unsq_item)
pred = pred.squeeze(0)

In [126]:
label_3d = item[1].squeeze(0).detach().numpy()
pred = (pred >= pred.max()*threshold).int()
pred_3d = pred.squeeze(0).detach().numpy()

# Calculate overlap and unique areas
overlap = (pred_3d > 0) & (label_3d > 0)
unique1 = (pred_3d > 0) & ~(label_3d > 0)
unique2 = (label_3d > 0) & ~(pred_3d > 0)
t1 = item[0][3].detach().numpy()

In [None]:
# Coordinates for unique to image1, unique to image2, and overlap
x1, y1, z1 = np.nonzero(unique1)
x2, y2, z2 = np.nonzero(unique2)
xo, yo, zo = np.nonzero(overlap)

trace1 = go.Scatter3d(
    x=x1, y=y1, z=z1,
    mode='markers',
    marker=dict(size=3, color='blue', opacity=0.5),
    name='Prediction'
)

trace2 = go.Scatter3d(
    x=x2, y=y2, z=z2,
    mode='markers',
    marker=dict(size=3, color='green', opacity=0.5),
    name='Label'
)

trace_overlap = go.Scatter3d(
    x=xo, y=yo, z=zo,
    mode='markers',
    marker=dict(size=3, color='red', opacity=0.5),
    name='Overlap'
)

fig = go.Figure(data=[trace1, trace2, trace_overlap])
fig.update_layout(
    scene=dict(
        xaxis_title='X Axis',
        yaxis_title='Y Axis',
        zaxis_title='Z Axis'
    ),
    title="3D Visualization of Two Overlapping Images"
)

fig.show()