# Imports

In [38]:
%matplotlib widget
from scipy.io import loadmat
from ipywidgets import interact, IntSlider, fixed
import matplotlib.pyplot as plt
import torch
import numpy as np

# Data Loading and Pre-processing

In [39]:
# Load the MATLAB files
gt_mat = loadmat("Data/3D_recon/FF/recon_p06.FF01.u_FDK_ROI_fullView.mat")
fdk_mat = loadmat("Data/3D_recon/FF/recon_p06.FF01.u_FDK_ROI.mat")
pl_mat = loadmat("Data/3D_recon/FF/recon_p06.FF01.u_PL_ROI.b1.mat")

# Print the keys of the loaded dictionaries
print(gt_mat.keys())
print(fdk_mat.keys())
print(pl_mat.keys())

dict_keys(['__header__', '__version__', '__globals__', 'u_FDK_ROI_fullView'])
dict_keys(['__header__', '__version__', '__globals__', 'u_FDK_ROI'])
dict_keys(['__header__', '__version__', '__globals__', 'u_PL_ROI'])


In [40]:
# Extract the scan data
gt = gt_mat['u_FDK_ROI_fullView']
fdk = fdk_mat['u_FDK_ROI']
pl = pl_mat['u_PL_ROI']

# Check their shapes
print("Initial shapes:")
print("GT:", gt.shape)
print("FDK:", fdk.shape)
print("PL:", pl.shape)

# Remove the first and last 20 slices
gt = gt[..., 20:-20]
fdk = fdk[..., 20:-20]
pl = pl[..., 20:-20]

# Crop
gt = gt[128:-128, 128:-128]
fdk = fdk[128:-128, 128:-128]
pl = pl[128:-128, 128:-128]

# Check the new shapes (for verification)
print("\nShapes after removing slices:")
print("GT:", gt.shape)
print("FDK:", fdk.shape)
print("PL:", pl.shape)

Initial shapes:
GT: (512, 512, 200)
FDK: (512, 512, 200)
PL: (512, 512, 200)

Shapes after removing slices:
GT: (256, 256, 160)
FDK: (256, 256, 160)
PL: (256, 256, 160)


In [41]:
# NOTE: This is a numpy array, not a torch tensor
ddCNN = torch.load("Data/3D_recon/FF/p06.FF01_IResNet_MK6_DS14.2_run1_3D.pt", weights_only=False)

# The loaded array should already have the correct number of slices
ddCNN.shape

(256, 256, 160)

# Plotting

In [None]:
def show_volumes(volumes, titles, slice_idx):
    """
    Display a list of 3D volumes side-by-side at a given slice index.
    
    volumes : list of ndarray, each shape (H, W, N)
    titles  : list of str, same length as volumes
    slice_idx: int, which slice to display
    """
    n = len(volumes)
    fig, axes = plt.subplots(1, n, figsize=(4*n, 5))
    
    # If only one volume, axes is not a list
    if n == 1:
        axes = [axes]
    
    for ax, vol, title in zip(axes, volumes, titles):
        ax.imshow(vol[:, :, slice_idx], cmap='gray')
        ax.set_title(title)
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [43]:
vols = [
    gt,
    fdk,
    pl,
    ddCNN,
]
titles = [
    'Ground Truth',
    'FDK Reconstruction',
    'PL Reconstruction',
    'DDCNN Reconstruction'
]

# 3. Hook it up to the slider
interact(
    show_volumes,
    volumes=fixed(vols),
    titles=fixed(titles),
    slice_idx=IntSlider(
        min=0,
        max=gt.shape[2] - 1,
        step=1,
        value=0,
        description='Slice'
    )
)

interactive(children=(IntSlider(value=0, description='Slice', max=159), Output()), _dom_classes=('widget-inter…

<function __main__.show_volumes(volumes, titles, slice_idx)>