# PET Reconstruction with OSEM Example

This notebook demonstrates a basic PET reconstruction workflow using the Ordered Subsets Expectation Maximization (OSEM) algorithm. We will:
1. Generate a simple 2D phantom.
2. Define scanner geometry for PET.
3. Simulate projection data (sinogram) using a System Matrix, optionally adding Poisson noise.
4. Reconstruct the image from the sinogram using the OSEM optimizer.
5. Visualize the original phantom, initial guess, and the reconstructed image.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Ensure reconlib is in the Python path (e.g., if running from examples folder)
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '..')))

from reconlib.pet_ct_simulation import PhantomGenerator, simulate_projection_data
from reconlib.geometry import ScannerGeometry, SystemMatrix
from reconlib.optimizers import OrderedSubsetsExpectationMaximization
from reconlib.plotting import plot_projection_data, visualize_reconstruction
from reconlib.regularizers.common import NonnegativityConstraint

%matplotlib inline

## Setup Device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Phantom Generation

In [None]:
img_size = (128, 128)
phantom_gen = PhantomGenerator(device=device)

# Using 'circles_pet' which was implemented in the previous step
# Or use 'shepp-logan-pet' if its 2D version is also implemented
try:
    phantom = phantom_gen.generate(size=img_size, phantom_type='circles_pet')
    print(f"Generated 'circles_pet' phantom of shape: {phantom.shape}")
except NotImplementedError:
    print("'circles_pet' not implemented, trying 'shepp-logan-pet'")
    try:
        phantom = phantom_gen.generate(size=img_size, phantom_type='shepp-logan-pet')
        print(f"Generated 'shepp-logan-pet' phantom of shape: {phantom.shape}")
    except NotImplementedError as e:
        print(f"Error: {e}. Creating a dummy phantom.")
        phantom = torch.zeros(1, 1, *img_size, device=device)
        y_coords, x_coords = torch.ogrid[-img_size[0]//2:img_size[0]//2, -img_size[1]//2:img_size[1]//2]
        y_coords = y_coords.to(device)
        x_coords = x_coords.to(device)
        mask1 = x_coords*x_coords + y_coords*y_coords <= (min(img_size[0],img_size[1])//3)**2
        phantom[0,0][mask1] = 1.0

# Visualize phantom (squeezing batch and channel dimensions for 2D plot)
visualize_reconstruction(phantom.squeeze().cpu().numpy(), main_title="Original Phantom")
plt.show()

## 2. Scanner Geometry Definition

In [None]:
num_detectors = 180  # Number of detector bins per projection angle
num_angles = 180     # Number of projection angles
angles = np.linspace(0, np.pi, num_angles, endpoint=False) # Angles from 0 to pi (180 degrees)

scanner_geo = ScannerGeometry(
    geometry_type='cylindrical_pet', 
    angles=angles, 
    n_detector_pixels=num_detectors, 
    detector_size=np.array([4.0]), # Example detector element size (e.g., 4mm)
    # detector_positions can be None for cylindrical_pet if detector_radius is given
    detector_radius=350.0  # Example radius of the PET scanner ring (e.g., 350mm)
)

print(f"Scanner geometry type: {scanner_geo.geometry_type}")
print(f"Number of angles: {len(scanner_geo.angles)}")
print(f"Number of detector pixels: {scanner_geo.n_detector_pixels}")

# Optional: Visualize geometry (might be basic for PET)
# try:
#     scanner_geo.visualize_geometry()
#     plt.show()
# except NotImplementedError as e:
#     print(f"Geometry visualization not implemented for {scanner_geo.geometry_type}: {e}")

## 3. System Matrix and Data Simulation

In [None]:
sys_matrix = SystemMatrix(scanner_geometry=scanner_geo, img_size=img_size, device=device)

# Simulate clean projection data
# Phantom shape should be (B, C, H, W) for the projector
if phantom.ndim == 2: # If phantom is only H,W from a simple manual creation
    phantom_for_proj = phantom.unsqueeze(0).unsqueeze(0).to(device)
elif phantom.ndim == 3: # If D,H,W, add batch dim for now (assuming single channel)
     phantom_for_proj = phantom.unsqueeze(0).to(device)
else: # Assuming already B,C,H,W or similar compatible shape
     phantom_for_proj = phantom.to(device)

print(f"Phantom shape for projection: {phantom_for_proj.shape}")

projections_clean = sys_matrix.forward_project(phantom_for_proj)
print(f"Clean projection data shape: {projections_clean.shape}")

# Simulate noisy projection data
projections_noisy = simulate_projection_data(phantom_for_proj, sys_matrix, 
                                           noise_model='poisson', intensity_scale=20000) # Higher scale = less noise
print(f"Noisy projection data shape: {projections_noisy.shape}")

# Visualize the noisy sinogram (squeezing batch and channel for 2D plot if necessary)
plot_projection_data(projections_noisy.squeeze().cpu().numpy(), title="Noisy Sinogram (Simulated)")
plt.show()

## 4. OSEM Reconstruction

In [None]:
# Ensure initial_image has the same batch and channel dimensions as phantom_for_proj for consistency
initial_image = torch.ones_like(phantom_for_proj, dtype=torch.float32, device=device) * torch.mean(phantom_for_proj) # A better start than just ones
initial_image[initial_image <= 1e-9] = 1e-9 # Ensure positivity

num_iterations = 20 # Number of full iterations
num_subsets = 10    # Number of subsets for OSEM

osem_reconstructor = OrderedSubsetsExpectationMaximization(
    system_matrix=sys_matrix, 
    num_subsets=num_subsets, 
    num_iterations=num_iterations, 
    device=device,
    verbose=True
)

# The `solve` method of Optimizer expects: 
# k_space_data (here, our projection_data),
# forward_op (our system_matrix),
# regularizer (None for basic OSEM),
# initial_guess (our initial_image).
reconstructed_image_osem = osem_reconstructor.solve(
    k_space_data=projections_noisy, 
    forward_op=sys_matrix, # OSEM will use the system_matrix it was initialized with
    initial_guess=initial_image
)

print(f"Reconstructed image shape: {reconstructed_image_osem.shape}")

## 5. Visualization of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Original Phantom
axes[0].imshow(phantom_for_proj.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=phantom_for_proj.max())
axes[0].set_title('Original Phantom')
axes[0].axis('off')

# Initial Image (if different from just ones)
axes[1].imshow(initial_image.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=initial_image.max())
axes[1].set_title('Initial Image')
axes[1].axis('off')

# OSEM Reconstructed Image
im = axes[2].imshow(reconstructed_image_osem.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=reconstructed_image_osem.max())
axes[2].set_title(f'OSEM Recon ({num_iterations} iter, {num_subsets} subsets)')
axes[2].axis('off')

fig.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04) # Add colorbar to the reconstructed image
plt.tight_layout()
plt.show()

# Alternative visualization using the dedicated function
# visualize_reconstruction(phantom_for_proj.squeeze().cpu().numpy(), main_title="Original Phantom")
# plt.show()
# visualize_reconstruction(reconstructed_image_osem.squeeze().cpu().numpy(), main_title=f"OSEM Recon ({num_iterations} iter, {num_subsets} subsets)")
# plt.show()