# Hyperspectral Imaging (HSI) Reconstruction Demo
This notebook demonstrates a basic pipeline for reconstructing a hyperspectral data cube from compressed measurements. This is common in Compressed HSI (e.g., CASSI) systems.
The `HyperspectralImagingOperator` uses a **sensing matrix `H`** to model the acquisition `y = Hx`, where `x` is the flattened HSI cube. Reconstruction uses 3D Total Variation (TV) to promote sparsity/smoothness in the spatial and spectral dimensions of the HSI cube.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Adjust path to import from reconlib 
import sys
# sys.path.append('../../../') # Adjust as needed

from reconlib.modalities.hyperspectral.operators import HyperspectralImagingOperator
from reconlib.modalities.hyperspectral.reconstructors import tv_reconstruction_hsi
from reconlib.modalities.hyperspectral.utils import generate_hsi_phantom, plot_hsi_results

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Setup Parameters and HSI Phantom

In [None]:
hsi_cube_shape = (32, 32, 16)  # (Ny, Nx, N_bands) - Small for demo
num_elements_hsi = np.prod(hsi_cube_shape)

# Generate a phantom HSI cube with some spectral features
true_hsi_cube = generate_hsi_phantom(
    image_shape=hsi_cube_shape, 
    num_features=3, 
    device=device
)

print(f"Generated HSI cube of shape: {true_hsi_cube.shape}")

# Define RGB bands for visualization (example: last, middle, first bands for R,G,B)
rgb_display_bands = (hsi_cube_shape[2]-1, hsi_cube_shape[2]//2, 0)

# Utility function to create an RGB image from HSI cube for plotting
def get_rgb_from_hsi(hsi_data, bands):
    if hsi_data is None: return None
    r_band = torch.clamp(hsi_data[..., bands[0]],0,1)
    g_band = torch.clamp(hsi_data[..., bands[1]],0,1)
    b_band = torch.clamp(hsi_data[..., bands[2]],0,1)
    # Normalize each band individually for better visualization if ranges vary widely
    r_band = (r_band - r_band.min()) / (r_band.max() - r_band.min() + 1e-6)
    g_band = (g_band - g_band.min()) / (g_band.max() - g_band.min() + 1e-6)
    b_band = (b_band - b_band.min()) / (b_band.max() - b_band.min() + 1e-6)
    return torch.stack([r_band, g_band, b_band], dim=-1).cpu().numpy()

plt.figure(figsize=(6,6))
plt.imshow(get_rgb_from_hsi(true_hsi_cube, rgb_display_bands))
plt.title(f'True HSI Cube (RGB: bands {rgb_display_bands})')
plt.xlabel('X (pixels)')
plt.ylabel('Y (pixels)')
plt.show()

## 2. Initialize Operator and Simulate Compressed Measurements
The `HyperspectralImagingOperator` uses a sensing matrix `H`. For compressed HSI, the number of measurements will be less than the total number of elements in the HSI cube.

In [None]:
# Number of compressed measurements (e.g., from a 2D detector in CASSI)
compression_ratio = 4 # Example: acquire 1/4 of the total data elements
num_measurements_hsi = num_elements_hsi // compression_ratio 
print(f"Total elements in HSI cube: {num_elements_hsi}")
print(f"Number of compressed measurements: {num_measurements_hsi}")

# Sensing matrix H (random placeholder - in real CASSI, H is structured and sparse)
sensing_matrix_H = torch.randn(num_measurements_hsi, num_elements_hsi, dtype=torch.float32, device=device) * 0.1
# Normalize rows of H for better conditioning (optional, but can help)
# sensing_matrix_H = sensing_matrix_H / torch.norm(sensing_matrix_H, dim=1, keepdim=True)

hsi_operator = HyperspectralImagingOperator(
    image_shape=hsi_cube_shape,
    sensing_matrix=sensing_matrix_H,
    device=device
)

# Simulate compressed measurements y = H * x_flattened
y_compressed_measurements = hsi_operator.op(true_hsi_cube)

# Add noise
snr_db_hsi = 25 # Signal-to-Noise Ratio in dB
signal_power_hsi = torch.mean(y_compressed_measurements**2)
noise_power_hsi = signal_power_hsi / (10**(snr_db_hsi / 10))
noise_hsi = torch.randn_like(y_compressed_measurements) * torch.sqrt(noise_power_hsi)
y_compressed_measurements_noisy = y_compressed_measurements + noise_hsi

print(f"Simulated measurement data shape: {y_compressed_measurements_noisy.shape}")
plt.figure(figsize=(7,4))
plt.plot(y_compressed_measurements_noisy.cpu().numpy()[:200]) # Plot first 200 measurements
plt.title('Sample of Noisy Compressed Measurements')
plt.xlabel('Measurement Index'); plt.ylabel('Value')
plt.show()

## 3. Perform Reconstruction
Using 3D Total Variation (TV) regularization with Proximal Gradient to reconstruct the HSI cube.

In [None]:
lambda_tv_hsi = 0.001     # TV regularization strength (critical, needs tuning!)
iterations_hsi = 100      # Iterations (more might be needed for HSI)
step_size_hsi = 1e-3      # Step size (critical, needs tuning!)
tv_prox_iters_hsi = 5    # Iterations for TV prox

reconstructed_hsi = tv_reconstruction_hsi(
    y_sensor_measurements=y_compressed_measurements_noisy,
    hsi_operator=hsi_operator,
    lambda_tv=lambda_tv_hsi,
    iterations=iterations_hsi,
    step_size=step_size_hsi,
    tv_prox_iterations=tv_prox_iters_hsi,
    verbose=True
)

print(f"Reconstructed HSI cube shape: {reconstructed_hsi.shape}")

## 4. Display Results

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

axes[0].imshow(get_rgb_from_hsi(true_hsi_cube, rgb_display_bands))
axes[0].set_title(f'True HSI Cube (RGB: {rgb_display_bands})')
axes[0].set_xlabel('X'); axes[0].set_ylabel('Y')

axes[1].imshow(get_rgb_from_hsi(reconstructed_hsi, rgb_display_bands))
axes[1].set_title(f'Recon. HSI (3D TV, {iterations_hsi} iters)')
axes[1].set_xlabel('X'); axes[1].set_ylabel('Y')

plt.tight_layout()
plt.show()

# Plot a few reconstructed bands vs true bands
num_bands_to_show = min(3, hsi_cube_shape[2])
fig, axes = plt.subplots(2, num_bands_to_show, figsize=(num_bands_to_show*4, 7))
for i in range(num_bands_to_show):
    band_idx = i * (hsi_cube_shape[2] // num_bands_to_show)
    ax_true = axes[0,i] if num_bands_to_show > 1 else axes[0]
    ax_recon = axes[1,i] if num_bands_to_show > 1 else axes[1]
    
    im_true = ax_true.imshow(true_hsi_cube[..., band_idx].cpu().numpy(), cmap='gray')
    ax_true.set_title(f'True - Band {band_idx}')
    fig.colorbar(im_true, ax=ax_true, fraction=0.046, pad=0.04)
    
    im_recon = ax_recon.imshow(reconstructed_hsi[..., band_idx].cpu().numpy(), cmap='gray')
    ax_recon.set_title(f'Recon - Band {band_idx}')
    fig.colorbar(im_recon, ax=ax_recon, fraction=0.046, pad=0.04)
plt.suptitle('Comparison of Selected Spectral Bands')
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

# Using the utility plot function (currently a placeholder itself for detailed plotting)
plot_hsi_results(
    true_hsi_cube=true_hsi_cube,
    recon_hsi_cube=reconstructed_hsi,
    measurement_data=y_compressed_measurements_noisy,
    rgb_bands=rgb_display_bands
)

## 5. Further Considerations for HSI Reconstruction
1.  **Realistic Sensing Matrix `H`**: For systems like CASSI, `H` is sparse and structured, determined by the coded apertures and disperser. A random matrix is a simplification.
2.  **Regularization**: 3D TV is a good baseline. More advanced regularizers exploiting spatial-spectral correlations (e.g., dictionary learning, low-rank tensor methods, total generalized variation) can significantly improve HSI reconstruction.
3.  **Parameter Tuning**: `lambda_tv` and `step_size` are highly sensitive and crucial for good results. They often require careful tuning for specific datasets or sensing matrices.
4.  **Computational Cost**: Reconstructing large HSI cubes, especially with many iterations or complex regularizers, can be computationally intensive.