# Enhanced Ultrasound Reconstruction Example
This notebook demonstrates B-mode like ultrasound image reconstruction using a more realistic forward model and Total Variation (TV) regularization.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path (adjust if notebook is moved)
if 'reconlib' not in os.getcwd():
    # Assuming notebook is in reconlib/modalities/ultrasound/
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    # If notebook is in root/examples or similar, and reconlib is a sibling dir
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.ultrasound.operators import UltrasoundForwardOperator
from reconlib.modalities.ultrasound.regularizers import UltrasoundTVCustomRegularizer
from reconlib.reconstructors.proximal_gradient_reconstructor import ProximalGradientReconstructor

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define Simulation Parameters
Parameters are chosen to align with typical ultrasound imaging scenarios and the user's pseudocode suggestions.

In [None]:
# Image and Grid Parameters
N = 128  # Image size (N x N pixels)
pixel_size_m = 0.0002  # Pixel size (0.2 mm)
image_shape = (N, N)
image_spacing_m = (pixel_size_m, pixel_size_m)
img_depth_m = N * pixel_size_m
img_width_m = N * pixel_size_m

# Ultrasound Physics Parameters
sound_speed_mps = 1540.0  # Speed of sound in m/s
center_frequency_hz = 5e6  # Center frequency (5 MHz)
pulse_bandwidth_fractional = 0.6 # Fractional bandwidth (e.g., 60% of center_freq)
sampling_rate_hz = 4 * center_frequency_hz # Sampling frequency (e.g., 20 MHz for 5MHz pulse, Nyquist is 2*f_max)

# Transducer Parameters
num_elements = 64
element_pitch_m = 0.0003  # Element pitch (0.3 mm)
array_width_m = (num_elements - 1) * element_pitch
element_x_coords = torch.linspace(-array_width_m / 2, array_width_m / 2, num_elements, device=device)
# Position elements slightly above the image region (e.g., at y = -2mm relative to image top)
element_y_pos_m = -0.002 
element_positions = torch.stack(
    (element_x_coords, torch.full_like(element_x_coords, element_y_pos_m)), dim=1
)
beam_sigma_rad = 0.05  # Beam width (radians) - adjust for desired focus/spread

# Attenuation
attenuation_coeff_db_cm_mhz = 0.5 # Typical for soft tissue

# RF Data Simulation Parameters
# Number of time samples for RF data: needs to cover round trip to max depth
max_time_s = 2 * img_depth_m / sound_speed_mps * 1.2 # Add 20% margin
num_samples_rf = int(np.ceil(max_time_s * sampling_rate_hz))
print(f"Calculated num_samples_rf: {num_samples_rf}")

# Reconstruction Parameters
lambda_tv_overall = 0.005 # Regularization strength for TV
pg_iterations = 50
pg_step_size = 0.01 # Initial step size for ProximalGradientReconstructor

## 3. Create Phantom Image

In [None]:
def generate_ultrasound_phantom(shape, spacing_m, device='cpu'):
    phantom = torch.zeros(shape, dtype=torch.float32, device=device)
    h, w = shape
    h_m, w_m = h * spacing_m[0], w * spacing_m[1]
    
    # Central circular inclusion
    center_y, center_x = h_m / 2, w_m / 2
    radius1 = min(h_m, w_m) / 4
    Y, X = torch.meshgrid(torch.linspace(0, h_m, h, device=device), 
                          torch.linspace(0, w_m, w, device=device), indexing='ij')
    mask1 = (X - center_x)**2 + (Y - center_y)**2 < radius1**2
    phantom[mask1] = 1.0
    
    # Smaller, off-center inclusion (hypoechoic - lower reflectivity)
    center_y2, center_x2 = h_m * 0.25, w_m * 0.75
    radius2 = min(h_m, w_m) / 8
    mask2 = (X - center_x2)**2 + (Y - center_y2)**2 < radius2**2
    phantom[mask2] = 0.3
    
    # Point scatterers
    phantom[int(h*0.7), int(w*0.3)] = 1.5
    phantom[int(h*0.6), int(w*0.6)] = 1.2
    
    # Smooth slightly to avoid harsh edges (optional)
    # phantom = torch_gaussian_filter_2d(phantom.unsqueeze(0).unsqueeze(0), kernel_size=3, sigma=0.5).squeeze()
    return phantom.to(torch.complex64) # Operator expects complex

phantom_image = generate_ultrasound_phantom(image_shape, image_spacing_m, device=device)

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(phantom_image).cpu().numpy(), cmap='gray')
plt.title('Original Phantom (Magnitude)')
plt.colorbar()
plt.show()

## 4. Initialize Ultrasound Forward Operator

In [None]:
us_operator = UltrasoundForwardOperator(
    image_shape=image_shape,
    sound_speed=sound_speed_mps,
    num_elements=num_elements,
    element_positions=element_positions,
    sampling_rate=sampling_rate_hz,
    num_samples=num_samples_rf,
    image_spacing=image_spacing_m,
    center_frequency=center_frequency_hz,
    pulse_bandwidth_fractional=pulse_bandwidth_fractional,
    beam_sigma_rad=beam_sigma_rad,
    attenuation_coeff_db_cm_mhz=attenuation_coeff_db_cm_mhz,
    device=device
)
print("UltrasoundForwardOperator initialized.")

## 5. Simulate RF Data (Forward Projection)

In [None]:
print("Simulating RF data... This might take a moment.")
rf_data_clean = us_operator.op(phantom_image)
print(f"Simulated clean RF data shape: {rf_data_clean.shape}")

# Add noise
signal_power = torch.mean(torch.abs(rf_data_clean)**2)
noise_power_ratio = 0.05 # e.g., 5% noise relative to signal power
noise_std = torch.sqrt(signal_power * noise_power_ratio / 2) # Factor of 2 for complex noise (real+imag)
noise = noise_std * (torch.randn_like(rf_data_clean.real) + 1j * torch.randn_like(rf_data_clean.imag))
rf_data_noisy = rf_data_clean + noise
print(f"Added complex Gaussian noise. Noise STD: {noise_std.item()}")

plt.figure(figsize=(10, 5))
plt.subplot(1,2,1)
plt.imshow(torch.abs(rf_data_clean).cpu().numpy(), aspect='auto', cmap='viridis')
plt.title('Clean RF Data (Magnitude)')
plt.xlabel('Time Samples')
plt.ylabel('Transducer Element')
plt.colorbar()
plt.subplot(1,2,2)
plt.imshow(torch.abs(rf_data_noisy).cpu().numpy(), aspect='auto', cmap='viridis')
plt.title('Noisy RF Data (Magnitude)')
plt.xlabel('Time Samples')
plt.ylabel('Transducer Element')
plt.colorbar()
plt.tight_layout()
plt.show()

## 6. Image Reconstruction

### 6.1 Adjoint Reconstruction (Delay-and-Sum like)

In [None]:
print("Performing Adjoint (DAS-like) reconstruction...")
adjoint_recon = us_operator.op_adj(rf_data_noisy)
print(f"Adjoint reconstructed image shape: {adjoint_recon.shape}")

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(adjoint_recon).cpu().numpy(), cmap='gray')
plt.title('Adjoint (DAS-like) Reconstruction')
plt.colorbar()
plt.show()

### 6.2 TV Regularized Reconstruction using Proximal Gradient

In [None]:
print(f"Performing TV Regularized Reconstruction (lambda_TV={lambda_tv_overall})...This may take time.")
tv_regularizer = UltrasoundTVCustomRegularizer(
    lambda_reg=lambda_tv_overall, 
    prox_iterations=10, # Iterations for the inner prox_tv loop
    is_3d=False,
    prox_step_size=0.01 # Step size for prox_tv's own gradient steps
)

# ProximalGradientReconstructor setup
pg_reconstructor_tv = ProximalGradientReconstructor(
    iterations=pg_iterations,
    step_size=pg_step_size,
    verbose=True,
    log_fn=lambda iter_num, current_image, change, grad_norm: 
        print(f"Iter {iter_num+1}: Change={change:.2e}, GradNorm={grad_norm:.2e}") if (iter_num % 10 ==0 or iter_num == pg_iterations -1) else None
)

# Initial estimate for reconstruction (can be adjoint or zeros)
initial_estimate_tv = adjoint_recon.clone() # Start from adjoint
# initial_estimate_tv = torch.zeros_like(phantom_image, device=device) # Start from zeros

tv_recon_image = pg_reconstructor_tv.reconstruct(
    kspace_data=rf_data_noisy, # Our 'y' is the RF data
    forward_op_fn=lambda img, smaps: us_operator.op(img), # smaps not used by US op
    adjoint_op_fn=lambda data, smaps: us_operator.op_adj(data), # smaps not used
    regularizer_prox_fn=lambda img, step: tv_regularizer.proximal_operator(img, step),
    sensitivity_maps=None, # No coil sensitivities in this US model
    x_init=initial_estimate_tv
)
print(f"TV Reconstructed image shape: {tv_recon_image.shape}")

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(tv_recon_image).cpu().numpy(), cmap='gray')
plt.title(f'TV Regularized Reconstruction (lambda={lambda_tv_overall}, {pg_iterations} iters)')
plt.colorbar()
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

axes[0].imshow(torch.abs(phantom_image).cpu().numpy(), cmap='gray')
axes[0].set_title('Original Phantom')
axes[0].axis('off')

axes[1].imshow(torch.abs(adjoint_recon).cpu().numpy(), cmap='gray')
axes[1].set_title('Adjoint (DAS-like) Recon')
axes[1].axis('off')

axes[2].imshow(torch.abs(tv_recon_image).cpu().numpy(), cmap='gray')
axes[2].set_title(f'TV Recon (lambda={lambda_tv_overall})')
axes[2].axis('off')

plt.tight_layout()
plt.show()