In [4]:
# --- Celda 1: imports y helpers de IO ---
import os
import pickle
import numpy as np
from cluster_analysis import ClusterAnalyzer
from tqdm.notebook import tqdm
import pytensor.tensor as pt
from types import MethodType

# helpers
def save_npz(path, **arrays):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    np.savez_compressed(path, **arrays)

def save_pkl(path, obj):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)

def load_pkl(path):
    with open(path, "rb") as f:
        return pickle.load(f)

# chequeos rápidos de reproducibilidad
def fingerprint(arr, k=5):
    """Hash simple: primeros k valores y shape/dtype."""
    arr = np.asarray(arr).ravel()
    head = arr[:k]
    return dict(shape=arr.shape, dtype=str(arr.dtype),
                head=np.array(head, copy=False))

In [3]:
# --- Celda 2: RNG y observaciones ---
rng = np.random.default_rng(42)  # semilla fija

In [4]:
file_path = "/Users/notluquis/Library/Mobile Documents/com~apple~CloudDocs/Investigación/COSMIC/COSMIC/data.ecsv"

ca = ClusterAnalyzer(file_path)

DataLoader initialized with file path: /Users/notluquis/Library/Mobile Documents/com~apple~CloudDocs/Investigación/COSMIC/COSMIC/data.dill
Masked data handled and replaced with NaN.


In [6]:
ca.clusters_summary(include_noise=True)

Unnamed: 0,cluster,count,fraction,persistence,mean_prob,median_prob,min_prob,max_prob,iqr_prob,centroid_pmra,centroid_pmdec,mean_dist2centroid,std_dist2centroid,pmra_range,pmdec_range
0,12,331,1.0,,0.780547,0.807082,0.500444,1.0,0.317599,2.514274,-1.699468,0.221924,0.102786,0.903669,0.733876


In [6]:
import asteca
isochs = asteca.Isochrones(isochs_path="/Users/notluquis/Library/Mobile Documents/com~apple~CloudDocs/Investigación/NGC6383/MIST",
                           model='MIST',
                           magnitude="Gaia_G_EDR3",
                           magnitude_effl=6390.7,
                           color=("Gaia_BP_EDR3","Gaia_RP_EDR3"),
                           color_effl = (5182.6,7825.1),
                           #color2 = ("Gaia_RP_EDR3","2MASS_J"),
                           #color2_effl = (7825.1, 12375.60),
                          )


Instantiating isochrones



KeyboardInterrupt


KeyboardInterrupt



In [None]:
df = ca.data

In [None]:
my_cluster = asteca.Cluster(
    ra=df["ra"],
    dec=df["dec"],
    pmra=df["pmra"],
    pmde=df["pmdec"],
    plx=df["parallax"],
    e_pmra=df["pmra_error"],
    e_pmde=df["pmdec_error"],
    e_plx=df["parallax_error"],
    magnitude=df['Gmag'],
    e_mag=df["e_Gmag"],
    color=df["G_BPmag"] - df["G_RPmag"],
    e_color=df['e_BP_RP'],
    #color2=df["G_RPmag"] - df["j_m"],
    #e_color2 = df['e_RP_J']
)

In [None]:
synthcl = asteca.Synthetic(isochs)

synthcl.calibrate(my_cluster)

In [None]:
# --- mover isócrona (PT) ---
def move_isochrone_pt(iso_np, m_ini_idx, dm, binar_flag=True):
    """
    iso_np: np.ndarray con forma (Nd, Ni)
    m_ini_idx: índice entero de masa inicial
    dm: float
    """
    # asegurar tipos correctos
    m_ini_idx = int(m_ini_idx)          # <- clave
    Xm  = pt.as_tensor_variable(np.asarray(iso_np, dtype=np.float64))
    dm_ = pt.as_tensor_variable(float(dm), dtype="float64")

    # magnitud principal
    Xm = pt.set_subtensor(Xm[0, :], Xm[0, :] + dm_)

    # magnitud binaria (si corresponde y existe esa fila)
    if binar_flag:
        idx_bin_mag = m_ini_idx + 2
        if idx_bin_mag < iso_np.shape[0]:
            Xm = pt.set_subtensor(Xm[idx_bin_mag, :], Xm[idx_bin_mag, :] + dm_)

    # evaluar a numpy
    return np.asarray(Xm.eval(), dtype=np.float64)

def cut_max_mag_pt(iso_np: np.ndarray, max_mag_syn: float) -> np.ndarray:
    """
    Versión PyTensor-friendly del corte por magnitud máxima.
    Mantengo el enmascarado final en NumPy para evitar lógica booleana simbólica.
    """
    # simbólico (no imprescindible, pero coherente)
    X = pt.as_tensor_variable(iso_np)
    mag = X[0, :]                              # (N,)
    msk = pt.lt(mag, max_mag_syn)              # boolean simbólico
    msk_np = np.asarray(msk.eval(), dtype=bool)
    return iso_np[:, msk_np]

def extinction_pt(
    ext_law: str,
    ext_coefs: list,
    rand_norm: np.ndarray,
    rand_unif: np.ndarray,
    DR_distribution: str,
    m_ini_idx: int,
    binar_flag: bool,
    Av: float,
    dr: float,
    Rv: float,
    iso_np: np.ndarray,
) -> np.ndarray:
    """
    Extinción en PyTensor para CCMO y GAIADR3 (incluye DR uniform/normal).
    Para GAIADR3 usa la misma función 'dustapprox' del módulo original
    (se mantiene el resultado numéricamente idéntico).
    """
    from asteca.modules import synth_cluster_priv as scp

    # índices SIEMPRE como enteros Python
    m_ini_idx = int(m_ini_idx)
    idx_bin_mag = m_ini_idx + 2
    idx_bin_c1  = m_ini_idx + 3
    idx_bin_c2  = m_ini_idx + 4

    # Trabajamos sobre copia numpy (la función original modifica in-place)
    iso = np.array(iso_np, copy=True)
    Ns  = iso.shape[-1]

    # ----- DR (differential reddening) -----
    if dr > 0.0:
        if DR_distribution == "uniform":
            # en el original: (2*U-1)*dr, clip en Av>=0
            dr_arr = (2.0 * rand_unif[:Ns] - 1.0) * dr
        elif DR_distribution == "normal":
            dr_arr = rand_norm[:Ns] * dr
        else:
            raise ValueError(f"Unknown DR_distribution: {DR_distribution}")
        Av_dr = np.clip(Av + dr_arr, a_min=0.0, a_max=np.inf)
    else:
        # sin DR: escalar (o vector constante)
        Av_dr = Av

    # ----- Coeficientes de extinción -----
    if ext_law == "CCMO":
        # Magnitud
        ec_mag  = ext_coefs[0][0] + ext_coefs[0][1] / Rv
        # Primer color
        ec_col1 = (ext_coefs[1][0][0] + ext_coefs[1][0][1] / Rv) - \
                  (ext_coefs[1][1][0] + ext_coefs[1][1][1] / Rv)
        # Segundo color (opcional)
        has_c2  = len(ext_coefs) > 2
        if has_c2:
            ec_col2 = (ext_coefs[2][0][0] + ext_coefs[2][0][1] / Rv) - \
                      (ext_coefs[2][1][0] + ext_coefs[2][1][1] / Rv)

    elif ext_law == "GAIADR3":
        # Usa la misma aproximación del código original
        # Retorna coeficientes efectivos (dependen de color BP-RP y Av_dr)
        # Firma original: ec_mag, ec_col1 = dustapprox(BP_RP, Av_dr)
        BPmRP = iso[1]  # primer color
        ec_mag, ec_col1 = scp.dustapprox(BPmRP, Av_dr)
        has_c2  = False  # el modelo DR3 afecta "color2" vía CCMO (no está definido aquí)
    else:
        raise ValueError(f"Unknown extinction law: {ext_law}")

    # A_x y E_x
    Ax  = ec_mag  * Av_dr
    Ex1 = ec_col1 * Av_dr

    # Asegura forma vector si corresponde
    if np.ndim(Ax) == 0:
        Ax  = np.full(Ns, Ax, dtype=float)
    if np.ndim(Ex1) == 0:
        Ex1 = np.full(Ns, Ex1, dtype=float)

    # Aplicar a magnitud y primer color
    iso[0] += Ax
    iso[1] += Ex1

    # Binario (si corresponde y existen columnas)
    if binar_flag:
        if idx_bin_mag < iso.shape[0]:
            iso[idx_bin_mag] += Ax
        if idx_bin_c1  < iso.shape[0]:
            iso[idx_bin_c1] += Ex1

    # Segundo color, solo si CCMO trae coef y la isocrona lo tiene
    if ext_law == "CCMO" and has_c2 and iso.shape[0] > 2:
        Ex2 = ec_col2 * Av_dr
        if np.ndim(Ex2) == 0:
            Ex2 = np.full(Ns, Ex2, dtype=float)
        iso[2] += Ex2
        if binar_flag and idx_bin_c2 < iso.shape[0]:
            iso[idx_bin_c2] += Ex2

    return iso

In [None]:
def generate_2(self, params: dict, N_stars: int = 100) -> np.ndarray:
    from asteca.modules import synth_cluster_priv as scp

    try:
        # Del cluster calibrado
        max_mag_syn   = self.max_mag_syn_obs
        N_synth_stars = self.N_stars_obs
        err_dist_synth = self.err_dist_obs
    except AttributeError:
        max_mag_syn   = np.inf
        N_synth_stars = int(N_stars)
        err_dist_synth = []

    # Parámetros efectivos y pesos de (met, loga)
    met, loga, alpha, beta, av, dr, rv, dm, ml, mh, al, ah = scp.properModel(
        self.met_age_dict, self.def_params, params
    )

    # Isochrone fija o promedio ponderado
    if ml == al == mh == ah == 0:
        isochrone = np.array(self.theor_tracks[0][0])
    else:
        isochrone = scp.zaWAverage(
            self.theor_tracks, self.met_age_dict, self.m_ini_idx,
            met, loga, ml, mh, al, ah
        )

    # Flag de binariedad
    binar_flag = not (alpha == 0.0 and beta == 0.0)

    # (1) mover por distancia -> PyTensor
    isoch_moved = move_isochrone_pt(
        isochrone, int(self.m_ini_idx), float(dm), binar_flag
    )

    # (2) extinción -> PyTensor (CCMO / GAIADR3, con o sin DR)
    isoch_extin = extinction_pt(
        ext_law      = self.ext_law,
        ext_coefs    = self.ext_coefs,
        rand_norm    = self.rand_floats["norm"][0],
        rand_unif    = self.rand_floats["unif"][0],
        DR_distribution = self.DR_distribution,
        m_ini_idx    = int(self.m_ini_idx),
        binar_flag   = binar_flag,
        Av           = float(av),
        dr           = float(dr),
        Rv           = float(rv),
        iso_np       = isoch_moved,
    )

    # (3) corte por magnitud máxima -> PyTensor-friendly
    isoch_cut = cut_max_mag_pt(isoch_extin, float(max_mag_syn))
    if not isoch_cut.any():
        return np.array([])

    # Atajo para devolver isócrona recortada
    if N_stars == -1:
        return isoch_cut

    # (4) interpolar IMF
    isoch_mass = scp.mass_interp(
        isoch_cut,
        int(self.m_ini_idx),
        self.st_dist_mass[ml][al],
        int(N_synth_stars),
        binar_flag,
    )
    if not isoch_mass.any():
        return np.array([])

    # (5) asignar binariedad
    isoch_binar = scp.binarity(
        float(alpha), float(beta), binar_flag, int(self.m_ini_idx),
        self.rand_floats["unif"][1], isoch_mass
    )

    # (6) añadir errores fotométricos
    synth_clust = scp.add_errors(isoch_binar, err_dist_synth)

    return synth_clust

# bind al objeto
synthcl.generate_2 = MethodType(generate_2, synthcl)

In [None]:
met_min  = float(isochs.met_age_dict['met'][0])
met_max  = float(isochs.met_age_dict['met'][-1])
loga_min = float(isochs.met_age_dict['loga'][0])
loga_max = float(isochs.met_age_dict['loga'][-1])

M_MET, M_LOGA = 10,10
met_grid  = np.linspace(met_min,  met_max,  M_MET,  dtype=float)
loga_grid = np.linspace(loga_min, loga_max, M_LOGA, dtype=float)

In [None]:
def compare_clusters(syn: np.ndarray,
                     syn2: np.ndarray,
                     rtol: float = 0.0,
                     atol: float = 0.0,
                     topk: int = 10,
                     name1="generate",
                     name2="generate_2",
                     verbose=True):
    """
    Compara dos clusters (features x Nstars). Devuelve dict con métricas y,
    si verbose=True, imprime diagnóstico detallado.
    """
    out = {
        "ok": False, "reason": None,
        "shape1": None, "shape2": None,
        "n": 0, "n_nan_equal": 0, "n_checked": 0,
        "max_abs": np.nan, "max_rel": np.nan,
        "p95_abs": np.nan, "mean_abs": np.nan,
        "violations": 0, "viol_idx": None,
        "topk": []
    }

    # 0) forma
    out["shape1"], out["shape2"] = syn.shape, syn2.shape
    if syn.shape != syn2.shape:
        out["reason"] = "shape_mismatch"
        if verbose:
            print(f"[X] Shape mismatch: {name1}{syn.shape} vs {name2}{syn2.shape}")
        return out

    # 1) patrón de NaNs
    n1 = np.isnan(syn)
    n2 = np.isnan(syn2)
    if not np.array_equal(n1, n2):
        out["reason"] = "nan_pattern_mismatch"
        if verbose:
            diff_mask = n1 ^ n2
            f_idx, s_idx = np.where(diff_mask)
            print(f"[X] NaN pattern mismatch en {diff_mask.sum()} posiciones.")
            print("  Ejemplo (feature, star):", list(zip(f_idx[:10], s_idx[:10])))
        return out

    # 2) compara valores sólo donde ambos son finitos
    fin = (~n1) & (~n2)
    out["n"] = syn.size
    out["n_nan_equal"] = np.count_nonzero(n1)  # == n2
    out["n_checked"] = fin.sum()
    if out["n_checked"] == 0:
        out["ok"] = True
        out["reason"] = "all_nan_or_empty"
        if verbose:
            print("[=] No hay elementos finitos que comparar (todo NaN/vacío).")
        return out

    a = syn[fin].astype(np.float64, copy=False)
    b = syn2[fin].astype(np.float64, copy=False)

    absdiff = np.abs(a - b)
    with np.errstate(divide='ignore', invalid='ignore'):
        # rel = |a-b| / max(|b|, tiny)
        denom = np.maximum(np.abs(b), np.finfo(np.float64).tiny)
        reldiff = absdiff / denom

    # métricas
    out["max_abs"]  = float(absdiff.max(initial=np.nan))
    out["p95_abs"]  = float(np.percentile(absdiff, 95))
    out["mean_abs"] = float(absdiff.mean())
    out["max_rel"]  = float(np.nanmax(reldiff))

    # violaciones estrictas con rtol/atol dados
    viol_mask = absdiff > (atol + rtol * np.abs(b))
    out["violations"] = int(viol_mask.sum())

    # top-k diferencias
    if out["n_checked"] > 0:
        if topk > 0:
            # índices de top-k por diferencia absoluta
            idx_sorted = np.argsort(-absdiff)[:topk]
            # mapear a (feature, star)
            feat_idx, star_idx = np.where(fin)
            for idx in idx_sorted:
                f = int(feat_idx[idx]); s = int(star_idx[idx])
                out["topk"].append({
                    "feat": f, "star": s,
                    f"{name1}": float(syn[f, s]),
                    f"{name2}": float(syn2[f, s]),
                    "absdiff": float(absdiff[idx]),
                    "reldiff": float(reldiff[idx])
                })
            out["viol_idx"] = np.column_stack([feat_idx[viol_mask], star_idx[viol_mask]])

    # decide ok / not ok
    out["ok"] = (out["violations"] == 0)
    out["reason"] = "ok" if out["ok"] else "violations"

    if verbose:
        print(f"[=] Comparación {name1} vs {name2}:")
        print(f"    shape: {syn.shape}, finitos comparados: {out['n_checked']}, NaNs coincidentes: {out['n_nan_equal']}")
        print(f"    max|Δ|={out['max_abs']:.3e}, p95|Δ|={out['p95_abs']:.3e}, mean|Δ|={out['mean_abs']:.3e}, max rel={out['max_rel']:.3e}")
        print(f"    tolerancias: rtol={rtol}, atol={atol} → violaciones={out['violations']}")
        if out["topk"]:
            print("    Top-k diferencias:")
            for k, rec in enumerate(out["topk"], 1):
                print(f"      #{k} (feat={rec['feat']}, star={rec['star']}): "
                      f"{name1}={rec[name1]:.15g}, {name2}={rec[name2]:.15g}, "
                      f"|Δ|={rec['absdiff']:.3e}, rel={rec['reldiff']:.3e}")
    return out

In [None]:
met_min  = float(isochs.met_age_dict['met'][0])
met_max  = float(isochs.met_age_dict['met'][-1])
loga_min = float(isochs.met_age_dict['loga'][0])
loga_max = float(isochs.met_age_dict['loga'][-1])

M_MET, M_LOGA = 10,10
met_grid  = np.linspace(met_min,  met_max,  M_MET,  dtype=float)
loga_grid = np.linspace(loga_min, loga_max, M_LOGA, dtype=float)

In [None]:
rtol, atol = 0.0, 0.0  # pide igualdad bit a bit en float64 (si no hay reordenamientos)
mismatches = []

for i, met in enumerate(tqdm(met_grid, desc="met grid")):
    for j, loga in enumerate(loga_grid):
        pars = {"met": float(met), "loga": float(loga), "dm": 0.0, "Av": 0.0}
        syn  = synthcl.generate(pars)
        syn2 = synthcl.generate_2(pars)

        # salta vacíos
        if syn.size == 0 and syn2.size == 0:
            continue
        if syn.size == 0 or syn2.size == 0:
            print(f"[{i},{j}] uno vacío: shapes {syn.shape} vs {syn2.shape}")
            mismatches.append((i, j, "empty"))
            continue

        rep = compare_clusters(syn, syn2, rtol=rtol, atol=atol, topk=5, verbose=False)
        if not rep["ok"]:
            mismatches.append((i, j, rep))

print(f"Total celdas: {len(met_grid)*len(loga_grid)}")
print(f"No equivalentes: {len(mismatches)}")
if mismatches:
    i, j, rep = mismatches[0]
    print(f"Primer mismatch en (i={i}, j={j}): motivo={rep['reason']}")
    # puedes imprimir rep completo si quieres
    _ = compare_clusters(synthcl.generate(
            {"met": float(met_grid[i]), "loga": float(loga_grid[j]), "dm": 0.0, "Av": 0.0}
        ),
        synthcl.generate_2(
            {"met": float(met_grid[i]), "loga": float(loga_grid[j]), "dm": 0.0, "Av": 0.0}
        ),
        rtol=rtol, atol=atol, topk=10, verbose=True
    )

In [None]:
# Carga directa desde ca.data (QTable)
T = ca.data  

# Columnas de magnitud y colores
mag_obs    = np.array(T["Gmag"])       # magnitud
color1_obs = np.array(T["G_BPmag"] - T["G_RPmag"])  # color1 = BP - RP
color2_obs = None

# Errores asociados
e_mag_obs  = np.array(T["e_Gmag"])
e_col1_obs = np.array(T["e_BP_RP"]) if "e_BP_RP" in T.colnames else (
    np.hypot(np.array(T["e_G_BPmag"]), np.array(T["e_G_RPmag"])))
e_col2_obs = None

# Chequeo de tamaños
print("Obs sizes:",
      len(mag_obs), len(color1_obs),
      None if color2_obs is None else len(color2_obs))

In [None]:
import asteca
isochs = asteca.Isochrones(isochs_path="/Users/notluquis/Library/Mobile Documents/com~apple~CloudDocs/Investigación/NGC6383/MIST",
                           model='MIST',
                           magnitude="Gaia_G_EDR3",
                           magnitude_effl=6390.7,
                           color=("Gaia_BP_EDR3","Gaia_RP_EDR3"),
                           color_effl = (5182.6,7825.1),
                           #color2 = ("Gaia_RP_EDR3","2MASS_J"),
                           #color2_effl = (7825.1, 12375.60),
                          )

In [None]:
theor_tracks, color_filters, met_age_dict = isochs.theor_tracks, isochs.color_filters, isochs.met_age_dict

In [None]:
def get_sample_imf_cached_pkl(cache_path: str,
                              rng_seed: int,
                              IMF_name: str,
                              max_mass: float,
                              Nmets: int,
                              Nages: int):
    """
    Carga st_dist_mass (+ ordered) desde un .pkl si es compatible con la malla actual.
    Si no existe / no coincide la metadata, recalcula, guarda y devuelve.
    """
    meta_new = dict(
        rng_seed=int(rng_seed),
        IMF_name=str(IMF_name),
        max_mass=float(max_mass),
        Nmets=int(Nmets),
        Nages=int(Nages),
    )

    obj = load_pkl(cache_path)
    st = obj["st_dist_mass"]
    sto = obj.get("st_dist_mass_ordered", None)
    meta_old = obj.get("meta", {})
    return st,sto,meta_old

In [None]:
# =========================
# USO con tus variables
# =========================
Nmets   = len(met_age_dict["met"])
Nages   = len(met_age_dict["loga"])
rng_seed = 42
IMF_name = "kroupa_2001"
Max_mass = 2.5e5

cache_path = '/Users/notluquis/Library/Mobile Documents/com~apple~CloudDocs/Investigación/COSMIC/COSMIC/precomp/st_imf.pkl'

#st_dist_mass, st_dist_mass_ordered, meta = get_sample_imf_cached_pkl(
st_dist_mass, st_dist_mass_ordered = get_sample_imf_cached_pkl(
    cache_path=cache_path,
    rng_seed=rng_seed,
    IMF_name=IMF_name,
    max_mass=Max_mass,
    Nmets=Nmets,
    Nages=Nages
)

In [None]:
st_dist_mass, st_dist_mass_ordered, meta = obj

In [None]:
print("len(st_dist_mass) =", len(st_dist_mass))
print("len(st_dist_mass[0]) =", len(st_dist_mass[0]))
print("Fingerprint ejemplo:",
      len(st_dist_mass[0][0]),
      np.nanmin(st_dist_mass[0][0]),
      np.nanmax(st_dist_mass[0][0]))

In [None]:
from asteca.modules import synth_cluster_priv as scp

# --- Celda 5: valores aleatorios congelados para test/validación ---
# Estos NO se usarán dentro de NUTS (ahí haremos todo determinista),
# pero sirven para validar equivalencias con tu pipeline original.

# Usa el mismo tamaño que el mayor número de estrellas que vayas a sintetizar
N_max = max(len(mag_obs), 10_000)

rand_norm_vals = np.array([
    rng.normal(0.0, 1.0, N_max),  # canal 0
    rng.normal(0.0, 1.0, N_max)   # canal 1 (si necesitas otro)
])
rand_unif_vals = np.array([
    rng.uniform(0.0, 1.0, N_max),
    rng.uniform(0.0, 1.0, N_max)
])

# Si quieres generar un err_dist "congelado" sólo para validaciones:
err_dist = scp.error_distribution(
    mag=mag_obs,
    e_mag=e_mag_obs,
    e_color=e_col1_obs,
    e_color2=None,
    rand_norm_vals=rand_norm_vals[0]
)
print("err_dist sizes:", [len(x) for x in err_dist])

In [None]:
import numpy as np
from scipy.sparse import csr_matrix

def collapse_duplicates_strict(x, y, keep="last"):
    x = np.asarray(x, float); y = np.asarray(y, float)
    order = np.argsort(x, kind="mergesort")
    xo, yo = x[order], y[order]

    if keep == "first":
        xs, idx = np.unique(xo, return_index=True)
        ys = yo[idx]
    elif keep == "last":
        xr, yr = xo[::-1], yo[::-1]
        _, idxr = np.unique(xr, return_index=True)  # first in reversed == last in original
        idx = (xr.size - 1) - idxr
        idx.sort()
        xs, ys = xo[idx], yo[idx]
    else:
        raise ValueError("keep must be 'first' or 'last'")

    if not np.all(np.diff(xs) > 0):
        raise RuntimeError("xs is not strictly increasing after collapsing duplicates.")
    return xs, ys

def linear_interp_matrix(xq, x):
    """
    Build a sparse matrix W such that (W @ y) == np.interp(xq, x, y)
    for any vector y sampled on strictly increasing grid x.
    """
    x  = np.asarray(x,  float)
    xq = np.asarray(xq, float)
    if np.any(~np.isfinite(x)) or np.any(~np.isfinite(xq)):
        raise ValueError("Non-finite values in x/xq are not supported.")
    if not np.all(np.diff(x) > 0):
        raise ValueError("x must be strictly increasing (no duplicates).")

    n = x.size
    m = xq.size
    j  = np.searchsorted(x, xq, side='right')
    j  = np.clip(j, 1, n-1)
    i0 = j - 1

    x0 = x[i0]; x1 = x[j]
    w1 = (xq - x0) / (x1 - x0)
    w0 = 1.0 - w1

    # filas correctas: [0..m-1] para el bloque izquierdo y derecho
    rows = np.concatenate([np.arange(m), np.arange(m)])
    cols = np.concatenate([i0, j])
    data = np.concatenate([w0, w1])

    return csr_matrix((data, (rows, cols)), shape=(m, n))

In [None]:
# =============================
# TEST con tus variables
# =============================
Nmets, Nages = theor_tracks.shape[:2]
Zidx = 0  # o el que quieras
Aidx = 0  # o el que quieras

m_ini_idx  = theor_tracks.shape[2] - 1
Mini_k_raw = np.asarray(theor_tracks[Zidx, Aidx, m_ini_idx], float)
mag_k_raw  = np.asarray(theor_tracks[Zidx, Aidx, 0],       float)

# 1) ordenar + colapsar duplicados (quedarse con el ÚLTIMO)
xs, ys = collapse_duplicates_strict(Mini_k_raw, mag_k_raw, keep="last")

# 2) masses del IMF y clip al rango de xs
xq = np.asarray(st_dist_mass[Zidx][Aidx], float)
xq = np.clip(xq, xs[0], xs[-1])

# 3) construir W y comparar con np.interp en la MISMA base (xs, ys)
W   = linear_interp_matrix(xq, xs)     # <-- usa el nombre correcto
lhs = W @ ys
rhs = np.interp(xq, xs, ys)

max_abs_err = float(np.max(np.abs(lhs - rhs)))
row_sums = np.asarray(W.sum(axis=1)).ravel()

print("interp check (∞-norm):", max_abs_err)
print("row sums min/max:", row_sums.min(), row_sums.max())

assert max_abs_err < 1e-8, "Interpolation mismatch exceeds tolerance"
assert np.allclose(row_sums, 1.0, atol=1e-12), "Each row of W must sum to 1"

In [None]:
for Zi in range(theor_tracks.shape[0]):
    for Ai in range(theor_tracks.shape[1]):
        xs, ys = collapse_duplicates_strict(theor_tracks[Zi, Ai, m_ini_idx],
                                            theor_tracks[Zi, Ai, 0],
                                            keep="last")
        xq = np.clip(np.asarray(st_dist_mass[Zi][Ai], float), xs[0], xs[-1])
        W   = linear_interp_matrix(xq, xs)
        lhs = W @ ys
        rhs = np.interp(xq, xs, ys)
        err = float(np.max(np.abs(lhs - rhs)))
        rows = np.asarray(W.sum(axis=1)).ravel()
        if not (err < 1e-8 and np.allclose(rows, 1.0, atol=1e-12)):
            raise AssertionError(f"Falla en (Z={Zi}, A={Ai}): err={err}, "
                                 f"rows[{rows.min()}, {rows.max()}])")
print("OK en toda la malla.")

In [None]:
theor_tracks

In [None]:
# ============================================
# Paso 2: check apply_extinction NumPy vs Torch
# ============================================
import numpy as np
from asteca.modules import synth_cluster_priv as scp

# --- Selección de celda y parámetros fijos (sin DR) ---
Zidx, Aidx = 0, 0
fit_params = {
    "met": float(met_age_dict["met"][Zidx]),
    "loga": float(met_age_dict["loga"][Aidx]),
    "alpha": 0.1, "beta": 1.0,
    "Av": 0.3, "DR": 0.0, "Rv": 3.1, "dm": 10.0
}

In [None]:
isoch_base = np.array(theor_tracks[Zidx, Aidx])   # (Nd, Ni)
Nd = isoch_base.shape[0]

# m_ini_idx: siempre la última fila disponible del isocrono base
# (mag, color1, mass_ini) -> m_ini_idx=2 ; (mag, color1, color2, mass_ini) -> m_ini_idx=3
m_ini_idx = Nd - 1

# ¿Hay columnas de binaria ya presentes?
# En los tracks "crudos" típicamente Nd=3 ó 4 → NO hay binaria aún.
has_binary_cols = (Nd >= m_ini_idx + 3)  # requiere al menos [mass_ini, m2, mag_b, color1_b, ...]

# Para este test queremos usar el isocrono tal cual:
binar_flag = bool(has_binary_cols)

# --- Mover por distancia (sin errores) ---
isoch_moved = scp.move_isochrone(
    isochrone=isoch_base.copy(),
    binar_flag=binar_flag,   # ¡solo True si hay columnas de binaria!
    m_ini_idx=m_ini_idx,
    dm=fit_params["dm"],
)

In [None]:
# --- Ley de extinción y coeficientes ---
ext_law = "CCMO"   # o "GAIADR3"
if ext_law == "CCMO":
    ext_coefs = scp.ccmo_ext_coeffs(
        magnitude_effl=isochs.magnitude_effl,
        color_effl=isochs.color_effl,
        color2_effl=isochs.color2_effl,
    )
else:
    ext_coefs = []  # GAIADR3 no usa pre-cálculo

In [None]:
# --- 2) referencia NumPy (llama a scp.extinction tal cual) ---
def apply_extinction_np(isoch_arr: np.ndarray) -> np.ndarray:
    return scp.extinction(
        ext_law=ext_law,
        ext_coefs=ext_coefs,
        rand_norm=rand_norm_vals[0],    # tus pools congelados
        rand_unif=rand_unif_vals[0],
        DR_distribution="uniform",
        m_ini_idx=m_ini_idx,
        binar_flag=binar_flag,
        Av=fit_params["Av"],
        dr=fit_params["DR"],
        Rv=fit_params["Rv"],
        isochrone=isoch_arr.copy(),
    )


# --- 3) versión Torch equivalente ---
def _dustapprox_torch(X_, Av_dr, torch):
    # Coefs idénticos a scp.dustapprox (EDR3)
    coeffs = {
        "G":  (0.995969721536602,-0.159726460302015,0.0122380738156057,0.00090726555099859,
               -0.0377160263914123,0.00151347495244888,-2.52364537395142e-05,
               0.0114522658102451,-0.000936914989014318,-0.000260296774134201),
        "BP": (1.15363197483424,-0.0814012991657388,-0.036013023976704,0.0192143585568966,
               -0.022397548243016,0.000840562680547171,-1.31018008013549e-05,
               0.00660124080271006,-0.000882247501989453,-0.000111215755291684),
        "RP": (0.66320787941067,-0.0179847164933981,0.000493769449961458,-0.00267994405695751,
               -0.00651422146709376,3.30179903473159e-05,1.57894227641527e-06,
               -7.9800898337247e-05,0.000255679812110045,1.10476584967393e-05),
    }
    X2, X3 = X_*X_, X_*X_*X_
    A2, A3 = Av_dr*Av_dr, Av_dr*Av_dr*Av_dr
    def ext_coeff(k):
        c = coeffs[k]
        ay = (torch.full_like(X_, c[0]) +
              c[1]*X_ + c[2]*X2 + c[3]*X3 +
              c[4]*Av_dr + c[5]*A2 + c[6]*A3 +
              c[7]*X_*Av_dr + c[9]*X_*A2 + c[8]*X2*Av_dr)
        return ay
    ec_G    = ext_coeff("G")
    ec_BPRP = ext_coeff("BP") - ext_coeff("RP")
    return ec_G, ec_BPRP

def apply_extinction_pt(
    isoch_arr: np.ndarray,
    *,
    ext_law: str,
    ext_coefs: list,
    rand_norm_vals: np.ndarray,
    rand_unif_vals: np.ndarray,
    DR_distribution: str,
    m_ini_idx: int,
    binar_flag: bool,
    Av: float,
    DR: float,
    Rv: float,
) -> np.ndarray:
    try:
        import torch
    except Exception:
        return None  # Torch no disponible → saltar

    t = torch.from_numpy(isoch_arr.copy()).to(torch.float64)

    # Av con DR
    if DR > 0.0:
        Ns = t.shape[-1]
        if DR_distribution == "uniform":
            dr_arr = (2.0*torch.from_numpy(rand_unif_vals[:Ns]).to(torch.float64) - 1.0) * DR
        elif DR_distribution == "normal":
            dr_arr = torch.from_numpy(rand_norm_vals[:Ns]).to(torch.float64) * DR
        else:
            raise ValueError(f"DR_distribution desconocida: {DR_distribution}")
        Av_dr = torch.clip(torch.tensor(Av, dtype=torch.float64) + dr_arr, min=0.0)
    else:
        Av_dr = torch.tensor(Av, dtype=torch.float64)

    def _dustapprox_torch(X_, Av_dr):
        # Coefs EDR3 idénticos a scp.dustapprox
        coeffs = {
            "G":  (0.995969721536602,-0.159726460302015,0.0122380738156057,0.00090726555099859,
                   -0.0377160263914123,0.00151347495244888,-2.52364537395142e-05,
                   0.0114522658102451,-0.000936914989014318,-0.000260296774134201),
            "BP": (1.15363197483424,-0.0814012991657388,-0.036013023976704,0.0192143585568966,
                   -0.022397548243016,0.000840562680547171,-0.0000131018008013549,
                   0.00660124080271006,-0.000882247501989453,-0.000111215755291684),
            "RP": (0.66320787941067,-0.0179847164933981,0.000493769449961458,-0.00267994405695751,
                   -0.00651422146709376,0.0000330179903473159,0.00000157894227641527,
                   -0.000079800898337247,0.000255679812110045,0.0000110476584967393),
        }
        X2, X3 = X_*X_, X_*X_*X_
        A2, A3 = Av_dr*Av_dr, Av_dr*Av_dr*Av_dr
        def ext_coeff(k):
            c = coeffs[k]
            ay = (torch.full_like(X_, c[0]) +
                  c[1]*X_ + c[2]*X2 + c[3]*X3 +
                  c[4]*Av_dr + c[5]*A2 + c[6]*A3 +
                  c[7]*X_*Av_dr + c[9]*X_*A2 + c[8]*X2*Av_dr)
            return ay
        ec_G    = ext_coeff("G")
        ec_BPRP = ext_coeff("BP") - ext_coeff("RP")
        return ec_G, ec_BPRP

    if ext_law == "CCMO":
        a_mag, b_mag = ext_coefs[0]
        ec_mag  = torch.tensor(a_mag, dtype=torch.float64) + torch.tensor(b_mag, dtype=torch.float64)/Rv
        a1, b1 = ext_coefs[1][0]; a2, b2 = ext_coefs[1][1]
        ec_col1 = (torch.tensor(a1, dtype=torch.float64) + torch.tensor(b1, dtype=torch.float64)/Rv) - \
                  (torch.tensor(a2, dtype=torch.float64) + torch.tensor(b2, dtype=torch.float64)/Rv)
        has_col2 = (len(ext_coefs) > 2)
        if has_col2:
            a1b, b1b = ext_coefs[2][0]; a2b, b2b = ext_coefs[2][1]
            ec_col2 = (torch.tensor(a1b, dtype=torch.float64) + torch.tensor(b1b, dtype=torch.float64)/Rv) - \
                      (torch.tensor(a2b, dtype=torch.float64) + torch.tensor(b2b, dtype=torch.float64)/Rv)
    elif ext_law == "GAIADR3":
        X = t[1]
        if DR > 0.0 and Av_dr.ndim == 0:
            Av_dr = Av_dr.expand_as(X)
        ec_mag, ec_col1 = _dustapprox_torch(X, Av_dr)
        has_col2 = False
    else:
        raise ValueError(f"Unknown extinction law: {ext_law}")

    Ax  = ec_mag  * Av_dr
    Ex1 = ec_col1 * Av_dr

    t[0] += Ax
    t[1] += Ex1

    # Solo tocar columnas de binaria si realmente existen
    if binar_flag and (t.shape[0] > m_ini_idx + 3):
        t[m_ini_idx + 2] += Ax
        t[m_ini_idx + 3] += Ex1

    if ext_law == "CCMO" and has_col2 and (t.shape[0] > 2):
        Ex2 = ec_col2 * Av_dr
        t[2] += Ex2
        if binar_flag and (t.shape[0] > m_ini_idx + 4):
            t[m_ini_idx + 4] += Ex2

    return t.numpy()

In [None]:
# --- Ejecutar el check ---
out_np = apply_extinction_np(isoch_moved.copy())
out_pt = apply_extinction_pt(
    isoch_moved.copy(),
    ext_law=ext_law,
    ext_coefs=ext_coefs,
    rand_norm_vals=rand_norm_vals[0],
    rand_unif_vals=rand_unif_vals[0],
    DR_distribution="uniform",
    m_ini_idx=m_ini_idx,
    binar_flag=binar_flag,
    Av=fit_params["Av"],
    DR=fit_params["DR"],
    Rv=fit_params["Rv"],
)

if out_pt is None:
    print("Torch no disponible → test saltado.")
else:
    diff = float(np.max(np.abs(out_np - out_pt)))
    print("extinction check: max |np - pt| =", diff)
    assert diff < 1e-8, "Extinction mismatch exceeds tolerance"

In [None]:
# ========= utilidades base (torch) =========
import torch
import numpy as np

def to_tensor64(x, device=None):
    if isinstance(x, torch.Tensor):
        return x.to(dtype=torch.float64, device=device)
    return torch.as_tensor(x, dtype=torch.float64, device=device)

# --- 1) move_isochrone (torch) ---
def move_isochrone_pt(isoch_pt, m_ini_idx, dm, binar_flag=True):
    # magnitud (col 0)
    isoch_pt[0] = isoch_pt[0] + dm
    if binar_flag:
        # m_ini_idx+2 = magnitud de binaria (mismo layout que tu NumPy)
        isoch_pt[m_ini_idx + 2] = isoch_pt[m_ini_idx + 2] + dm
    return isoch_pt

# --- 2) extinction CCMO (torch) ---
def extinction_ccmo_pt(isoch_pt, m_ini_idx, Av, Rv, ext_coefs, binar_flag=True):
    """
    ext_coefs = [ [a,b], [[a1,b1],[a2,b2]] , (opcional color2) ]
    Aplica Ax = (a + b/Rv)*Av para mag y
           Ex1 = ([a1 + b1/Rv] - [a2 + b2/Rv]) * Av para color1
    """
    a_mag, b_mag = ext_coefs[0]
    ec_mag = a_mag + b_mag / Rv
    isoch_pt[0] = isoch_pt[0] + ec_mag * Av

    (a1, b1), (a2, b2) = ext_coefs[1]
    ec_c1 = (a1 + b1 / Rv) - (a2 + b2 / Rv)
    isoch_pt[1] = isoch_pt[1] + ec_c1 * Av

    if binar_flag:
        isoch_pt[m_ini_idx + 2] = isoch_pt[m_ini_idx + 2] + ec_mag * Av
        isoch_pt[m_ini_idx + 3] = isoch_pt[m_ini_idx + 3] + ec_c1 * Av

    # Si tienes color2 en tracks y coeficientes:
    if len(ext_coefs) > 2:
        (a1c2, b1c2), (a2c2, b2c2) = ext_coefs[2]
        ec_c2 = (a1c2 + b1c2 / Rv) - (a2c2 + b2c2 / Rv)
        isoch_pt[2] = isoch_pt[2] + ec_c2 * Av
        if binar_flag:
            isoch_pt[m_ini_idx + 4] = isoch_pt[m_ini_idx + 4] + ec_c2 * Av
    return isoch_pt

# --- 3) soft histogram 1D (diferenciable) ---
def soft_hist1d_torch(vals, edges):
    """
    vals: (Ns,) tensor de magnitudes
    edges: (B+1,) bordes crecientes
    Devuelve lam: (B,) con asignación lineal a los dos bins vecinos.
    """
    x = to_tensor64(vals)
    e = to_tensor64(edges)

    # Bin centers:
    c = 0.5 * (e[1:] + e[:-1])   # (B,)
    B = c.numel()

    # para cada x, encontrar bin derecho respecto a centers
    j = torch.searchsorted(c, x, right=True)  # [0..B]
    j = torch.clamp(j, 1, B-1)                # evita extremos
    i0 = j - 1

    x0 = c[i0]       # centro izq
    x1 = c[j]        # centro der
    t  = (x - x0) / (x1 - x0)  # peso hacia la derecha
    w0 = 1.0 - t
    w1 = t

    lam = torch.zeros(B, dtype=torch.float64, device=x.device)
    lam = lam.index_add(0, i0, w0)  # acumula en i0
    lam = lam.index_add(0, j,  w1)  # acumula en j
    return lam  # diferenciable

# --- 4) soft histogram 2D (magnitud, color) ---
def soft_hist2d_torch(vals_x, edges_x, vals_y, edges_y):
    """
    vals_x: magnitud, vals_y: color
    edges_x: (Bx+1,), edges_y: (By+1,)
    Devuelve lam: (Bx, By) con kernel separable triangular (outer product por estrella).
    """
    x = to_tensor64(vals_x); y = to_tensor64(vals_y)
    ex = to_tensor64(edges_x); ey = to_tensor64(edges_y)

    cx = 0.5*(ex[1:]+ex[:-1])  # (Bx,)
    cy = 0.5*(ey[1:]+ey[:-1])  # (By,)
    Bx, By = cx.numel(), cy.numel()

    # vecinos x
    jx = torch.searchsorted(cx, x, right=True).clamp(1, Bx-1); ix0 = jx-1
    x0 = cx[ix0]; x1 = cx[jx]
    tx = (x - x0)/(x1 - x0); wx0 = 1.0 - tx; wx1 = tx

    # vecinos y
    jy = torch.searchsorted(cy, y, right=True).clamp(1, By-1); iy0 = jy-1
    y0 = cy[iy0]; y1 = cy[jy]
    ty = (y - y0)/(y1 - y0); wy0 = 1.0 - ty; wy1 = ty

    lam = torch.zeros(Bx, By, dtype=torch.float64, device=x.device)

    # 4 combinaciones por estrella: (ix0,iy0), (ix0,jy), (jx,iy0), (jx,jy)
    # pesos = producto separable
    lam.index_put_((ix0, iy0), wx0*wy0, accumulate=True)
    lam.index_put_((ix0, jy ), wx0*wy1, accumulate=True)
    lam.index_put_((jx,  iy0), wx1*wy0, accumulate=True)
    lam.index_put_((jx,  jy ), wx1*wy1, accumulate=True)
    return lam

In [None]:
# ========= harness de prueba (1D en magnitud; adapta a 2D si quieres) =========

# 0) Observado (usa las mismas edges que tu pipeline)
valid = np.isfinite(mag_obs)
N_obs_np, edges_np = np.histogram(mag_obs[valid], bins=30)
N_obs = torch.as_tensor(N_obs_np, dtype=torch.float64)
edges = torch.as_tensor(edges_np, dtype=torch.float64)

# 1) Selección de isócrona
Zidx, Aidx = 0, 0
tracks_np  = theor_tracks[Zidx, Aidx]             # (Nd, Ni) numpy
tracks_pt  = to_tensor64(tracks_np)               # torch
Nd         = tracks_pt.shape[0]
m_ini_idx  = Nd - 1                                # igual que antes

# 2) Mover + Extinguir (CCMO) sin DR
dm = 10.0; Av = 0.3; Rv = 3.1
binar_flag = (tracks_pt.shape[0] > (m_ini_idx + 2))  # True si hay columnas binaria

# ext_coefs: llévalos como tensores (a,b) con mismo orden que en NumPy
# ej: Gaia EDR3
ext_coefs = [
    (torch.tensor(0.0, dtype=torch.float64), torch.tensor(0.0, dtype=torch.float64)),  # placeholder
    ((torch.tensor(0.0), torch.tensor(0.0)), (torch.tensor(0.0), torch.tensor(0.0)))  # placeholder
]
# >>> Reemplaza arriba por tus coeficientes reales (mismos que ya usaste en (2) NP/PT)

iso_moved = move_isochrone_pt(tracks_pt.clone(), m_ini_idx, dm, binar_flag=binar_flag)
iso_ext   = extinction_ccmo_pt(iso_moved, m_ini_idx, Av, Rv, ext_coefs, binar_flag=binar_flag)

# 3) Magnitudes “modelo” (sin errores ni DR)
mag_model_pt = iso_ext[0]   # (Ni,)

# 4) Histograma suave (PT)
lam_pt = soft_hist1d_torch(mag_model_pt, edges)   # (B,)

# 5) Normaliza a total observado
sum_obs = float(N_obs.sum())
sum_mod = float(lam_pt.sum())
if sum_mod > 0:
    lam_pt = lam_pt * (sum_obs / sum_mod)
else:
    lam_pt = torch.zeros_like(lam_pt)

# 6) Check de normalización
rel_err = abs(float(lam_pt.sum()) - sum_obs)/max(1.0, sum_obs)
print(f"[PT] lam.sum={float(lam_pt.sum()):.6f}, N_obs.sum={sum_obs:.6f}, rel_err={rel_err:.3e}")
assert rel_err < 1e-8, "Total esperado (lam.sum) debe igualar al observado"

# (Opcional) equivalencia con NumPy si tienes versión NP (soft-hess NP)
# lam_np = ...  # tu ruta NumPy que calcule lo mismo con el mismo kernel triangular
# diff = np.max(np.abs(lam_np - lam_pt.detach().cpu().numpy()))
# print("soft_hess NP vs PT: max |lam_np - lam_pt| =", float(diff))
# assert diff < 1e-8

In [None]:
Mini-checklist final
	•	Usa el mismo m_ini_idx que en tu pipeline (2 si trabajas con 1 color; 3 si usas 2).
	•	Asegura que Mini_k esté ordenado y sin NaNs.
	•	Reutiliza exactamente los mismos rand_norm_vals / rand_unif_vals entre ambas rutas (np/pt).
	•	En el test del Hessiano, si tu soft_hess_np devuelve también penalizaciones/reg., comprueba el lam antes de normalizar por regularización.