# SPECT Reconstruction Examples (FBP & OSEM)

This notebook demonstrates basic reconstruction workflows for Single Photon Emission Computed Tomography (SPECT) using Filtered Back-Projection (FBP) and Ordered Subsets Expectation Maximization (OSEM) algorithms available in `reconlib`.

**Workflow:**
1. Setup simulation parameters and create a simple activity phantom.
2. Instantiate the `SPECTProjectorOperator` which can model basic SPECT physics (attenuation and PSF are optional).
3. Simulate projection data (sinogram) using the projector.
4. Reconstruct the activity map from the sinogram using:
    a. `SPECTFBPReconstructor`.
    b. `SPECTOSEMReconstructor`.
5. Visualize the original phantom and the reconstructed images.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

# Ensure reconlib is in the Python path (e.g., if running from examples folder)
import sys
import os
if os.path.abspath(os.path.join(os.getcwd(), '..')) not in sys.path:
    sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

try:
    from reconlib.modalities.spect import SPECTProjectorOperator, SPECTFBPReconstructor, SPECTOSEMReconstructor
    # simple_radon_transform might be needed if not using SPECTProjectorOperator for data gen, but we are.
    # from reconlib.modalities.pcct.operators import simple_radon_transform 
except ImportError as e:
    print(f"Import Error: {e}. Make sure reconlib is installed or PYTHONPATH is set correctly.")
    print("You might need to run 'export PYTHONPATH=/path/to/your/reconlib/parent/directory:$PYTHONPATH'")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 1. Simulation Parameters and Phantom

In [None]:
# Image and projection geometry
img_s_spect = (64, 64)  # Image size (Ny, Nx)
n_angles_spect = 60     # Number of projection angles (should be a divisor for num_subsets in OSEM for simplicity)
n_dets_spect = int(np.floor(img_s_spect[0] * np.sqrt(2)) +1) # Number of detector bins 
if n_dets_spect % 2 == 0: n_dets_spect +=1 # Ensure odd for centering if needed by radon/bp

angles_spect_np = np.linspace(0, np.pi, n_angles_spect, endpoint=False)
angles_spect = torch.tensor(angles_spect_np, device=device, dtype=torch.float32)

# Create a simple activity phantom
activity_phantom = torch.zeros(img_s_spect, device=device, dtype=torch.float32)
center_y, center_x = img_s_spect[0]//2, img_s_spect[1]//2
radius = img_s_spect[0]//4
y_coords, x_coords = torch.meshgrid(torch.arange(img_s_spect[0], device=device), 
                                  torch.arange(img_s_spect[1], device=device), indexing='ij')
disk_mask = (y_coords - center_y)**2 + (x_coords - center_x)**2 < radius**2
activity_phantom[disk_mask] = 1.0
# Add another smaller, hotter spot
activity_phantom[center_y+radius//2 : center_y+radius//2+5, 
                 center_x-radius//2 : center_x-radius//2+5] = 2.0

plt.figure(figsize=(5,5))
plt.imshow(activity_phantom.cpu().numpy(), cmap='hot')
plt.title("Original Activity Phantom")
plt.xlabel("X-pixel"); plt.ylabel("Y-pixel")
plt.colorbar(label="Activity")
plt.show()

## 2. Setup SPECT Projector
For this initial demonstration, we'll use a simplified projector without attenuation or significant PSF blurring to focus on the reconstruction algorithms.

In [None]:
attenuation_map_spect = None # No attenuation for this example
geometric_psf_fwhm_spect_mm = None # No PSF for this example
pixel_size_spect_mm = 1.0

spect_projector = SPECTProjectorOperator(
    image_shape=img_s_spect,
    angles=angles_spect,
    detector_pixels=n_dets_spect,
    attenuation_map=attenuation_map_spect,
    geometric_psf_fwhm_mm=geometric_psf_fwhm_spect_mm,
    pixel_size_mm=pixel_size_spect_mm,
    device=device
)
print("SPECTProjectorOperator instantiated.")

## 3. Simulate Projections

In [None]:
projections_spect = spect_projector.op(activity_phantom)
print(f"Simulated projections shape: {projections_spect.shape}")

plt.figure(figsize=(7,5))
plt.imshow(projections_spect.cpu().numpy(), cmap='hot', aspect='auto')
plt.title("Simulated SPECT Projections (Sinogram - Ideal)")
plt.xlabel("Detector Bin")
plt.ylabel("Angle Index")
plt.colorbar(label="Projected Activity")
plt.show()

# For noisy data simulation (optional):
# Define an approximate total count level for the noisy simulation
total_counts_target = 5e5 
current_total_counts = torch.sum(projections_spect).item()
if current_total_counts > 1e-9: # Avoid division by zero if ideal projections are all zero
    scaling_factor = total_counts_target / current_total_counts
    projections_for_noise = projections_spect * scaling_factor
    noisy_projections_spect = torch.poisson(torch.relu(projections_for_noise)) # Ensure non-negative mean for Poisson
    print(f"Generated noisy projections. Original sum: {current_total_counts:.2e}, Target sum for noise: {total_counts_target:.2e}, Noisy sum: {torch.sum(noisy_projections_spect).item():.2e}")
    
    plt.figure(figsize=(7,5))
    plt.imshow(noisy_projections_spect.cpu().numpy(), cmap='hot', aspect='auto')
    plt.title("Noisy SPECT Projections (Poisson)")
    plt.xlabel("Detector Bin")
    plt.ylabel("Angle Index")
    plt.colorbar(label="Counts")
    plt.show()
    projections_to_reconstruct = noisy_projections_spect
else:
    print("Skipping noise simulation as ideal projections are zero or near zero.")
    projections_to_reconstruct = projections_spect # Use ideal if it was all zeros


## 4. FBP Reconstruction
Filtered Back-Projection (FBP) is a common analytical reconstruction method.

In [None]:
fbp_reconstructor = SPECTFBPReconstructor(image_shape=img_s_spect, device=dev)

reconstructed_fbp = fbp_reconstructor.reconstruct(
    projections_to_reconstruct, # Use the (potentially noisy) projections
    angles_spect,
    filter_cutoff=0.8, # Example: Cutoff at 80% of Nyquist
    filter_window='hann' 
)
print(f"FBP reconstructed image shape: {reconstructed_fbp.shape}")

plt.figure(figsize=(6,6))
plt.imshow(reconstructed_fbp.cpu().numpy(), cmap='hot')
plt.title(f"SPECT FBP Reconstruction (Cutoff: 0.8, Window: Hann)")
plt.xlabel("X-pixel"); plt.ylabel("Y-pixel")
plt.colorbar(label="Reconstructed Activity")
plt.show()

## 5. OSEM Reconstruction
Ordered Subsets Expectation Maximization (OSEM) is an iterative algorithm often used in SPECT and PET.

In [None]:
# Initial estimate for OSEM
initial_estimate_osem = spect_projector.op_adj(projections_to_reconstruct) # Using simple backprojection as initial guess
initial_estimate_osem = torch.clamp(initial_estimate_osem, min=1e-6) # Ensure positivity and non-zero for stability
# Normalize initial estimate to have a reasonable scale if needed
if torch.max(initial_estimate_osem).item() > 1e-9:
    initial_estimate_osem = initial_estimate_osem / torch.max(initial_estimate_osem).item() * torch.max(activity_phantom).item() 
else: # If adjoint is all zero (e.g. zero sinogram), start with ones
    initial_estimate_osem = torch.ones(img_s_spect, device=dev, dtype=torch.float32)

num_osem_iterations = 20 
num_osem_subsets = 6 # 60 angles / 6 subsets = 10 angles per subset

osem_reconstructor = SPECTOSEMReconstructor(
    image_shape=img_s_spect,
    iterations=num_osem_iterations,
    num_subsets=num_osem_subsets,
    # initial_estimate can be set here, or overridden in reconstruct method
    positivity_constraint=True,
    device=dev,
    verbose=True
)

reconstructed_osem = osem_reconstructor.reconstruct(
    projections_to_reconstruct, # Use the (potentially noisy) projections
    spect_projector, # The same projector used for simulation
    initial_estimate_override=initial_estimate_osem.clone() 
)
print(f"OSEM reconstructed image shape: {reconstructed_osem.shape}")

plt.figure(figsize=(6,6))
plt.imshow(reconstructed_osem.cpu().numpy(), cmap='hot')
plt.title(f"SPECT OSEM Reconstruction (Iter: {num_osem_iterations}, Subsets: {num_osem_subsets})")
plt.xlabel("X-pixel"); plt.ylabel("Y-pixel")
plt.colorbar(label="Reconstructed Activity")
plt.show()

## 6. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
phantom_max_val_plot = activity_phantom.max().item()
if phantom_max_val_plot == 0: phantom_max_val_plot = 1.0 # Avoid vmin=vmax if phantom is empty

axes[0].imshow(activity_phantom.cpu().numpy(), cmap='hot', vmin=0, vmax=phantom_max_val_plot)
axes[0].set_title("Original Phantom")
axes[0].axis('off')

fbp_plot = reconstructed_fbp.cpu().numpy()
fbp_max_val_plot = np.max(fbp_plot) if np.max(fbp_plot) > 1e-9 else 1.0
im1 = axes[1].imshow(fbp_plot, cmap='hot', vmin=0, vmax=fbp_max_val_plot)
axes[1].set_title("FBP Reconstruction")
axes[1].axis('off')
fig.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

osem_plot = reconstructed_osem.cpu().numpy()
osem_max_val_plot = np.max(osem_plot) if np.max(osem_plot) > 1e-9 else 1.0
im2 = axes[2].imshow(osem_plot, cmap='hot', vmin=0, vmax=osem_max_val_plot)
axes[2].set_title("OSEM Reconstruction")
axes[2].axis('off')
fig.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)
            
plt.tight_layout()
plt.show()