# Seismic Imaging Reconstruction Example
This notebook demonstrates basic 2D seismic imaging reconstruction from simulated seismic traces. It uses a ray-based forward model incorporating a source wavelet and geometrical spreading, and Total Variation (TV) regularized reconstruction to estimate a subsurface reflectivity map.

## 1. Imports and Setup

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Ensure reconlib is in the Python path
if 'reconlib' not in os.getcwd():
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '../../..')))
else:
    sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname('__file__'), '..')))

from reconlib.modalities.seismic.operators import SeismicForwardOperator
from reconlib.modalities.seismic.reconstructors import tv_reconstruction_seismic

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Define Seismic Survey Parameters
We define parameters for the subsurface map, seismic wave, source/receiver geometry, source wavelet, and reconstruction.

In [None]:
# Subsurface Reflectivity Map Parameters
Nz, Nx = 64, 128  # Depth pixels (Nz), Horizontal pixels (Nx)
reflectivity_map_shape = (Nz, Nx)
pixel_spacing_val = 1.0  # Pixel spacing in meters (dz = dx = 1.0m)
max_depth_m = Nz * pixel_spacing_val
survey_width_m = Nx * pixel_spacing_val

# Seismic Wave & Recording Parameters
wave_speed_mps = 2000.0  # Average seismic wave speed in m/s
dt_s_param = 0.002 # Time step for seismic traces (2 ms)
max_record_time_s = 2 * max_depth_m / wave_speed_mps * 1.8 # Max time based on depth, with margin
num_time_samples_val = int(np.ceil(max_record_time_s / dt_s_param))
print(f"Max recording time: {max_record_time_s:.2f} s, Number of time samples: {num_time_samples_val}")

# Source Wavelet Parameters
def ricker_wavelet_nb(peak_freq, dt, num_samples, device='cpu'):
    if num_samples % 2 == 0: num_samples += 1 # Ensure odd length for symmetry
    t_np = (np.arange(num_samples) - num_samples // 2) * dt
    t_scaled_np = t_np * peak_freq * np.pi
    y_np = (1.0 - 2.0 * t_scaled_np**2) * np.exp(-t_scaled_np**2)
    return torch.tensor(y_np, dtype=torch.float32, device=device)

wavelet_len_param = 63 # Number of samples in the wavelet
wavelet_peak_freq_param = 25.0 # Peak frequency of the Ricker wavelet in Hz
sim_wavelet = ricker_wavelet_nb(wavelet_peak_freq_param, dt_s_param, wavelet_len_param, device=device)
sim_wavelet_offset_s = ((wavelet_len_param -1) // 2) * dt_s_param # Time offset for wavelet peak
use_spreading_param = True # Apply geometrical spreading

plt.figure(figsize=(5,2))
plt.plot( (np.arange(wavelet_len_param) - wavelet_len_param // 2) * dt_s_param , sim_wavelet.cpu().numpy())
plt.title(f'Source Wavelet ({wavelet_peak_freq_param} Hz Ricker)')
plt.xlabel('Time (s)'); plt.ylabel('Amplitude'); plt.grid(True); plt.show()

# Source and Receiver Geometry
source_x_m_param = survey_width_m / 2 # Source at center of survey line
source_z_m_param = 0.0 # Source at surface
source_pos_m_param = (source_x_m_param, source_z_m_param)

num_receivers_param = 64
receiver_x_coords_m_param = torch.linspace(survey_width_m * 0.05, survey_width_m * 0.95, num_receivers_param, device=device)
receiver_z_coords_m_param = torch.full_like(receiver_x_coords_m_param, 0.0) # Receivers at surface
receiver_pos_m_param = torch.stack((receiver_x_coords_m_param, receiver_z_coords_m_param), dim=1)

# Reconstruction Parameters
lambda_tv_seismic = 0.01
seismic_pg_iterations = 30
seismic_pg_step_size = 0.005
seismic_tv_prox_iters = 5
seismic_tv_prox_step = 0.01

## 3. Create Subsurface Reflectivity Phantom

In [None]:
def generate_seismic_phantom(shape, device='cpu'):
    # shape: (Nz, Nx)
    phantom = torch.zeros(shape, dtype=torch.float32, device=device)
    nz, nx = shape
    
    # Horizontal layers
    phantom[nz // 3, :] = 0.5
    phantom[nz // 2, :] = -0.3 # Negative reflectivity contrast
    phantom[nz * 3 // 4, int(nx*0.1):int(nx*0.7)] = 0.8 # Discontinuous layer
    
    # Dipping reflector (simple fault-like structure)
    for i in range(int(nx * 0.4), int(nx * 0.8)):
        depth_idx = int(nz * 0.2 + (i - int(nx*0.4)) * 0.5)
        if 0 <= depth_idx < nz:
            phantom[depth_idx, i] = 0.6
            if depth_idx + 1 < nz : phantom[depth_idx+1, i] = 0.6 # Thicken it a bit
    return phantom

seismic_phantom_reflectivity = generate_seismic_phantom(reflectivity_map_shape, device=device)

plt.figure(figsize=(10, 5))
plt.imshow(seismic_phantom_reflectivity.cpu().numpy(), cmap='Greys', aspect='auto', 
           extent=[0, survey_width_m, max_depth_m, 0])
plt.title('Original Subsurface Reflectivity Phantom')
plt.xlabel('Horizontal Distance (m)')
plt.ylabel('Depth (m)')
plt.colorbar(label='Reflectivity Contrast')
plt.scatter([source_pos_m_param[0]], [source_pos_m_param[1]], c='red', marker='*', s=100, label='Source')
plt.scatter(receiver_pos_m_param[:,0].cpu(), receiver_pos_m_param[:,1].cpu(), c='blue', marker='v', s=30, label='Receivers')
plt.legend()
plt.show()

## 4. Initialize Seismic Forward Operator
The operator is initialized with the defined survey geometry, wavelet, and spreading options.

In [None]:
seismic_operator_inst = SeismicForwardOperator(
    reflectivity_map_shape=reflectivity_map_shape,
    wave_speed_mps=wave_speed_mps,
    time_sampling_dt_s=dt_s_param,
    num_time_samples=num_time_samples_val,
    source_pos_m=source_pos_m_param,
    receiver_pos_m=receiver_pos_m_param,
    pixel_spacing_m=pixel_spacing_val,
    source_wavelet=sim_wavelet,
    wavelet_time_offset_s=sim_wavelet_offset_s,
    apply_geometrical_spreading=use_spreading_param,
    device=device
)
print("SeismicForwardOperator initialized with wavelet and spreading.")

## 5. Simulate Seismic Traces (Seismogram)

In [None]:
print("Simulating seismic traces... This might take a moment.")
seismic_traces_clean = seismic_operator_inst.op(seismic_phantom_reflectivity)
print(f"Simulated clean seismic traces shape: {seismic_traces_clean.shape}")

# Add Gaussian noise
signal_mean_abs_seismic = torch.mean(torch.abs(seismic_traces_clean))
noise_level_seismic = 0.1 # 10% noise relative to mean signal magnitude
noise_std_seismic = noise_level_seismic * signal_mean_abs_seismic 
noise_seismic = noise_std_seismic * torch.randn_like(seismic_traces_clean)
seismic_traces_noisy = seismic_traces_clean + noise_seismic
print(f"Added Gaussian noise. Noise STD: {noise_std_seismic.item() if noise_std_seismic > 0 else 0.0}")

plt.figure(figsize=(10, 6))
plt.imshow(seismic_traces_noisy.cpu().numpy(), aspect='auto', cmap='seismic', 
           extent=[0, max_record_time_s, num_receivers_param, 0], 
           vmin=-torch.quantile(torch.abs(seismic_traces_noisy),0.95).cpu(), vmax=torch.quantile(torch.abs(seismic_traces_noisy),0.95).cpu())
plt.title('Noisy Seismic Traces (Seismogram)')
plt.xlabel('Time (s)')
plt.ylabel('Receiver Index')
plt.colorbar(label='Amplitude')
plt.show()

## 6. Subsurface Image Reconstruction

### 6.1 Adjoint Reconstruction (Migration)

In [None]:
print("Performing Adjoint (Migration) reconstruction...")
seismic_migrated_image = seismic_operator_inst.op_adj(seismic_traces_noisy)
print(f"Migrated image shape: {seismic_migrated_image.shape}")

plt.figure(figsize=(10, 5))
plt.imshow(seismic_migrated_image.cpu().numpy(), cmap='Greys', aspect='auto',
           extent=[0, survey_width_m, max_depth_m, 0])
plt.title('Adjoint Seismic Reconstruction (Migrated Image)')
plt.xlabel('Horizontal Distance (m)')
plt.ylabel('Depth (m)')
plt.colorbar(label='Amplitude')
plt.show()

### 6.2 TV Regularized Reconstruction

In [None]:
print(f"Performing TV Regularized Seismic Reconstruction (lambda_TV={lambda_tv_seismic})...This may take some time.")

seismic_tv_recon_map = tv_reconstruction_seismic(
    y_seismic_traces=seismic_traces_noisy,
    seismic_operator=seismic_operator_inst,
    lambda_tv=lambda_tv_seismic,
    iterations=seismic_pg_iterations,
    step_size=seismic_pg_step_size,
    tv_prox_iterations=seismic_tv_prox_iters,
    tv_prox_step_size=seismic_tv_prox_step,
    verbose=True
)
print(f"TV Reconstructed Seismic Map shape: {seismic_tv_recon_map.shape}")

plt.figure(figsize=(10, 5))
plt.imshow(seismic_tv_recon_map.cpu().numpy(), cmap='Greys', aspect='auto',
           extent=[0, survey_width_m, max_depth_m, 0])
plt.title(f'TV Regularized Seismic Recon (lambda={lambda_tv_seismic}, {seismic_pg_iterations} iters)')
plt.xlabel('Horizontal Distance (m)')
plt.ylabel('Depth (m)')
plt.colorbar(label='Reflectivity')
plt.show()

## 7. Comparison of Results

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
fig.suptitle('Seismic Reconstruction Comparison', fontsize=16)
plot_kwargs = {'cmap':'Greys', 'aspect':'auto', 'extent':[0, survey_width_m, max_depth_m, 0]}

im0 = axes[0].imshow(seismic_phantom_reflectivity.cpu().numpy(), **plot_kwargs)
axes[0].set_title('Original Reflectivity Phantom')
axes[0].set_xlabel('Distance (m)'); axes[0].set_ylabel('Depth (m)')
fig.colorbar(im0, ax=axes[0], shrink=0.8, label='Reflectivity')

im1 = axes[1].imshow(seismic_migrated_image.cpu().numpy(), **plot_kwargs)
axes[1].set_title('Adjoint Recon (Migrated)')
axes[1].set_xlabel('Distance (m)'); axes[1].set_ylabel('Depth (m)')
fig.colorbar(im1, ax=axes[1], shrink=0.8, label='Amplitude')

im2 = axes[2].imshow(seismic_tv_recon_map.cpu().numpy(), **plot_kwargs)
axes[2].set_title(f'TV Regularized Recon (lambda={lambda_tv_seismic})')
axes[2].set_xlabel('Distance (m)'); axes[2].set_ylabel('Depth (m)')
fig.colorbar(im2, ax=axes[2], shrink=0.8, label='Reflectivity')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()