# Terahertz (THz) Imaging Reconstruction Demo (Fourier Sampling Model)
This notebook demonstrates a basic reconstruction pipeline for Terahertz Imaging using a Fourier sampling model. 
The `TerahertzOperator` simulates acquiring samples of the image's 2D Fourier Transform at specified `k_space_locations`. 
This is a simplified representation of modalities like THz holography or certain k-space scanning systems. 
Reconstruction uses Total Variation (TV) regularization.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Adjust path to import from reconlib 
import sys
# sys.path.append('../../../') # Adjust as needed

from reconlib.modalities.terahertz.operators import TerahertzOperator
from reconlib.modalities.terahertz.reconstructors import tv_reconstruction_thz
from reconlib.modalities.terahertz.utils import generate_thz_phantom, plot_thz_results

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Setup Parameters and Phantom

In [None]:
image_shape_thz = (64, 64)  # (Ny, Nx) for the image (e.g., material property map)
Ny, Nx = image_shape_thz

# Generate a simple real-valued phantom 
true_thz_image = generate_thz_phantom(image_shape_thz, num_shapes=2, shape_type='rect', device=device)

plt.figure(figsize=(6,6))
plt.imshow(true_thz_image.cpu().numpy(), cmap='magma')
plt.title('True THz Phantom Image (Real-valued)')
plt.xlabel('X (pixels)')
plt.ylabel('Y (pixels)')
plt.colorbar(label='Property Value')
plt.show()

## 2. Define k-space Sampling and Initialize Operator
The `TerahertzOperator` (Fourier Sampling) requires `k_space_locations` which are (kx, ky) coordinates.
These should be in the range `[-N/2, N/2-1]` for each dimension.

In [None]:
# Define k-space sampling locations (kx, ky)
num_measurements_thz = (Ny * Nx) // 3  # Example: 1/3rd sampling of k-space

# Random sparse k-space sampling
kx_samples = torch.randint(-Nx // 2, Nx // 2, (num_measurements_thz,), device=device).float()
ky_samples = torch.randint(-Ny // 2, Ny // 2, (num_measurements_thz,), device=device).float()
k_space_sampling_locations = torch.stack([kx_samples, ky_samples], dim=1)

print(f"Number of k-space samples: {k_space_sampling_locations.shape[0]}")

# Visualize k-space sampling pattern
plt.figure(figsize=(5,5))
plt.scatter(k_space_sampling_locations[:,0].cpu().numpy(), 
            k_space_sampling_locations[:,1].cpu().numpy(), s=5, alpha=0.7)
plt.title('k-space Sampling Pattern (kx, ky)')
plt.xlabel('kx'); plt.ylabel('ky')
plt.xlim([-Nx//2 -1, Nx//2 + 1]); plt.ylim([-Ny//2 -1, Ny//2 + 1])
plt.axhline(0, color='black', lw=0.5); plt.axvline(0, color='black', lw=0.5)
plt.grid(True, linestyle=':', alpha=0.5)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

thz_operator = TerahertzOperator(
    image_shape=image_shape_thz,
    k_space_locations=k_space_sampling_locations,
    device=device
)

# Simulate THz k-space data: y = FFT(x)[k_locations]
y_kspace_measurements = thz_operator.op(true_thz_image) # Output is complex

# Add some noise
snr_db_thz = 20 # SNR in dB
signal_power_thz = torch.mean(torch.abs(y_kspace_measurements)**2)
noise_power_thz = signal_power_thz / (10**(snr_db_thz / 10))
noise_std_thz = torch.sqrt(noise_power_thz / 2) # For complex Gaussian (half power in real, half in imag)
noise_thz = torch.complex(torch.randn_like(y_kspace_measurements.real) * noise_std_thz, 
                          torch.randn_like(y_kspace_measurements.imag) * noise_std_thz)
y_kspace_measurements_noisy = y_kspace_measurements + noise_thz

print(f"Simulated k-space measurement data shape: {y_kspace_measurements_noisy.shape}, dtype: {y_kspace_measurements_noisy.dtype}")

# Visualize a part of the k-space data (e.g., magnitude)
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.plot(torch.abs(y_kspace_measurements_noisy).cpu().numpy()[:200])
plt.title('Magnitude of first 200 K-space Samples (Noisy)')
plt.xlabel('Measurement Index'); plt.ylabel('Magnitude')
plt.subplot(1,2,2)
plt.plot(torch.angle(y_kspace_measurements_noisy).cpu().numpy()[:200])
plt.title('Phase of first 200 K-space Samples (Noisy)')
plt.xlabel('Measurement Index'); plt.ylabel('Phase (rad)')
plt.tight_layout()
plt.show()

## 3. Perform Reconstruction
Using Total Variation (TV) regularization with Proximal Gradient. The image to reconstruct is real-valued.

In [None]:
lambda_tv_thz = 0.001      # TV regularization strength (CRITICAL - needs tuning)
iterations_thz = 75       # Number of proximal gradient iterations
step_size_thz = 5e-3      # Step size for proximal gradient (CRITICAL - adjust based on data scaling)
tv_prox_iters_thz = 5     # Iterations for TV prox

reconstructed_thz_img = tv_reconstruction_thz(
    y_thz_data=y_kspace_measurements_noisy, # These are complex k-space samples
    thz_operator=thz_operator,
    lambda_tv=lambda_tv_thz,
    iterations=iterations_thz,
    step_size=step_size_thz,
    tv_prox_iterations=tv_prox_iters_thz,
    is_3d_tv=False, # Current operator is 2D
    verbose=True
)

print(f"Reconstructed THz image shape: {reconstructed_thz_img.shape}, dtype: {reconstructed_thz_img.dtype}")

## 4. Display Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

im1 = axes[0].imshow(true_thz_image.cpu().numpy(), cmap='magma')
axes[0].set_title('Ground Truth THz Image')
axes[0].set_xlabel('X (pixels)')
axes[0].set_ylabel('Y (pixels)')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

img_display = reconstructed_thz_img.cpu().numpy()
im2 = axes[1].imshow(img_display, cmap='magma')
axes[1].set_title(f'Reconstructed THz Image (TV, {iterations_thz} iters - Fourier Sampling)')
axes[1].set_xlabel('X (pixels)')
axes[1].set_ylabel('Y (pixels)')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# Using the utility plot function (currently a placeholder itself)
plot_thz_results(
    true_image=true_thz_image,
    reconstructed_image=reconstructed_thz_img,
    measurement_data=y_kspace_measurements_noisy # Pass k-space data
)

## 5. Further Considerations
1.  **k-space Sampling Strategy**: The choice of `k_space_locations` is critical. Random sparse sampling is one option. Other strategies (e.g., radial, spiral, specific patterns for THz holography) would depend on the actual THz system being modeled.
2.  **Data Scaling**: The magnitude of k-space data can vary significantly. Normalization or careful choice of `step_size` and `lambda_tv` is important for stable reconstruction.
3.  **Phase Information**: This model uses complex k-space data. The quality of phase information in real THz measurements can be a challenge and might affect reconstruction.
4.  **Regularization**: While TV promotes piecewise smoothness, other regularizers (e.g., L1-wavelet for sparsity) might be appropriate depending on the expected image content.
5.  **Computational Cost**: For large images or many k-space samples, FFTs and iterative reconstruction can be demanding. The current operator is suitable for 2D. For 3D THz imaging (e.g., tomographic reconstruction of a volume), a different operator (like Radon transform based) would be needed.