# Walkthrough: Geometric Scattering Analysis (`run_scat_analysis.py`)

This notebook provides a step-by-step visual demonstration of the FLITS **Geometric Scattering Pipeline**. 

We will interactively perform exactly what `run_scat_analysis.py` does under the hood:
1.  **Load Data**: Read filterbank data for a burst.
2.  **Preprocessing**: Dedispersion and downsampling.
3.  **DM Refinement**: Optimize the dispersion measure using phase-amplitude structure.
4.  **Modeling**: Construct the Pulse Broadening Function (PBF) model.
5.  **Fitting**: Run the MCMC sampler to fit scattering parameters ($	au$, $\alpha$).
6.  **Diagnostics**: Visualize the residuals and corner plots.

This allows you to verify the pipeline's logic and inspect intermediate data products.

In [None]:
import os
import yaml
import numpy as np
import matplotlib.pyplot as plt
import astropy.units as u

# Import FLITS Scattering modules
from scattering.scat_analysis import load_data, preprocess
from scattering.scat_analysis.models import PulseBroadeningModel
from scattering.scat_analysis.likelihood import GaussianLikelihood
from scattering.scat_analysis.dm_refinement import maximize_structure_dm

# Set plotting style
plt.style.use('seaborn-v0_8-darkgrid')
%matplotlib inline

## 1. Configuration & Data Loading

We start by loading a standard configuration file used by the pipeline. For this demo, we'll use **Freya** (or another available burst) as an example.

In [None]:
# Path to a config file
config_path = 'scattering/configs/dsa/freya_dsa.yaml'

# Load the YAML config
with open(config_path, 'r') as f:
    cfg = yaml.safe_load(f)

print(f"Target Burst: {cfg['burst_name']}")
print(f"Data Path: {cfg['data_path']}")
print(f"Initial DM: {cfg['dm_init']}")

In [None]:
# Load the raw filterbank data
# (Note: This might take a few seconds depending on file size)
raw_data, metadata = load_data.load_filterbank(
    cfg['data_path'], 
    t_start=cfg.get('t_start', 0), 
    t_duration=cfg.get('t_duration', None)
)

print(f"Raw Data Shape: {raw_data.shape} (Time x Freq)")
print(f"Frequency Range: {metadata['fch1']} - {metadata['fch1'] + metadata['nchans']*metadata['df']} MHz")

# Visualize Raw Data (Waterfall)
plt.figure(figsize=(10, 6))
plt.imshow(raw_data.T, aspect='auto', origin='lower', cmap='viridis', interpolation='nearest')
plt.colorbar(label='Intensity')
plt.title("Raw Filterbank Data (Dispersed)")
plt.xlabel("Time Samples")
plt.ylabel("Frequency Channels")
plt.show()

## 2. Dedispersion & Preprocessing

To analyze the pulse profile, we must remove the dispersive delay caused by the ISM. We use the initial DM from the config.

In [None]:
# Dedisperse
dm_init = cfg['dm_init']
dedispersed_data = preprocess.dedisperse(raw_data, metadata, dm_init)

# Crop to relevant time window (zoom in on pulse)
# The preprocessor usually centers the pulse or uses config window
window_size = 1000 # samples
center_idx = np.argmax(np.mean(dedispersed_data, axis=1))
start = max(0, center_idx - window_size//2)
end = min(dedispersed_data.shape[0], center_idx + window_size//2)

cropped_data = dedispersed_data[start:end, :]
time_axis = np.arange(cropped_data.shape[0]) * metadata['dt'] * 1e3 # ms

# Visualize Dedispersed Pulse
plt.figure(figsize=(10, 8))
plt.subplot(2, 1, 1)
plt.imshow(cropped_data.T, aspect='auto', extent=[time_axis[0], time_axis[-1], metadata['fmin'], metadata['fmax']], origin='lower')
plt.title(f"Dedispersed Data (DM={dm_init})")
plt.ylabel("Freq (MHz)")

plt.subplot(2, 1, 2)
profile = np.mean(cropped_data, axis=1)
plt.plot(time_axis, profile)
plt.title("Frequency-Averaged Profile")
plt.xlabel("Time (ms)")
plt.tight_layout()
plt.show()

## 3. DM Refinement (Optional but Critical)

Errors in DM can look like scattering (smearing the pulse). The pipeline optionally refines DM by maximizing the pulse "structure" or sharpness.

In [None]:
should_refine = cfg.get('refine_dm', False)

if True: # Always show for demo
    print("Running DM Refinement Scan...")
    # Scan range: +/- 5 units around DM_init
    dm_grid = np.linspace(dm_init - 1.0, dm_init + 1.0, 50)
    structure_metrics = []
    
    for dm_trial in dm_grid:
        # Dedisperse with trial DM
        tmp_data = preprocess.dedisperse(raw_data, metadata, dm_trial)
        # Compute structure metric (e.g. sum of squares of profile)
        prof = np.mean(tmp_data, axis=1)
        metric = np.sum(prof**2)
        structure_metrics.append(metric)
        
    best_dm_idx = np.argmax(structure_metrics)
    best_dm = dm_grid[best_dm_idx]
    
    plt.figure()
    plt.plot(dm_grid, structure_metrics, '-o')
    plt.axvline(best_dm, color='r', linestyle='--', label=f'Best DM={best_dm:.3f}')
    plt.xlabel("DM")
    plt.ylabel("Structure Metric")
    plt.legend()
    plt.title("DM Refinement")
    plt.show()
    
    # Update data with best DM
    final_data = preprocess.dedisperse(raw_data, metadata, best_dm)
    final_data = final_data[start:end, :]

## 4. Modeling: The Pulse Broadening Function (PBF)

The core physics is the convolution model:
$$ I(t, \nu) = (S(t) \ast \text{PBF}(t, \nu)) + N(t) $$

Where the PBF is an exponential decay whose timescale $\tau$ depends on frequency:
$$ \tau(\nu) = \tau_{\text{ref}} \left( \frac{\nu}{\nu_{\text{ref}}} \right)^{-\alpha} $$

Let's visualize this model for specific parameters.

In [None]:
# Initialize Model
model = PulseBroadeningModel(
    n_freq=final_data.shape[1],
    n_time=final_data.shape[0],
    freqs=np.linspace(metadata['fch1'], metadata['fch1'] + metadata['nchans']*metadata['df'], metadata['nchans']),
    dt=metadata['dt']
)

# Define parameters for visualization
params = {
    'tau_ref': 2.0,       # ms at reference freq
    'alpha': 4.0,         # Scattering index
    'intrinsic_width': 0.5, # ms
    't0': 5.0,            # arrival time offset (ms)
    'amplitude': 10.0
}

# Generate Model Waterfall
model_waterfall = model.generate(params)

plt.figure(figsize=(10, 6))
plt.imshow(model_waterfall.T, aspect='auto', origin='lower', 
           extent=[time_axis[0], time_axis[-1], metadata['fmin'], metadata['fmax']],
           cmap='inferno')
plt.colorbar(label='Model Intensity')
plt.title(f"Scattering Model (τ={params['tau_ref']}ms, α={params['alpha']})")
plt.xlabel("Time (ms)")
plt.ylabel("Frequency (MHz)")
plt.show()

## 5. Fitting (Interactive MCMC Demo)

The full script uses `emcee` to sample the posterior. Here, we'll demonstrate the likelihood calculation and a simple optimization.

In [None]:
# Likelihood Function
likelihood = GaussianLikelihood(data=final_data, model=model)

def ln_prob(theta):
    # Unpack parameters (simplified)
    tau, alpha = theta
    p = params.copy()
    p['tau_ref'] = tau
    p['alpha'] = alpha
    
    try:
        return likelihood.evaluate(p)
    except ValueError:
        return -np.inf

# "Fit" by scanning (simplified version of MCMC)
tau_scan = np.linspace(0.1, 5.0, 20)
log_probs = []

for t in tau_scan:
    log_probs.append(ln_prob([t, 4.0]))

plt.figure()
plt.plot(tau_scan, log_probs)
plt.xlabel("Tau (ms)")
plt.ylabel("Log Probability")
plt.title("Likelihood Slice (Assuming Alpha=4.0)")
plt.show()

## 6. Examining Results (From Saved Run)

Finally, `run_scat_analysis.py` saves the full MCMC chains and results. We can load a results database to see fits.

In [None]:
# (Mockup of loading results)
best_fit_params = {
    'tau_ref': 2.8, 
    'alpha': 3.9,
    'dm': best_dm
}

print("Final Best Fit Parameters:")
for k, v in best_fit_params.items():
    print(f"  {k}: {v}")

# Residual Plot
best_model = model.generate({**params, **best_fit_params})
residuals = final_data - best_model

plt.figure(figsize=(10, 4))
plt.imshow(residuals.T, aspect='auto', origin='lower', cmap='RdBu_r')
plt.colorbar(label='Residuals')
plt.title("Best Fit Residuals")
plt.show()