# Synthetic Aperture Radar (SAR) Reconstruction Example
This notebook demonstrates basic SAR image reconstruction from simulated k-space (visibility) data using a Fourier-based forward model with physically derived k-space sampling and Total Variation (TV) regularized reconstruction.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path
if 'reconlib' not in os.getcwd():
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.sar.operators import SARForwardOperator
from reconlib.modalities.sar.reconstructors import tv_reconstruction_sar
# tv_reconstruction_sar uses UltrasoundTVCustomRegularizer by default

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define SAR Parameters
We will define image parameters, physical parameters for the SAR acquisition (which determine k-space sampling), and reconstruction algorithm parameters.

In [None]:
# Image Parameters
Ny, Nx = 128, 128  # Image height (Ny) and width (Nx)
image_shape_sar = (Ny, Nx)

# SAR Specific Physical Parameters (for uv-coordinate generation)
num_angles = 256 # Number of azimuth angles for k-space sampling
sensor_azimuth_angles = torch.linspace(0, 2 * np.pi, num_angles, device=device) # Sample full circle
wavelength = 0.03  # Example: 0.03m (X-band, 10 GHz)
# FOV in meters (Ny corresponds to fov_y, Nx to fov_x)
fov = (20.0, 20.0)  # e.g., 20m x 20m scene

print(f"Defined physical parameters for {num_angles} k-space samples.")

# Reconstruction Parameters
lambda_tv_sar = 0.001
sar_pg_iterations = 50
sar_pg_step_size = 0.05
sar_tv_prox_iters = 5
sar_tv_prox_step = 0.01

## 3. Create SAR Target Reflectivity Phantom

In [None]:
def generate_sar_phantom(shape, device='cpu'):
    # shape: (Ny, Nx)
    phantom = torch.zeros(shape, dtype=torch.complex64, device=device)
    h, w = shape
    
    # Simulate a few point-like targets and a rectangular area
    phantom[h // 4, w // 4] = 2.0 + 1j
    phantom[h // 2, w // 2 + w // 8] = 1.5 - 0.5j
    phantom[int(h*0.7), int(w*0.6)] = 2.5
    
    # Rectangular region
    phantom[int(h*0.2):int(h*0.3), int(w*0.6):int(w*0.8)] = 1.0
    return phantom

sar_phantom = generate_sar_phantom(image_shape_sar, device=device)

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(sar_phantom).cpu().numpy(), cmap='gray', origin='lower')
plt.title('Original SAR Target Phantom (Reflectivity Magnitude)')
plt.xlabel('X-pixels (Range/Azimuth)')
plt.ylabel('Y-pixels (Range/Azimuth)')
plt.colorbar(label='Reflectivity')
plt.show()

## 4. Initialize SAR Forward Operator
We initialize the `SARForwardOperator` using the defined image and physical parameters. We'll explicitly use the FFT-based mode.

In [None]:
# Initialize SARForwardOperator using physical parameters
# This will use the FFT-based fallback by default (use_nufft=False)
sar_operator_inst = SARForwardOperator(
    image_shape=image_shape_sar,
    wavelength=wavelength,
    sensor_azimuth_angles=sensor_azimuth_angles,
    fov=fov,
    device=device,
    use_nufft=False # Explicitly use FFT-based operator
)
print("SARForwardOperator initialized (using FFT-based mode).")

## 5. Simulate SAR Raw Data (Visibilities)

In [None]:
print("Simulating SAR raw data (visibilities)...")
visibilities_clean = sar_operator_inst.op(sar_phantom)
print(f"Simulated clean visibilities shape: {visibilities_clean.shape}")

# Add complex Gaussian noise
signal_power_sar = torch.mean(torch.abs(visibilities_clean)**2)
noise_power_ratio_sar = 0.05 # 5% noise relative to signal power
noise_std_sar = torch.sqrt(signal_power_sar * noise_power_ratio_sar / 2) # Factor of 2 for complex
noise_sar = noise_std_sar * (torch.randn_like(visibilities_clean.real) + 1j * torch.randn_like(visibilities_clean.imag))
visibilities_noisy = visibilities_clean + noise_sar
print(f"Added complex Gaussian noise. Noise STD: {noise_std_sar.item()}")

# Visualize k-space samples (optional, can be dense)
plt.figure(figsize=(6,6))
uv_cpu = sar_operator_inst.raw_uv_coordinates.cpu().numpy()
plt.scatter(uv_cpu[:,0], uv_cpu[:,1], c=torch.abs(visibilities_noisy).cpu().numpy(), cmap='viridis', s=5, vmax=torch.quantile(torch.abs(visibilities_noisy),0.95).cpu())
plt.title('Noisy k-space Samples (Visibilities Magnitude)')
plt.xlabel('u (FFT-scaled kx-related)')
plt.ylabel('v (FFT-scaled ky-related)')
plt.colorbar()
plt.axis('square')
# plt.xlim(-Nx/2 -5, Nx/2 + 5); plt.ylim(-Ny/2 -5, Ny/2 + 5) # Removed to auto-scale
plt.show()

## 6. Image Reconstruction

### 6.1 Adjoint Reconstruction ('Dirty Image')

In [None]:
print("Performing Adjoint (Dirty Image) reconstruction...")
sar_dirty_image = sar_operator_inst.op_adj(visibilities_noisy)
print(f"Dirty image shape: {sar_dirty_image.shape}")

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(sar_dirty_image).cpu().numpy(), cmap='gray', origin='lower')
plt.title('SAR Adjoint Recon (Dirty Image)')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar(label='Magnitude')
plt.show()

### 6.2 TV Regularized Reconstruction

In [None]:
print(f"Performing TV Regularized SAR Reconstruction (lambda_TV={lambda_tv_sar})...This may take some time.")

sar_tv_recon_image = tv_reconstruction_sar(
    y_sar_data=visibilities_noisy,
    sar_operator=sar_operator_inst,
    lambda_tv=lambda_tv_sar,
    iterations=sar_pg_iterations,
    step_size=sar_pg_step_size,
    tv_prox_iterations=sar_tv_prox_iters,
    tv_prox_step_size=sar_tv_prox_step,
    verbose=True
)
print(f"TV Reconstructed SAR image shape: {sar_tv_recon_image.shape}")

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(sar_tv_recon_image).cpu().numpy(), cmap='gray', origin='lower')
plt.title(f'TV Regularized SAR Recon (lambda={lambda_tv_sar}, {sar_pg_iterations} iters)')
plt.xlabel('X-pixels'); plt.ylabel('Y-pixels')
plt.colorbar(label='Magnitude')
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle('SAR Reconstruction Comparison', fontsize=16)

im0 = axes[0].imshow(torch.abs(sar_phantom).cpu().numpy(), cmap='gray', origin='lower')
axes[0].set_title('Original SAR Phantom')
axes[0].set_xlabel('X-pixels'); axes[0].set_ylabel('Y-pixels')
fig.colorbar(im0, ax=axes[0], shrink=0.8)

im1 = axes[1].imshow(torch.abs(sar_dirty_image).cpu().numpy(), cmap='gray', origin='lower')
axes[1].set_title('Adjoint (Dirty Image) Recon')
axes[1].set_xlabel('X-pixels'); axes[1].set_ylabel('Y-pixels')
fig.colorbar(im1, ax=axes[1], shrink=0.8)

im2 = axes[2].imshow(torch.abs(sar_tv_recon_image).cpu().numpy(), cmap='gray', origin='lower')
axes[2].set_title(f'TV Regularized SAR Recon (lambda={lambda_tv_sar})')
axes[2].set_xlabel('X-pixels'); axes[2].set_ylabel('Y-pixels')
fig.colorbar(im2, ax=axes[2], shrink=0.8)

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()