# Astronomical Interferometry Reconstruction Example
This notebook demonstrates basic image reconstruction for astronomical interferometry (e.g., radio astronomy). It simulates visibility data from a sky brightness model, then reconstructs the sky image using a direct inversion (dirty image) and a TV-regularized approach.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path
if 'reconlib' not in os.getcwd():
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.astronomical.operators import AstronomicalInterferometryOperator
from reconlib.modalities.astronomical.reconstructors import tv_reconstruction_astro
# tv_reconstruction_astro uses UltrasoundTVCustomRegularizer by default

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define Astronomical Observation Parameters

In [None]:
# Sky Image Parameters
Ny, Nx = 128, 128  # Image height (pixels, e.g., Dec) and width (pixels, e.g., RA)
image_shape_astro = (Ny, Nx)

# (u,v) Coverage / Baselines
num_visibilities = 3000  # Number of visibility samples

# Simulate some (u,v) coordinates. In reality, these come from telescope array configuration.
# For this example, we generate random points in a circular region of k-space.
max_uv = min(Nx, Ny) / 2.5 # Max spatial frequency to sample (controls resolution)
uv_radius = torch.rand(num_visibilities, device=device) * max_uv
uv_angle = torch.rand(num_visibilities, device=device) * 2 * np.pi
u_coords = uv_radius * torch.cos(uv_angle)
v_coords = uv_radius * torch.sin(uv_angle)
uv_coordinates = torch.stack((u_coords, v_coords), dim=1)

print(f"Generated {uv_coordinates.shape[0]} (u,v) sample points.")
plt.figure(figsize=(5,5))
plt.scatter(u_coords.cpu().numpy(), v_coords.cpu().numpy(), s=1, alpha=0.5)
plt.title('Simulated (u,v) Coverage')
plt.xlabel('u (spatial frequency)')
plt.ylabel('v (spatial frequency)')
plt.axis('square')
plt.show()

# Reconstruction Parameters
lambda_tv_astro = 0.0005    # TV regularization strength (needs careful tuning!)
astro_pg_iterations = 75   # Iterations for Proximal Gradient
astro_pg_step_size = 0.02  # Step size for Proximal Gradient
astro_tv_prox_iters = 5    # Inner iterations for the custom TV prox
astro_tv_prox_step = 0.01 # Inner step size for custom TV prox

## 3. Create Sky Brightness Phantom

In [None]:
def generate_sky_phantom(shape, device='cpu'):
    # shape: (Ny, Nx)
    phantom = torch.zeros(shape, dtype=torch.float32, device=device)
    h, w = shape
    
    # Simulate a few point sources (stars)
    phantom[h // 4, w // 4] = 2.0
    phantom[h // 2 + 10, w // 2 - 20] = 1.5
    phantom[int(h*0.7), int(w*0.65)] = 1.8
    
    # Simulate a small extended source (galaxy/nebula)
    center_y, center_x = h // 1.5, w // 3
    radius_y, radius_x = h // 10, w // 12
    Y, X = torch.meshgrid(torch.arange(h, device=device), 
                          torch.arange(w, device=device), indexing='ij')
    mask_ellipse = ((X - center_x)/radius_x)**2 + ((Y - center_y)/radius_y)**2 < 1
    phantom[mask_ellipse] = 0.8
    return phantom.to(torch.complex64) # Operator expects complex

sky_phantom = generate_sky_phantom(image_shape_astro, device=device)

plt.figure(figsize=(6,6))
plt.imshow(torch.abs(sky_phantom).cpu().numpy(), cmap='hot', origin='lower')
plt.title('Original Sky Phantom (Brightness)')
plt.xlabel('RA direction (pixels)')
plt.ylabel('Dec direction (pixels)')
plt.colorbar(label='Brightness')
plt.show()

## 4. Initialize Astronomical Interferometry Operator

In [None]:
astro_operator_inst = AstronomicalInterferometryOperator(
    image_shape=image_shape_astro,
    uv_coordinates=uv_coordinates,
    device=device
)
print("AstronomicalInterferometryOperator initialized.")

## 5. Simulate Visibility Data

In [None]:
print("Simulating visibility data...")
visibilities_clean = astro_operator_inst.op(sky_phantom)
print(f"Simulated clean visibilities shape: {visibilities_clean.shape}")

# Add complex Gaussian noise (thermal noise in receivers)
signal_power_astro = torch.mean(torch.abs(visibilities_clean)**2)
noise_power_ratio_astro = 0.1 # 10% noise power relative to signal power
noise_std_astro = torch.sqrt(signal_power_astro * noise_power_ratio_astro / 2) # Factor of 2 for complex
noise_astro = noise_std_astro * (torch.randn_like(visibilities_clean.real) + 1j * torch.randn_like(visibilities_clean.imag))
visibilities_noisy = visibilities_clean + noise_astro
print(f"Added complex Gaussian noise to visibilities. Noise STD: {noise_std_astro.item()}")

## 6. Sky Image Reconstruction

### 6.1 Adjoint Reconstruction ('Dirty Image')

In [None]:
print("Performing Adjoint ('Dirty Image') reconstruction...")
astro_dirty_image = astro_operator_inst.op_adj(visibilities_noisy)
print(f"Dirty image shape: {astro_dirty_image.shape}")

plt.figure(figsize=(7,6))
# Displaying FFT shifted version of dirty image magnitude for better centering of features
dirty_image_display = torch.fft.fftshift(torch.abs(astro_dirty_image)).cpu().numpy()
plt.imshow(dirty_image_display, cmap='hot', origin='lower', 
           vmax=np.percentile(dirty_image_display, 99.5)) # Clip extreme values for viz
plt.title('Astronomical Adjoint Recon (Dirty Image Magnitude)')
plt.xlabel('RA direction (pixels)')
plt.ylabel('Dec direction (pixels)')
plt.colorbar(label='Amplitude')
plt.show()

### 6.2 TV Regularized Reconstruction

In [None]:
print(f"Performing TV Regularized Astronomical Reconstruction (lambda_TV={lambda_tv_astro})...This may take some time.")

astro_tv_recon_image = tv_reconstruction_astro(
    y_visibilities=visibilities_noisy,
    astro_operator=astro_operator_inst,
    lambda_tv=lambda_tv_astro,
    iterations=astro_pg_iterations,
    step_size=astro_pg_step_size,
    tv_prox_iterations=astro_tv_prox_iters,
    tv_prox_step_size=astro_tv_prox_step,
    verbose=True
)
print(f"TV Reconstructed Astronomical Image shape: {astro_tv_recon_image.shape}")

plt.figure(figsize=(7,6))
tv_recon_display = torch.fft.fftshift(torch.abs(astro_tv_recon_image)).cpu().numpy()
plt.imshow(tv_recon_display, cmap='hot', origin='lower', 
           vmax=np.percentile(tv_recon_display, 99.8)) # Clip for better visualization
plt.title(f'TV Regularized Astro Recon (lambda={lambda_tv_astro}, {astro_pg_iterations} iters)')
plt.xlabel('RA direction (pixels)')
plt.ylabel('Dec direction (pixels)')
plt.colorbar(label='Brightness')
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
fig.suptitle('Astronomical Imaging Reconstruction Comparison', fontsize=16)
plot_kwargs_astro = {'cmap':'hot', 'origin':'lower'}

im0 = axes[0].imshow(torch.abs(sky_phantom).cpu().numpy(), **plot_kwargs_astro)
axes[0].set_title('Original Sky Phantom')
axes[0].set_xlabel('RA (pix)'); axes[0].set_ylabel('Dec (pix)')
fig.colorbar(im0, ax=axes[0], shrink=0.8, label='Brightness')

abs_dirty_shifted = torch.fft.fftshift(torch.abs(astro_dirty_image)).cpu().numpy()
im1 = axes[1].imshow(abs_dirty_shifted, **plot_kwargs_astro, vmax=np.percentile(abs_dirty_shifted, 99.5))
axes[1].set_title('Adjoint (Dirty Image) Recon')
axes[1].set_xlabel('RA (pix)'); axes[1].set_ylabel('Dec (pix)')
fig.colorbar(im1, ax=axes[1], shrink=0.8, label='Amplitude')

abs_tv_shifted = torch.fft.fftshift(torch.abs(astro_tv_recon_image)).cpu().numpy()
im2 = axes[2].imshow(abs_tv_shifted, **plot_kwargs_astro, vmax=np.percentile(abs_tv_shifted, 99.8))
axes[2].set_title(f'TV Regularized Recon (lambda={lambda_tv_astro})')
axes[2].set_xlabel('RA (pix)'); axes[2].set_ylabel('Dec (pix)')
fig.colorbar(im2, ax=axes[2], shrink=0.8, label='Brightness')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()