# Photoacoustic Tomography (PAT) Reconstruction Demo
This notebook demonstrates a basic reconstruction pipeline for Photoacoustic Tomography using placeholder operators and reconstructors. The actual forward model and back-projection in `PhotoacousticOperator` are currently placeholders and would need to be replaced with accurate implementations (e.g., using k-Wave or analytical models) for meaningful results.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Adjust path to import from reconlib (if running from outside the root directory)
import sys
# Example: sys.path.append('../../../') # Adjust based on your notebook's location

from reconlib.modalities.photoacoustic.operators import PhotoacousticOperator
from reconlib.modalities.photoacoustic.reconstructors import tv_reconstruction_pat
from reconlib.modalities.photoacoustic.utils import generate_pat_phantom, plot_pat_results

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Setup Parameters and Phantom

In [None]:
image_shape_pat = (128, 128)  # (Ny, Nx) or (Nz, Ny, Nx) for 3D
num_sensors_pat = 64          # Number of acoustic sensors
sound_speed_mps = 1500      # Speed of sound in m/s
num_time_samples = 256      # Number of time samples recorded by each sensor

# Generate a simple phantom (initial pressure distribution)
true_initial_pressure = generate_pat_phantom(image_shape_pat, num_circles=5, device=device)

# Define sensor geometry (e.g., circular array around the phantom)
angles = torch.linspace(0, 2 * np.pi, num_sensors_pat, device=device, endpoint=False)
radius_m = (max(image_shape_pat) / 2) * 0.001 # Assume pixel size is 1mm, radius slightly larger than half FOV
sensor_positions_pat = torch.stack([
    radius_m * torch.cos(angles),
    radius_m * torch.sin(angles) 
], dim=1)

if len(image_shape_pat) == 3: # Example for 3D, sensors in XY plane
    z_pos = torch.zeros(num_sensors_pat, 1, device=device) + image_shape_pat[0] * 0.001 / 2 # Centered in Z
    sensor_positions_pat = torch.cat((sensor_positions_pat, z_pos), dim=1)

plt.figure(figsize=(6,6))
plt.imshow(true_initial_pressure.cpu().numpy())
plt.title('True Initial Pressure Phantom')
plt.xlabel('X (pixels)')
plt.ylabel('Y (pixels)')
if sensor_positions_pat.shape[1] == 2:
    plt.scatter(sensor_positions_pat[:,0].cpu()/(0.001) + image_shape_pat[1]/2, sensor_positions_pat[:,1].cpu()/(0.001) + image_shape_pat[0]/2, c='red', marker='x', label='Sensors')
plt.legend()
plt.show()

## 2. Initialize Operator and Simulate Data

In [None]:
pat_operator = PhotoacousticOperator(
    image_shape=image_shape_pat,
    sensor_positions=sensor_positions_pat,
    sound_speed=sound_speed_mps,
    device=device
)

# Simulate sensor data using the forward operator
# Note: This uses the placeholder forward model!
y_sensor_data = pat_operator.op(true_initial_pressure)

print(f"Simulated sensor data shape: {y_sensor_data.shape}")

# Visualize sensor data (sinogram-like representation)
plt.figure(figsize=(8,5))
plt.imshow(y_sensor_data.cpu().numpy(), aspect='auto', cmap='viridis')
plt.title('Simulated Sensor Data (Placeholder)')
plt.xlabel('Time Samples')
plt.ylabel('Sensor Index')
plt.colorbar(label='Signal Amplitude')
plt.show()

## 3. Perform Reconstruction
We will use Total Variation (TV) regularization with the Proximal Gradient algorithm.

In [None]:
lambda_tv_pat = 0.01       # TV regularization strength
iterations_pat = 20       # Number of proximal gradient iterations (low for demo)
step_size_pat = 0.005     # Step size for proximal gradient
tv_prox_iters = 5         # Iterations for TV prox (Chambolle-Pock, etc.)

# Perform reconstruction
# Note: This uses the placeholder adjoint model!
reconstructed_pressure = tv_reconstruction_pat(
    y_sensor_data=y_sensor_data,
    pat_operator=pat_operator,
    lambda_tv=lambda_tv_pat,
    iterations=iterations_pat,
    step_size=step_size_pat,
    tv_prox_iterations=tv_prox_iters,
    is_3d_tv=len(image_shape_pat)==3,
    verbose=True
)

print(f"Reconstructed pressure map shape: {reconstructed_pressure.shape}")

## 4. Display Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

im1 = axes[0].imshow(true_initial_pressure.cpu().numpy(), cmap='viridis')
axes[0].set_title('Ground Truth Initial Pressure')
axes[0].set_xlabel('X (pixels)')
axes[0].set_ylabel('Y (pixels)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

im2 = axes[1].imshow(reconstructed_pressure.cpu().numpy(), cmap='viridis')
axes[1].set_title(f'Reconstructed Pressure (TV, {iterations_pat} iters - Placeholders)')
axes[1].set_xlabel('X (pixels)')
axes[1].set_ylabel('Y (pixels)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# Using the utility plot function (currently a placeholder itself)
plot_pat_results(
    initial_pressure_map=true_initial_pressure,
    reconstructed_map=reconstructed_pressure,
    sensor_data=y_sensor_data,
    sensor_positions=sensor_positions_pat
)

## 5. Further Considerations (Beyond Placeholders)
To make this notebook fully functional for PAT:
1. **Implement `PhotoacousticOperator.op`**: Replace the placeholder with an accurate physical model of photoacoustic wave generation and propagation (e.g., using k-Wave, analytical solutions for simple geometries, or finite difference/element methods).
2. **Implement `PhotoacousticOperator.op_adj`**: Implement the corresponding adjoint operation. For many PAT systems, this is a form of back-projection (e.g., universal back-projection, time reversal). Ensure it passes the dot-product test with the forward operator.
3. **Refine `tv_reconstruction_pat`**: Adjust parameters, potentially use more advanced regularizers (e.g., L1 wavelet), or explore different optimization algorithms if needed.
4. **Realistic Phantom and Parameters**: Use a more realistic phantom and sensor geometry. Ensure acoustic parameters (sound speed, time sampling, sensor locations) are consistent and physically meaningful.
5. **Noise Handling**: Add noise to the simulated sensor data (`y_sensor_data`) to test robustness.
6. **Update `plot_pat_results`**: Implement actual plotting in `utils.py`.