# Fluorescence Microscopy Deconvolution Demo
This notebook demonstrates TV-regularized deconvolution for a simulated fluorescence microscopy image. Deconvolution aims to reverse the blurring effect caused by the microscope's Point Spread Function (PSF).
The `FluorescenceMicroscopyOperator` models the blurring as a convolution with a known PSF. The `tv_deconvolution_fm` reconstructor then attempts to recover a sharper image.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Adjust path to import from reconlib 
import sys
# sys.path.append('../../../') # Adjust as needed

from reconlib.modalities.fluorescence_microscopy.operators import FluorescenceMicroscopyOperator, generate_gaussian_psf
from reconlib.modalities.fluorescence_microscopy.reconstructors import tv_deconvolution_fm
from reconlib.modalities.fluorescence_microscopy.utils import generate_fluorescence_phantom, plot_fm_results

print(f"PyTorch version: {torch.__version__}")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Setup Parameters, Phantom, and PSF
We'll work with a 2D example first.

In [None]:
image_shape_fm = (128, 128)      # (Ny, Nx) for the fluorescence image
is_3d_example = False # Set to True for a 3D deconvolution example
if is_3d_example:
    image_shape_fm = (64, 64, 32) # (Nz, Ny, Nx)

# Generate a phantom (true fluorescence distribution)
phantom_structure = 'cells' # 'cells' or 'filaments'
true_fluorescence_map = generate_fluorescence_phantom(
    image_shape=image_shape_fm, 
    num_structures=5 if not is_3d_example else 10,
    structure_type=phantom_structure,
    device=device
)

# Generate a Point Spread Function (PSF) - e.g., Gaussian
psf_shape = (11, 11) if not is_3d_example else (7, 7, 7) # Must be smaller than image
psf_sigma = 1.5 if not is_3d_example else (1.2, 1.2, 1.8) # Sigma for Gaussian PSF
psf = generate_gaussian_psf(shape=psf_shape, sigma=psf_sigma, device=device)

print(f"True map shape: {true_fluorescence_map.shape}, PSF shape: {psf.shape}")

# Display phantom and PSF (center slice if 3D)
fig, axes = plt.subplots(1, 2, figsize=(10, 5))
true_display = true_fluorescence_map.cpu().numpy()
psf_display = psf.cpu().numpy()
if is_3d_example:
    true_display = true_display[true_display.shape[0]//2, :,:]
    psf_display = psf_display[psf_display.shape[0]//2, :,:]
    axes[0].set_title(f'True Fluorescence (Z-slice {true_fluorescence_map.shape[0]//2})')
    axes[1].set_title(f'PSF (Z-slice {psf.shape[0]//2})')
else:
    axes[0].set_title('True Fluorescence Map')
    axes[1].set_title('Point Spread Function (PSF)')

im1 = axes[0].imshow(true_display, cmap='viridis')
fig.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)
im2 = axes[1].imshow(psf_display, cmap='viridis')
fig.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)
plt.show()

## 2. Initialize Operator and Simulate Observed (Blurred) Image

In [None]:
fm_operator = FluorescenceMicroscopyOperator(
    image_shape=image_shape_fm,
    psf=psf,
    device=device
)

# Simulate the observed (blurred) image: Y = PSF * X 
observed_blurred_image = fm_operator.op(true_fluorescence_map)

# Add some noise (e.g., Gaussian + Poisson, simplified to Gaussian here)
snr_db_fm = 30 # Signal-to-Noise Ratio in dB
signal_power_fm = torch.mean(observed_blurred_image**2)
noise_power_fm = signal_power_fm / (10**(snr_db_fm / 10))
noise_fm = torch.randn_like(observed_blurred_image) * torch.sqrt(noise_power_fm)
observed_image_noisy = observed_blurred_image + noise_fm
# Ensure non-negativity, common for fluorescence images
observed_image_noisy = torch.clamp(observed_image_noisy, min=0.0)

print(f"Observed (blurred & noisy) image shape: {observed_image_noisy.shape}")

plt.figure(figsize=(6,6))
obs_display = observed_image_noisy.cpu().numpy()
if is_3d_example:
    obs_display = obs_display[obs_display.shape[0]//2, :,:]
    plt.title(f'Observed Image (Z-slice, {snr_db_fm}dB SNR)')
else:
    plt.title(f'Observed Image ({snr_db_fm}dB SNR)')
plt.imshow(obs_display, cmap='viridis')
plt.colorbar(label='Intensity')
plt.show()

## 3. Perform TV-Regularized Deconvolution

In [None]:
lambda_tv_fm = 0.002     # TV regularization strength (CRITICAL - needs tuning!)
iterations_fm = 50       # Number of proximal gradient iterations 
step_size_fm = 0.01      # Step size for proximal gradient (CRITICAL!)
tv_prox_iters_fm = 5     # Iterations for TV prox

deconvolved_map = tv_deconvolution_fm(
    y_observed_image=observed_image_noisy,
    fm_operator=fm_operator,
    lambda_tv=lambda_tv_fm,
    iterations=iterations_fm,
    step_size=step_size_fm,
    tv_prox_iterations=tv_prox_iters_fm,
    verbose=True
)

print(f"Deconvolved map shape: {deconvolved_map.shape}")

## 4. Display Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
titles = ['True Fluorescence', 'Observed (Blurred & Noisy)', f'Deconvolved (TV, {iterations_fm} iters)']
maps_to_display = [true_fluorescence_map, observed_image_noisy, deconvolved_map]

for i, data_map in enumerate(maps_to_display):
    display_slice = data_map.cpu().numpy()
    slice_label = ''
    if is_3d_example:
        slice_num = data_map.shape[0]//2
        display_slice = display_slice[slice_num,:,:]
        slice_label = f' (Z-slice {slice_num})'
    
    im = axes[i].imshow(display_slice, cmap='viridis', vmin=0, vmax=true_fluorescence_map.max().cpu().item()) # Consistent scaling
    axes[i].set_title(titles[i] + slice_label)
    axes[i].set_xlabel('X'); axes[i].set_ylabel('Y')
    fig.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

# Using the utility plot function (currently a placeholder for detailed plotting)
plot_fm_results(
    true_map=true_fluorescence_map,
    observed_map=observed_image_noisy,
    deconvolved_map=deconvolved_map,
    slice_idx=image_shape_fm[0]//2 if is_3d_example else None
)

## 5. Further Considerations for Fluorescence Deconvolution
1.  **PSF Estimation**: Accurate knowledge of the PSF is crucial. If the PSF is unknown or varies spatially, blind deconvolution or PSF estimation techniques might be needed.
2.  **Noise Model**: Fluorescence microscopy data often follows Poisson or mixed Poisson-Gaussian statistics, especially in low-light conditions. Using a data fidelity term matched to the noise model (instead of the implicit L2 norm in this PGD) can improve results (e.g., Richardson-Lucy uses Poisson likelihood).
3.  **Regularization**: TV is a common choice. Other regularizers (e.g., Tikhonov, sparsity in wavelet domain) or more advanced methods like deep learning based deconvolution can also be used.
4.  **Parameter Tuning**: `lambda_tv` and `step_size` are highly sensitive. Incorrect values can lead to overly smooth or noisy results, or slow/no convergence.
5.  **Non-Negativity**: Fluorescence intensity is non-negative. While TV regularization and some initial estimates might preserve this, explicit non-negativity constraints can be beneficial if not inherently handled.