# TESTING NS FOR LINE FITTING

In [4]:
import numpyro as npro
npro.set_host_device_count(4)
import jax
jax.local_device_count()
import jax.numpy as jnp
import tqdm
import blackjax
import ravel
import pandas as pd
import os
from glob import glob
import matplotlib.pyplot as plt

In [38]:
from ravel import read_spectra, setup_star_directory_and_save_jds, setup_line_dictionary, initialize_fit_variables, rv_shift_wavelength
from ravel import gaussian, lorentzian, pseudo_voigt
import numpy as np
from scipy.interpolate import interp1d
from astropy.constants import c



In [73]:
# RAVEL.PY WORK ON fit_sb2_probmod
from tensorflow_probability.substrates import jax as tfp
import distrax


def fit_sb2_probmod(lines, wavelengths, fluxes, f_errors, lines_dic, Hlines, neblines, path, sigma_prior, K=2, shift_kms=0,
                    wavelength_type='air', rm_epochs=None, profile='Voigt', chi2_plots=False):
    """
    Fit SB2 (double-lined spectroscopic binary) spectral lines using a probabilistic
    model with Numpyro. The function interpolates spectral data onto a common grid,
    constructs a Bayesian model for the line profiles, and samples the posterior via
    MCMC (using NUTS). To improve robustness against switching, a second model is constructed
    that re-samples only the RV posteriors. The best-fitting result between both runs is 
    determined via a χ2 comparison.
    
    Parameters:
    -----------
    lines : list
        List of spectral line identifiers (keys from lines_dic) to be fitted.
    wavelengths : list
        List (per epoch) of wavelength arrays.
    fluxes : list
        List (per epoch) of flux arrays.
    f_errors : list
        List (per epoch) of flux error arrays.
    lines_dic : dict
        Dictionary containing spectral line regions, initial centre guesses, etc.
    Hlines : list
        List of lines (subset of `lines`) that are Hydrogen lines.
    neblines : list
        (Currently unused) List of nebular lines.
    path : str
        Path for storing output plots.
    sigma_prior : float
        Sigma on Gaussian priors for RV fitting in the second MCMC run [km/s].
    K : int, optional
        Number of components (default 2).
    shift_kms : float, optional
        The overall velocity shift in km/s. For example, use 172 km/s for the SMC.
    wavelength_type : str, optional
        Type of wavelength to use ('air' or 'vacuum'). Default is 'air'.
    rm_epochs : list, optional
        The indices (0-based) of the epochs to remove from the fitting.
    profile : str, optional ('Voigt' or 'Gaussian)
        Profile type used for all components in the spectral line fitting.
    chi2_plots : bool
        Enable/disable seperate plotting of first, second, and final (stitched) MCMC results.

    Returns:
    --------
    trace : dict
        The MCMC trace (posterior samples).
    x_waves : array (JAX)
        The interpolated wavelength grid for each line and epoch.
    y_fluxes : array (JAX)
        The interpolated fluxes.
    """
    n_lines = len(lines)
    n_epochs = len(wavelengths)
    print('Number of lines:', n_lines)
    print('Number of epochs:', n_epochs)

    # Determine the key to use based on the chosen wavelength type
    key = 'centre' if wavelength_type == 'vacuum' else 'air'

    # Boolean mask for Hydrogen lines (will use Lorentzian instead of Gaussian)
    is_hline = jnp.array([line in Hlines for line in lines])

    # Interpolate fluxes and errors to a common grid
    x_waves_interp = []
    y_fluxes_interp = []
    y_errors_interp = []
    common_grid_length = 200  # Choose a consistent number of points for interpolation

    for line in lines:
        region_start, region_end = lines_dic[line]['region']
        # Shift the region boundaries by shift_kms
        region_start = rv_shift_wavelength(region_start, shift_kms)
        region_end = rv_shift_wavelength(region_end, shift_kms)

        x_waves_line = []
        y_fluxes_line = []
        y_errors_line = []

        for wave_set, flux_set, error_set in zip(wavelengths, fluxes, f_errors):
            mask = (wave_set > region_start) & (wave_set < region_end)
            wave_masked = wave_set[mask]
            flux_masked = flux_set[mask]
            if error_set is not None:
                error_masked = error_set[mask]
            else:
                f_err = compute_flux_err(wave_set, flux_set)
                error_masked = f_err[mask]

            # Interpolate onto a common wavelength grid for this line and epoch
            common_wavelength_grid = np.linspace(wave_masked.min(), wave_masked.max(), common_grid_length)
            interp_flux = interp1d(wave_masked, flux_masked, bounds_error=False, fill_value="extrapolate")(common_wavelength_grid)
            interp_error = interp1d(wave_masked, error_masked, bounds_error=False, fill_value="extrapolate")(common_wavelength_grid)
            x_waves_line.append(common_wavelength_grid)
            y_fluxes_line.append(interp_flux)
            y_errors_line.append(interp_error)

        x_waves_interp.append(x_waves_line)
        y_fluxes_interp.append(y_fluxes_line)
        y_errors_interp.append(y_errors_line)

    # Convert the interpolated lists to JAX arrays (all dimensions now match)
    x_waves = jnp.array(x_waves_interp)       # Shape: (n_lines, n_epochs, common_grid_length)
    y_fluxes = jnp.array(y_fluxes_interp)       # Shape: (n_lines, n_epochs, common_grid_length)
    y_errors = jnp.array(y_errors_interp)       # Shape: (n_lines, n_epochs, common_grid_length)

    # Remove bad epochs along the second axis (axis=1)
    if rm_epochs is not None:
        x_waves = jnp.delete(x_waves, jnp.array(rm_epochs), axis=1)
        y_fluxes = jnp.delete(y_fluxes, jnp.array(rm_epochs), axis=1)
        y_errors = jnp.delete(y_errors, jnp.array(rm_epochs), axis=1)

    # Initial guess for the rest (central) wavelength from lines_dic
    cen_ini = jnp.array([lines_dic[line][key][0] for line in lines])

    # SET UP JAX PRIORS
    c_kms = c.to('km/s').value  
    nlines, nepochs, ndata = x_waves.shape

    # Set up priors
    tfd = tfp.distributions

    # Continuum
    # logσ_ε_prior = distrax.Uniform(-5,0)
    # σ_ε = jnp.exp(logσ_ε_prior)
    # ε_prior = tfd.TruncatedNormal(loc=1.0, scale=σ_ε, low=0.7, high=1.1)

    logσ_ε_prior = tfd.Uniform(low=-5.0, high=0.0)   # same support as before

    # Conditioned truncated normal for ε | logσ_ε
    def make_eps_prior_tfp(logσ_ε):
        return tfd.TruncatedNormal(loc=1.0, scale=jnp.exp(logσ_ε), low=0.7, high=1.1)

    # Define rest wavelengths as a parameter (one per line)
    λ_rest_prior = distrax.Independent(
    tfd.Normal(loc=cen_ini, scale = jnp.full(nlines, 1e-3)),
    reinterpreted_batch_ndims=1 # Does this make sense?
    )

    # AMPLITUDES (across all lines)
    amp0_prior = tfd.Independent(
    tfd.TruncatedNormal(loc=jnp.full(nlines, 0.18), # array of shape (nlines), defining the same dist.
                        scale=jnp.full(nlines, 0.06),
                        low=jnp.full(nlines, 0.02),
                        high=jnp.full(nlines, 0.40)),
                        1 # Absorb the above dimensions to evaluate as one log-prob (?)
    )

    amp_ratio_prior = tfd.Independent(
    tfd.TruncatedNormal(loc=jnp.full(nlines, 0.60), # array of shape (nlines), defining the same dist.
                        scale=jnp.full(nlines, 0.15),
                        low=jnp.full(nlines, 0.25),
                        high=jnp.full(nlines, 0.95)),
                        1 # Absorb the above dimensions to evaluate as one log-prob (?)
    )

    # WIDTHS (across all lines)
    # Gaussian
    wid_G1_prior = distrax.Independent(distrax.Uniform(low=jnp.full(nlines,0.5), high=jnp.full(nlines, 5.0)), 1)
    wid_L1_prior = distrax.Independent(distrax.Uniform(low=jnp.full(nlines,0.1), high=jnp.full(nlines, 3.0)), 1)
    # Lorentzian
    delta_wid_G_prior = distrax.Independent(distrax.Uniform(low=jnp.full(nlines,0.1), high=jnp.full(nlines, 2.0)), 1)
    delta_wid_L_prior = distrax.Independent(distrax.Uniform(low=jnp.full(nlines,0.05), high=jnp.full(nlines, 1.0)), 1)

    # RADIAL VELOCITIES (for each epcoh + component)
    σ_Δv = 200.
    comp_sep = 200.

    Δv_means  = jnp.array([shift_kms - comp_sep/2, shift_kms + comp_sep/2])   # shape (K,)
    Δv_loc   = jnp.tile(Δv_means[:, None], (1, nepochs))                               # (K,E)
    Δv_τk_prior = distrax.Independent(distrax.Normal(loc=Δv_loc, scale=jnp.full((K, nepochs), 200.0)), 2)

    def prior_sample(seed, sample_shape=()):
        k = jax.random.split(seed, 10)

        # TFP draws (truncated normals and uniform)
        logσ_ε  = logσ_ε_prior.sample(seed=k[0], sample_shape=sample_shape)                  # ()
        ε       = make_eps_prior_tfp(logσ_ε).sample(seed=k[1], sample_shape=sample_shape)        # ()
        amp0      = amp0_prior.sample(seed=k[2], sample_shape=sample_shape)                  # (L,)
        amp_ratio = amp_ratio_prior.sample(seed=k[3], sample_shape=sample_shape)             # (L,)

        # Distrax draws
        λ_rest = λ_rest_prior.sample(seed=k[4], sample_shape=sample_shape)                       # (L,)
        wid_G1 = wid_G1_prior.sample(seed=k[5], sample_shape=sample_shape)                       # (L,)
        wid_L1 = wid_L1_prior.sample(seed=k[6], sample_shape=sample_shape)                       # (L,)
        delta_wid_G = delta_wid_G_prior.sample(seed=k[7], sample_shape=sample_shape)                       # (L,)
        delta_wid_L = delta_wid_L_prior.sample(seed=k[8], sample_shape=sample_shape)                       # (L,)
        Δv_τk  = Δv_τk_prior.sample(seed=k[9], sample_shape=sample_shape)                        # (K,E)

        return {
            "logσ_ε": logσ_ε, "ε": ε, "λ_rest": λ_rest,
            "amp0": amp0, "amp_ratio": amp_ratio,
            "wid_G1": wid_G1, "wid_L1": wid_L1, "delta_wid_G": delta_wid_G, "delta_wid_L": delta_wid_L,
            "Δv_τk": Δv_τk,
        }

    def prior_log_prob(params):
        lp  = logσ_ε_prior.log_prob(params["logσ_ε"])
        # lp += make_eps_prior(params["logσ_ε"]).log_prob(params["ε"])
        lp += make_eps_prior_tfp(params["logσ_ε"]).log_prob(params["ε"])
        lp += amp0_prior.log_prob(params["amp0"])
        lp += amp_ratio_prior.log_prob(params["amp_ratio"])
        lp += λ_rest_prior.log_prob(params["λ_rest"])
        lp += wid1_G_prior.log_prob(params["wid_G1"])
        lp += wid1_L_prior.log_prob(params["wid_L1"])
        lp += dwid_G_prior.log_prob(params["delta_wid_G"])
        lp += dwid_L_prior.log_prob(params["delta_wid_L"])
        lp += Δv_τk_prior.log_prob(params["Δv_τk"])
        return lp


    # Prior over MATRIX of sampled variables, of size K,E - RV for each component, each epochs (across all lines)
    # Prior centered on Δv_means for each component, each epoch

    def flux_pred(params, lam_grid, is_hline, profile="Voigt"):
        """
        params: dict of JAX arrays
        - ε: scalar
        - λ_rest: (L,)
        - amp0, amp_ratio: (L,)
        - wid1_G, wid1_L, dwid_G, dwid_L: (L,)
        - Δv_τk: (K,E)   # component RVs per epoch
        lam_grid: (L,E,N)
        is_hline: (L,)
        returns: f_pred (L,E,N)
        """
        c_kms = c.to('km/s').value  
        L, E, N = lam_grid.shape
        ε = params["ε"]

        # amplitudes for the two components → (K,L,1,1)
        amp0 = params["amp0"]
        amp1 = params["amp_ratio"] * amp0
        A = jnp.stack([amp0, amp1], axis=0)[:, :, None, None]

        # widths → (K,L,1,1)
        wid_G1, wid_L1 = params["wid_G1"], params["wid_L1"]
        delta_wid_G, delta_wid_L = params["delta_wid_G"], params["delta_wid_L"]
        wid_G = jnp.stack([wid_G1, wid_G1+delta_wid_G], axis=0)[:, :, None, None]
        wid_L = jnp.stack([wid_L1, wid_L1+delta_wid_L], axis=0)[:, :, None, None]

        # centers per component & epoch
        lam0 = params["λ_rest"][None, :, None, None]              # (1,L,1,1)
        dv   = params["Δv_τk"][:, None, :, None]                  # (K,1,E,1)
        mu   = lam0 * (1.0 + dv / c_kms)                          # (K,L,E,1)

        lam  = lam_grid[None, :, :, :]                            # (1,L,E,N)
        isH  = is_hline[None, :, None, None]                      # (1,L,1,1)

        # your existing kernels must be JAX-friendly (no Python loops, no side effects)
        G  = gaussian(lam, A, mu, wid_G)                          # (K,L,E,N)
        Lr = lorentzian(lam, A, mu, wid_L)                        # (K,L,E,N)
        V  = pseudo_voigt(lam, A, mu, wid_G, wid_L)               # (K,L,E,N)

        if profile == "Voigt":
            comp = jnp.where(isH, Lr, V)
        elif profile == "Gaussian":
            comp = jnp.where(isH, Lr, G)
        else:
            raise ValueError("profile must be 'Voigt' or 'Gaussian'")

        f_pred = ε + comp.sum(axis=0)                           # (L,E,N)
        return f_pred

    def loglikelihood_fn(params):
        f_pred = flux_pred(params, x_waves, is_hline, profile)     # (L,E,N)
        resid  = (y_fluxes - f_pred) / y_errors
        log_like = -jnp.log(y_errors * jnp.sqrt(2*jnp.pi)) -0.5*resid**2
        return log_like
        
        
    print(f"\nFitting with profile: {profile}, Δv_means: {Δv_means}")

    # Initial population of live points
    rng_key  = jax.random.PRNGKey(0)
    n_live = 1200
    initial_population = prior_sample(rng_key, sample_shape=(n_live,))

    d_guess = (1 + 1) + 3*nlines + 4*nlines + (K*nepochs)

    algo = blackjax.nss(
        logprior_fn      = prior_log_prob,
        loglikelihood_fn = loglikelihood_fn,
        num_delete       = int(n_live*0.1),
        num_inner_steps  = 3 * d_guess
    )

    state = algo.init(initial_population)

    # NESTED SAMPLING LOOP
    from blackjax.ns.base import NSState
    from blackjax.ns.base import NSInfo

    @jax.jit
    def one_step(carry, xs):
        state, k = carry
        k, subk = jax.random.split(k, 2)
        state, dead_point = algo.step(subk, state)
        return (state, k), dead_point

    iteration_count = 0
    live_points_snapshots = []
    dead = []

    with tqdm.tqdm(desc="Dead points", unit=" dead points") as pbar:
        live_points_snapshots.append(jnp.array(state.particles[..., :2]))
        
        while not state.logZ_live - state.logZ < -5:
            iteration_count += 1
            (state, rng_key), dead_info = one_step((state, rng_key), None)
            dead.append(dead_info)
            pbar.update(num_delete)

            live_points_snapshots.append(jnp.array(state.particles[..., :2]))

    live_points_snapshots.append(jnp.array(state.particles[..., :2]))

    # POST-PROCESSING
    from blackjax.ns.utils import log_weights, finalise, sample, ess

    rng_key, weight_key, sample_key = jax.random.split(rng_key,3)
    final_state = finalise(state,dead)
    log_w = log_weights(weight_key, final_state, shape=100)
    samples = sample(sample_key, final_state, shape = n_live)
    ns_ess = ess(sample_key, final_state)
    logzs = jax.scipy.special.logsumexp(log_w, axis=0)

    final_state = finalise(state, dead)
    rng_key, sk = jax.random.split(rng_key)
    posterior   = sample(sk, final_state, shape=num_live)

    return final_state, posterior, live_points_snapshots




    # state = jax.jit(ns.init)(particles)
    # step  = jax.jit(ns.step)

    # dead = []
    # # Fixed-iteration loop for v1 (simple & robust). Improve later with a live-evidence stop rule.
    # for _ in range(1000):
    #     rng_key, subk = jax.random.split(rng_key)
    #     state, dead_info = step(subk, state)
    #     dead.append(dead_info)

    # # Finalise & draw posterior samples (dict; same keys as your params)
    # final_state = finalise(state, dead)
    # rng_key, sk = jax.random.split(rng_key)
    # posterior   = sample(sk, final_state, shape=num_live)

    # # Evaluate model for each posterior sample
    # def _predict_one(p): return flux_pred(p, x_waves, is_hline, profile)
    # f_post  = jax.vmap(_predict_one)(posterior)   # (S,L,E,N), S=num_live
    # Δv_post = posterior["Δv_τk"]                  # (S,K,E)

    # # Build your representative per-epoch model (same logic you already use)
    # S = f_post.shape[0]
    # n_sols = min(200, S)
    # model_result = jnp.empty_like(x_waves)

    # for e in range(E):
    #     rv1 = jnp.asarray(Δv_post[:, 0, e])                 # component-1 RVs at epoch e
    #     center, _, _ = summarize_mode_1d(np.asarray(rv1))   # your helper
    #     idx = np.argsort(np.abs(np.asarray(rv1) - center))[:n_sols]
    #     model_result = model_result.at[:, e, :].set(f_post[idx, :, e, :].mean(axis=0))

    # chi2_final = get_chi2(y_fluxes, model_result, y_errors)

    # return 

In [74]:
sys_path = 'SB2_sys019_snr30/'
spec_files = sorted(glob(os.path.join(sys_path, "*epoch??.txt")))
SB2 = True

out_dir_requested = os.path.join('Sample_fit_ns', sys_path)
os.makedirs(out_dir_requested, exist_ok=True)

Hlines = [4102, 4340, 4861, 6562]
print('*** SB2 set to:', 'True', '***\n')

# Read in spectral data from the provided file list and data path
wavelengths, fluxes, f_errors, names, jds = read_spectra(spec_files, sys_path, 'txt', SB2=True)

# Setup the output directory and save the JD information if available
out_path = setup_star_directory_and_save_jds(names, jds, sys_path, SB2)

# Get the dictionary with line regions and initial parameters
lines = [4026, 4144, 4388, 4471]
lines_dic = setup_line_dictionary()

# Verify that user‑requested lines exist in the dictionary
missing = [ln for ln in lines if ln not in lines_dic]
if missing:
    print("Error: Unknown spectral line identifier(s):", missing)
    print("Available lines are:", sorted(lines_dic.keys()))
    raise ValueError(f"Please choose from the available lines or add your own. Missing: {missing}")

print('\n*** Fitting lines ***')
print('---------------------')
print('Lines to be fitted:', lines)

# Initialize fit variables for each line (lists to hold fit results, uncertainties, etc.)
(cen1, cen1_er, amp1, amp1_er, wid1, wid1_er, 
    cen2, cen2_er, amp2, amp2_er, wid2, wid2_er, 
    dely, sdev, results, comps, delta_cen, chisqr) = initialize_fit_variables(lines)


# SB2 fitting: fit all lines using the probabilistic SB2 model and write results to CSV
# final_state, posterior, live_points_snapshots = fit_sb2_probmod(lines, wavelengths, fluxes, f_errors, lines_dic,
#                                             Hlines, neblines=[], out_path, K=2, shift_kms=0)

final_state, posterior, live_points_snapshots = fit_sb2_probmod(lines, wavelengths, fluxes, f_errors, lines_dic, Hlines, neblines=1, path=1, sigma_prior=1, K=2, shift_kms=0,
                    wavelength_type='air', rm_epochs=None, profile='Voigt', chi2_plots=False)


*** SB2 set to: True ***


*** Fitting lines ***
---------------------
Lines to be fitted: [4026, 4144, 4388, 4471]
Number of lines: 4
Number of epochs: 12

Fitting with profile: Voigt, Δv_means: [-100.  100.]


  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)
  return lax_numpy.astype(self, dtype, copy=copy, device=device)


ValueError: Incompatible shapes for broadcasting: shapes=[(1200,), (4, 12, 200)]

In [None]:
distrax.Independent