In [None]:
import msprime as msp
import demes
import demesdraw

demo = msp.Demography()
demo.add_population(initial_size = 5000, name = "anc")
demo.add_population(initial_size = 5000, name = "P0")
demo.add_population(initial_size = 5000, name = "P1")
demo.add_population(initial_size = 5000, name = "P2")
demo.add_population(initial_size = 5000, name = "P3")
demo.add_population(initial_size = 5000, name = "P4")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.001)
demo.set_symmetric_migration_rate(populations=("P1", "P2"), rate=0.001)
demo.set_symmetric_migration_rate(populations=("P2", "P3"), rate=0.001)
demo.set_symmetric_migration_rate(populations=("P3", "P4"), rate=0.001)
tmp = [f"P{i}" for i in range(5)]
demo.add_population_split(time = 1000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(5)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 12)

demesdraw.tubes(g)

## The next chunk is just the fit function pasted in a code chunk

In [None]:
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import jax
import jax.numpy as jnp
import msprime as msp
from scipy.optimize import LinearConstraint, minimize
import jax.random as jr
from jax import vmap, lax 

from demesinfer.coal_rate import PiecewiseConstant
from demesinfer.constr import EventTree, constraints_for
from demesinfer.iicr import IICRCurve
from demesinfer.loglik.arg import loglik

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)

def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}

def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}

def compile(ts, subkey):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()}
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
    a, b = samples[0].item(0), samples[1].item(0)
    spans = []
    curr_t = None
    curr_L = 0.0
    for tree in ts.trees():
        L = tree.interval.right - tree.interval.left
        t = tree.tmrca(a, b)
        if curr_t is None or t != curr_t:
            if curr_t is not None:
                spans.append([curr_t, curr_L])
            curr_t = t
            curr_L = L
        else:
            curr_L += L
    spans.append([curr_t, curr_L])
    data = jnp.asarray(spans, dtype=jnp.float64)
    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return data, pop_cfg

def get_tmrca_data(ts, key, num_samples):
    data_list = []
    cfg_list = []
    max_indices = []
    for i in range(num_samples):
        key, subkey = jr.split(key)
        data, cfg = compile(ts, subkey)
        data_list.append(data)
        cfg_list.append(cfg)
        max_indices.append(data.shape[0] - 1)

    lens = jnp.array([d.shape[0] for d in data_list], dtype=jnp.int32)
    Lmax = int(lens.max())
    Npairs = len(data_list)
    data_pad = jnp.full((Npairs, Lmax, 2), jnp.array([1.0, 0.0]), dtype=jnp.float64)

    for i, d in enumerate(data_list):
        data_pad = data_pad.at[i, : d.shape[0], :].set(d)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))
    
    return data_pad, cfg_mat, deme_names, jnp.array(max_indices)

def plot_likelihood(demo, ts, paths, vec_values, recombination_rate=1e-8, seed=1, num_samples=20, t_min=1e-8, num_t=1000, k=2):
    import matplotlib.pyplot as plt

    key = jr.PRNGKey(seed)
    path_order: List[Var] = list(paths)
    data_pad, cfg_mat, deme_names, max_indices = get_tmrca_data(ts, key, num_samples)
    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    t_breaks = jnp.linspace(t_min, global_max * 2, num_t)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)

    def compute_loglik(vec, sample_config, data, max_index):
        # Convert sample_config (array) to dictionary of population sizes
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
        
        # Initialize params (assuming fixed for all samples)
        params = _vec_to_dict_jax(vec, path_order)
        
        # Compute IICR and log-likelihood
        c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data, max_index)
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(vec_array, cfg_mat, data_pad, max_indices)
        return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

    # Outer vmap: Parallelize across vec_values
    # batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # # 3. Compute all values (runs on GPU/TPU if available)
    # results = batched_neg_loglik(vec_values) 
    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

def fit(
    demo,
    paths: Params,
    ts,
    *,
    k: int = 2,
    n_samples: int = 10,
    t_min: float = 1e-8,
    # t_max: float,
    num_t: int = 1000,
    method: str = "trust-constr",
    options: Optional[dict] = None,
    recombination_rate: float = 1e-8,
    sequence_length: float = 1e7,
    mutation_rate: float = 1e-8,
    seed: int = 1,
    num_samples = 20,
):
    key = jr.PRNGKey(seed)
    # msp_demo = msp.Demography.from_demes(demo)
    # deme_names = [d.name for d in demo.demes]
    # samples = {d: n_samples for d in deme_names[1:]}
    # ts = msp.sim_mutations(
    #     msp.sim_ancestry(
    #         samples=samples,
    #         demography=msp_demo,
    #         recombination_rate=recombination_rate,
    #         sequence_length=sequence_length,
    #         random_seed=seed,
    #     ),
    #     rate=mutation_rate,
    #     random_seed=seed + 1,
    # )

    data_pad, cfg_mat, deme_names, max_indices = get_tmrca_data(ts, key, num_samples)

    path_order: List[Var] = list(paths)
    x0 = _dict_to_vec(paths, path_order)
    et = EventTree(demo)

    cons = constraints_for(et, *path_order)
    linear_constraints: list[LinearConstraint] = []

    Aeq, beq = cons["eq"]
    if Aeq.size:
        linear_constraints.append(LinearConstraint(Aeq, beq, beq))

    G, h = cons["ineq"]
    if G.size:
        lower = -jnp.inf * jnp.ones_like(h)
        linear_constraints.append(LinearConstraint(G, lower, h))

    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    t_breaks = jnp.linspace(t_min, global_max * 2, num_t)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)

    def compute_loglik(vec, sample_config, data):
        # Convert sample_config (array) to dictionary of population sizes
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
        
        # Initialize params (assuming fixed for all samples)
        params = _vec_to_dict_jax(vec, path_order)
        
        # Compute IICR and log-likelihood
        c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data)
    
    @jax.value_and_grad
    def neg_loglik(vec):
        vec = vec
        batched_loglik = vmap(
        compute_loglik,
        in_axes=(None, 0, 0))(vec, cfg_mat, data_pad)
        
        likelihood = jnp.sum(batched_loglik)

        return -likelihood / num_samples

    res = minimize(
        fun=lambda x: float(neg_loglik(x)[0]),
        # fun=lambda x: float(neg_loglik(x)),
        x0=jnp.asarray(x0),
        jac=lambda x: jnp.asarray(neg_loglik(x)[1], dtype=float),
        method=method,
        # bounds = [(3000. / 5000., 7000. / 5000.)],
        constraints=linear_constraints,
    )

    return _vec_to_dict(jnp.asarray(res.x), path_order)


## Here's how I would call on plot_likelihood over a vector of values to plot the likelihood

In [None]:
import jax.numpy as jnp
paths = {
    frozenset({
        ('demes', 1, 'epochs', 0, 'end_size'),
        ('demes', 1, 'epochs', 0, 'start_size'),
    }): 4000.,
}
vec_values = jnp.linspace(4000, 7000, 10)
result = plot_likelihood(g, ts, paths, vec_values, num_samples = 150)

## Here's how I know it takes way too long to evaluate the likelihood for a *single* parameter value while averaging over 150 samples with 1000 time discretization

In [None]:
key = jr.PRNGKey(1)
num_samples=150
path_order: List[Var] = list(paths)
data_pad, cfg_mat, deme_names, max_indices = get_tmrca_data(ts, key, num_samples=num_samples)
first_columns = data_pad[:, :, 0]
# Compute global max (single float value)
global_max = jnp.max(first_columns)
t_breaks = jnp.linspace(1e-8, global_max * 2, 1000)
rho = 1e-8
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)

def compute_loglik(vec, sample_config, data, max_index):
    # Convert sample_config (array) to dictionary of population sizes
    ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
    
    # Initialize params (assuming fixed for all samples)
    params = _vec_to_dict_jax(vec, path_order)
    
    # Compute IICR and log-likelihood
    c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
    eta = PiecewiseConstant(c=c, t=t_breaks)
    return loglik(eta, rho, data, max_index)

def evaluate_at_vec(vec):
    vec_array = jnp.atleast_1d(vec)
    # Batched over cfg_mat and all_tmrca_spans 
    batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(vec_array, cfg_mat, data_pad, max_indices)
    return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

# Outer vmap: Parallelize across vec_values
# batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

# # 3. Compute all values (runs on GPU/TPU if available)
# results = batched_neg_loglik(vec_values)
# vec_values = jnp.linspace(4000, 7000, 10) 
results = evaluate_at_vec(4000.)

## NaNs in eta bug

In [None]:
import msprime as msp
import demes
import demesdraw

# Create demography object
demo = msp.Demography()

# Add populations
demo.add_population(initial_size=3000, name="anc")
demo.add_population(initial_size=1000, name="P0")
demo.add_population(initial_size=1000, name="P1")

# Set initial migration rate
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)

# population growth at 500 generations
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P0"
)
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P1"
)

# Migration rate change changed to 0.001 AFTER 500 generation (going into the past)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P0",
    dest="P1"
)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P1",
    dest="P0"
)

# THEN add the older events (population split at 1000)
demo.add_population_split(time=1000, derived=["P0", "P1"], ancestral="anc")

# Visualize the demography
g = demo.to_demes()
demesdraw.tubes(g)

sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e8, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 12)

## copy pasting fit function into next code chunk

In [None]:
# Example implementation of a fit function for parameter inference.
# This is intended for tutorial use only. We do not take responsibility for any bugs or issues in this code.

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import jax
import jax.numpy as jnp
import msprime as msp
from scipy.optimize import LinearConstraint, minimize
import jax.random as jr
from jax import vmap, lax 

from demesinfer.coal_rate import PiecewiseConstant
from demesinfer.constr import EventTree, constraints_for
from demesinfer.iicr import IICRCurve
from demesinfer.loglik.arg import loglik

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)

def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}

def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}

def compile(ts, subkey):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()}
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
    a, b = samples[0].item(0), samples[1].item(0)
    spans = []
    curr_t = None
    curr_L = 0.0
    for tree in ts.trees():
        L = tree.interval.right - tree.interval.left
        t = tree.tmrca(a, b)
        if curr_t is None or t != curr_t:
            if curr_t is not None:
                spans.append([curr_t, curr_L])
            curr_t = t
            curr_L = L
        else:
            curr_L += L
    spans.append([curr_t, curr_L])
    data = jnp.asarray(spans, dtype=jnp.float64)
    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return data, pop_cfg

def get_tmrca_data(ts, key, num_samples):
    data_list = []
    cfg_list = []
    max_indices = []
    for i in range(num_samples):
        key, subkey = jr.split(key)
        data, cfg = compile(ts, subkey)
        data_list.append(data)
        cfg_list.append(cfg)
        max_indices.append(data.shape[0] - 1)

    lens = jnp.array([d.shape[0] for d in data_list], dtype=jnp.int32)
    Lmax = int(lens.max())
    Npairs = len(data_list)
    data_pad = jnp.full((Npairs, Lmax, 2), jnp.array([1.0, 0.0]), dtype=jnp.float64)

    for i, d in enumerate(data_list):
        data_pad = data_pad.at[i, : d.shape[0], :].set(d)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array([find_matching_index(row, unique_cfg) for row in cfg_mat])
    
    return data_pad, cfg_mat, deme_names, jnp.array(max_indices), unique_cfg, matching_indices

def plot_likelihood(demo, ts, paths, vec_values, recombination_rate=1e-8, seed=1, num_samples=20, t_min=1e-8, num_t=1000, k=2):
    import matplotlib.pyplot as plt

    key = jr.PRNGKey(seed)
    path_order: List[Var] = list(paths)
    data_pad, cfg_mat, deme_names, max_indices, unique_cfg, matching_indices = get_tmrca_data(ts, key, num_samples)
    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    print(global_max)
    t_breaks = jnp.linspace(t_min, global_max * 2, num_t)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)

    def compute_loglik(c_map, c_index, data, max_index):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data, max_index)
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)

        def compute_c(sample_config):
            # Convert sample_config (array) to dictionary of population sizes
            ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
            
            # Compute IICR and log-likelihood
            c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
            return c
        c_map = vmap(compute_c, in_axes=(0))(unique_cfg)
        # c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_breaks, num_samples=dict(zip(deme_names, cfg)))["c"])(
        #     jnp.array(unique_cfg)
        # )
        
        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(c_map, matching_indices, data_pad, max_indices)
        return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

    # Outer vmap: Parallelize across vec_values
    batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # 3. Compute all values (runs on GPU/TPU if available)
    results = batched_neg_loglik(vec_values) 
    # results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

def fit(
    demo,
    paths: Params,
    ts,
    *,
    k: int = 2,
    n_samples: int = 10,
    t_min: float = 1e-8,
    # t_max: float,
    num_t: int = 1000,
    method: str = "trust-constr",
    options: Optional[dict] = None,
    recombination_rate: float = 1e-8,
    sequence_length: float = 1e7,
    mutation_rate: float = 1e-8,
    seed: int = 1,
    num_samples = 20,
):
    key = jr.PRNGKey(seed)
    # msp_demo = msp.Demography.from_demes(demo)
    # deme_names = [d.name for d in demo.demes]
    # samples = {d: n_samples for d in deme_names[1:]}
    # ts = msp.sim_mutations(
    #     msp.sim_ancestry(
    #         samples=samples,
    #         demography=msp_demo,
    #         recombination_rate=recombination_rate,
    #         sequence_length=sequence_length,
    #         random_seed=seed,
    #     ),
    #     rate=mutation_rate,
    #     random_seed=seed + 1,
    # )

    data_pad, cfg_mat, deme_names, max_indices = get_tmrca_data(ts, key, num_samples)

    path_order: List[Var] = list(paths)
    x0 = _dict_to_vec(paths, path_order)
    et = EventTree(demo)

    cons = constraints_for(et, *path_order)
    linear_constraints: list[LinearConstraint] = []

    Aeq, beq = cons["eq"]
    if Aeq.size:
        linear_constraints.append(LinearConstraint(Aeq, beq, beq))

    G, h = cons["ineq"]
    if G.size:
        lower = -jnp.inf * jnp.ones_like(h)
        linear_constraints.append(LinearConstraint(G, lower, h))

    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    t_breaks = jnp.linspace(t_min, global_max * 2, num_t)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)

    def compute_loglik(vec, sample_config, data):
        # Convert sample_config (array) to dictionary of population sizes
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
        
        # Initialize params (assuming fixed for all samples)
        params = _vec_to_dict_jax(vec, path_order)
        
        # Compute IICR and log-likelihood
        c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data)
    
    @jax.value_and_grad
    def neg_loglik(vec):
        vec = vec
        batched_loglik = vmap(
        compute_loglik,
        in_axes=(None, 0, 0))(vec, cfg_mat, data_pad)
        
        likelihood = jnp.sum(batched_loglik)

        return -likelihood / num_samples

    res = minimize(
        fun=lambda x: float(neg_loglik(x)[0]),
        # fun=lambda x: float(neg_loglik(x)),
        x0=jnp.asarray(x0),
        jac=lambda x: jnp.asarray(neg_loglik(x)[1], dtype=float),
        method=method,
        # bounds = [(3000. / 5000., 7000. / 5000.)],
        constraints=linear_constraints,
    )

    return _vec_to_dict(jnp.asarray(res.x), path_order)


## Here is the NaN error

In [None]:
import jax.numpy as jnp
paths = {
    frozenset({('demes', 2, 'start_time'), ('demes', 0, 'epochs', 0, 'end_time'), ('demes', 1, 'start_time'), ('migrations', 0, 'start_time'), ('migrations', 1, 'start_time')}): 2000.,
}
vec_values = jnp.linspace(500, 1500, 20)

key = jr.PRNGKey(1)
path_order: List[Var] = list(paths)
data_pad, cfg_mat, deme_names, max_indices, unique_cfg, matching_indices = get_tmrca_data(ts, key, num_samples=5)
first_columns = data_pad[:, :, 0]
# Compute global max (single float value)
global_max = jnp.max(first_columns)
print(global_max)
t_breaks = jnp.linspace(1e-8, global_max * 2, 1000)
rho = 1e-8
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)
params = _vec_to_dict_jax(jnp.array([vec_values[0]]), path_order)
i = 0
data = data_pad[i]
sample_config = cfg_mat[i]
max_index = max_indices[i]
# Convert sample_config (array) to dictionary of population sizes
ns = {name: sample_config[i] for i, name in enumerate(deme_names)}

# Compute IICR and log-likelihood
c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
eta = PiecewiseConstant(c=c, t=t_breaks)
loglik(eta, rho, data, max_index)

## Here is the params type error, the iicr can be called on with frozenset in the dictionary but esfs cannot 

In [None]:
import msprime as msp
import demes
import demesdraw
import demes
from demesinfer.sfs import ExpectedSFS

demo = msp.Demography()
demo.add_population(initial_size = 5000, name = "anc")
demo.add_population(initial_size = 5000, name = "P0")
demo.add_population(initial_size = 5000, name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.001)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = 1000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 12)

samples = {f"P{i}": 20 for i in range(2)}
esfs = ExpectedSFS(g, num_samples=samples)
# e1 = esfs(params = {('migrations', 0, 'rate'): 0.005} )
e1 = esfs(params = {frozenset({
        ('demes', 0, 'epochs', 0, 'end_size'),
        ('demes', 0, 'epochs', 0, 'start_size'),
    }): 7000.} )
afs = ts.allele_frequency_spectrum(sample_sets=[ts.samples([1]), ts.samples([2])], span_normalise=False)

## Phlashlib Error where the loglik does not evaluate for the case where sequence length > warmup length
# Please note: many of these variables are related to the actual simulation, I just threw in random numbers for theta and rho to get loglik running

first two code chunks are for constructing the demography and necessary data processing functions

In [None]:
import msprime as msp
import demes
import demesdraw
import numpy as np

# Create demography object
demo = msp.Demography()

# Add populations
demo.add_population(initial_size=3000, name="anc")
demo.add_population(initial_size=500, name="P0", growth_rate=-np.log(3000 / 500)/500)
demo.add_population(initial_size=100, name="P1", growth_rate=-np.log(3000 / 100)/500)

# # Add exponential decline for P0 starting 1000 generations ago
# decline_time = 1000  # When decline starts (generations ago)
# initial_size = 5000  # Size at decline_time
# final_size = 1000    # Size at present (time 0) - smaller population

# # Calculate negative growth rate for decline
# growth_rate = np.log(final_size / initial_size) / decline_time
# print(f"Growth rate: {growth_rate:.6f}")  # Will be negative

# Set initial migration rate
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)

# population growth at 500 generations
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P0",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P1",
    growth_rate=0
)

# Migration rate change changed to 0.001 AFTER 500 generation (going into the past)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P0",
    dest="P1"
)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P1",
    dest="P0"
)

# THEN add the older events (population split at 1000)
demo.add_population_split(time=1000, derived=["P0", "P1"], ancestral="anc")

# Visualize the demography
g = demo.to_demes()
demesdraw.tubes(g)

sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 12)

In [None]:
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import msprime as msp
from scipy.optimize import LinearConstraint, minimize
import jax.random as jr
from jax import vmap, lax 

from demesinfer.coal_rate import PiecewiseConstant
from demesinfer.constr import EventTree, constraints_for
from demesinfer.iicr import IICRCurve
from demesinfer.loglik.arg import loglik
from jax.scipy.special import xlogy
from demesinfer.sfs import ExpectedSFS

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)

def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}

def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}

from intervaltree import IntervalTree
import tqdm.auto as tqdm
import tskit

def _read_ts(
    ts: tskit.TreeSequence,
    nodes: list[tuple[int, int]],
    window_size: int,
    progress: bool = False,
) -> np.ndarray:
    nodes_flat = list({x for t in nodes for x in t})
    node_inds = np.array([[nodes_flat.index(x) for x in t] for t in nodes])
    N = len(nodes)
    L = int(np.ceil(ts.get_sequence_length() / window_size))
    G = np.zeros([N, L], dtype=np.int8)
    with tqdm.tqdm(
        ts.variants(samples=nodes_flat, copy=False),
        total=ts.num_sites,
        disable=not progress,
    ) as pbar:
        pbar.set_description("Reading tree sequence")
        for v in pbar:
            g = v.genotypes[node_inds]
            ell = int(v.position / window_size)
            G[:, ell] += g[:, 0] != g[:, 1]
    return G

def get_data(ts, nodes=None, window_size=100, mask=None):
    # form interval tree for masking
    L = int(ts.get_sequence_length())
    mask = mask or []
    
    if nodes is None:
        nodes = [tuple(i.nodes) for i in ts.individuals()]

    tr = IntervalTree.from_tuples([(0, L)])
    for a, b in mask:
        tr.chop(a, b)
    # compute breakpoints
    bp = np.array([x for i in tr for x in [i.begin, i.end]])
    assert len(set(bp)) == len(bp)
    assert (bp == np.sort(bp)).all()
    if bp[0] != 0.0:
        bp = np.insert(bp, 0, 0.0)
    if bp[-1] != L:
        bp = np.append(bp, L)
    mid = (bp[:-1] + bp[1:]) / 2.0
    unmasked = [bool(tr[m]) for m in mid]
    nodes_flat = list({x for t in nodes for x in t})
    afs = ts.allele_frequency_spectrum(
        sample_sets=[nodes_flat], windows=bp, polarised=True, span_normalise=False
    )[unmasked].sum(0)[1:-1]
    het_matrix = _read_ts(ts, nodes, window_size)
    # now mask out columns of the het matrix based on interval
    # overlap
    tr = IntervalTree.from_tuples(mask)
    column_mask = [
        bool(tr[a : a + window_size]) for a in range(0, L, window_size)
    ]
    assert len(column_mask) == het_matrix.shape[1]
    # set mask out these columns
    het_matrix[:, column_mask] = -1
    return dict(afs=afs, het_matrix=het_matrix)

def compile(ts, subkey, a=None, b=None):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()}
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    if a == None and b == None:
        samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
        a, b = samples[0].item(0), samples[1].item(0)

    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return pop_cfg, (a, b)

def get_het_data(ts, key=jax.random.PRNGKey(2), num_samples=200, option="random", window_size=100, mask=None):
    cfg_list = []
    all_config=[]
    key, subkey = jr.split(key)
    if option == "random":
        for i in range(num_samples):
            cfg, pair = compile(ts, subkey)
            cfg_list.append(cfg)
            all_config.append(pair)
            key, subkey = jr.split(key)
    elif option == "all":
        from itertools import combinations
        all_config = list(combinations(ts.samples(), 2))
        for a, b in all_config:
            cfg = compile(ts, subkey, a, b)
            cfg_list.append(cfg)
    elif option == "unphased":
        all_config = ts.samples().reshape(-1, 2)
        for a, b in all_config:
            cfg = compile(ts, subkey, a, b)
            cfg_list.append(cfg)

    result = get_data(ts, all_config, window_size, mask)
    return result, cfg_list

def process_data(cfg_list):
    num_samples = len(cfg_list)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array([find_matching_index(row, unique_cfg) for row in cfg_mat])
    
    return cfg_mat, deme_names, unique_cfg, matching_indices

here is my chunked het_matrix

In [None]:
result, cfg_list = get_het_data(ts)
het_matrix = result["het_matrix"]
print(het_matrix)

This code below only works because the length of the data you pass in is less than length of warmup

In [None]:
from phlashlib.iicr import PiecewiseConstant
from phlashlib.loglik import loglik
import jax.numpy as jnp
t = jnp.array([0.0, 1.0, 2.0, 3.0])
c = jnp.array([0.001, 0.001, 0.011])
iicr = PiecewiseConstant(t=t[:-1], c=c)
ll = loglik(jnp.array(het_matrix[1][1:100]), iicr, t, theta=1.0, rho=1.0)
print(ll)

Now the error happens when sequence length is 1e7 > warmup

In [None]:
from phlashlib.iicr import PiecewiseConstant
from phlashlib.loglik import loglik
import jax.numpy as jnp
t = jnp.array([0.0, 1.0, 2.0, 3.0])
c = jnp.array([0.001, 0.001, 0.001])
iicr = PiecewiseConstant(t=t[:-1], c=c)
ll = loglik(jnp.array(het_matrix[1]), iicr, t, theta=1.0, rho=1.0)
print(ll)

## Phlashlib lax.map Error
First chunk is for running the simulation and loading functions

In [None]:
import msprime as msp
import demes
import demesdraw
import numpy as np

# Create demography object
demo = msp.Demography()

# Add populations
demo.add_population(initial_size=3000, name="anc")
demo.add_population(initial_size=500, name="P0", growth_rate=-np.log(3000 / 500)/500)
demo.add_population(initial_size=100, name="P1", growth_rate=-np.log(3000 / 100)/500)

# # Add exponential decline for P0 starting 1000 generations ago
# decline_time = 1000  # When decline starts (generations ago)
# initial_size = 5000  # Size at decline_time
# final_size = 1000    # Size at present (time 0) - smaller population

# # Calculate negative growth rate for decline
# growth_rate = np.log(final_size / initial_size) / decline_time
# print(f"Growth rate: {growth_rate:.6f}")  # Will be negative

# Set initial migration rate
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)

# population growth at 500 generations
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P0",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=500,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P1",
    growth_rate=0
)

# Migration rate change changed to 0.001 AFTER 500 generation (going into the past)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P0",
    dest="P1"
)
demo.add_migration_rate_change(
    time=500,
    rate=0.001, 
    source="P1",
    dest="P0"
)

# THEN add the older events (population split at 1000)
demo.add_population_split(time=1000, derived=["P0", "P1"], ancestral="anc")

# Visualize the demography
g = demo.to_demes()
demesdraw.tubes(g)

sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7, random_seed = 12)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed = 12)

from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import msprime as msp
from scipy.optimize import LinearConstraint, minimize
import jax.random as jr
from jax import vmap, lax 

from demesinfer.constr import EventTree, constraints_for
from demesinfer.iicr import IICRCurve
from phlashlib.loglik import loglik
from phlashlib.iicr import PiecewiseConstant
from jax.scipy.special import xlogy
from demesinfer.sfs import ExpectedSFS

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)

def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}

def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}

from intervaltree import IntervalTree
import tqdm.auto as tqdm
import tskit

def _read_ts(
    ts: tskit.TreeSequence,
    nodes: list[tuple[int, int]],
    window_size: int,
    progress: bool = False,
) -> np.ndarray:
    nodes_flat = list({x for t in nodes for x in t})
    node_inds = np.array([[nodes_flat.index(x) for x in t] for t in nodes])
    N = len(nodes)
    L = int(np.ceil(ts.get_sequence_length() / window_size))
    G = np.zeros([N, L], dtype=np.int8)
    with tqdm.tqdm(
        ts.variants(samples=nodes_flat, copy=False),
        total=ts.num_sites,
        disable=not progress,
    ) as pbar:
        pbar.set_description("Reading tree sequence")
        for v in pbar:
            g = v.genotypes[node_inds]
            ell = int(v.position / window_size)
            G[:, ell] += g[:, 0] != g[:, 1]
    return G

def get_data(ts, nodes=None, window_size=100, mask=None):
    # form interval tree for masking
    L = int(ts.get_sequence_length())
    mask = mask or []
    
    if nodes is None:
        nodes = [tuple(i.nodes) for i in ts.individuals()]

    tr = IntervalTree.from_tuples([(0, L)])
    for a, b in mask:
        tr.chop(a, b)
    # compute breakpoints
    bp = np.array([x for i in tr for x in [i.begin, i.end]])
    assert len(set(bp)) == len(bp)
    assert (bp == np.sort(bp)).all()
    if bp[0] != 0.0:
        bp = np.insert(bp, 0, 0.0)
    if bp[-1] != L:
        bp = np.append(bp, L)
    mid = (bp[:-1] + bp[1:]) / 2.0
    unmasked = [bool(tr[m]) for m in mid]
    nodes_flat = list({x for t in nodes for x in t})
    afs = ts.allele_frequency_spectrum(
        sample_sets=[nodes_flat], windows=bp, polarised=True, span_normalise=False
    )[unmasked].sum(0)[1:-1]
    het_matrix = _read_ts(ts, nodes, window_size)
    # now mask out columns of the het matrix based on interval
    # overlap
    tr = IntervalTree.from_tuples(mask)
    column_mask = [
        bool(tr[a : a + window_size]) for a in range(0, L, window_size)
    ]
    assert len(column_mask) == het_matrix.shape[1]
    # set mask out these columns
    het_matrix[:, column_mask] = -1
    return dict(afs=afs, het_matrix=het_matrix)

def compile(ts, subkey, a=None, b=None):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()}
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    if a == None and b == None:
        samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
        a, b = samples[0].item(0), samples[1].item(0)

    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return pop_cfg, (a, b)

def get_het_data(ts, key=jax.random.PRNGKey(2), num_samples=10, option="random", window_size=100, mask=None):
    cfg_list = []
    all_config=[]
    key, subkey = jr.split(key)
    if option == "random":
        for i in range(num_samples):
            cfg, pair = compile(ts, subkey)
            cfg_list.append(cfg)
            all_config.append(pair)
            key, subkey = jr.split(key)
    elif option == "all":
        from itertools import combinations
        all_config = list(combinations(ts.samples(), 2))
        for a, b in all_config:
            cfg = compile(ts, subkey, a, b)
            cfg_list.append(cfg)
    elif option == "unphased":
        all_config = ts.samples().reshape(-1, 2)
        for a, b in all_config:
            cfg = compile(ts, subkey, a, b)
            cfg_list.append(cfg)

    result = get_data(ts, all_config, window_size, mask)
    return result, cfg_list

def process_data(cfg_list):
    num_samples = len(cfg_list)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array([find_matching_index(row, unique_cfg) for row in cfg_mat])
    
    return cfg_mat, deme_names, unique_cfg, matching_indices

def plot_iicr_likelihood(demo, data, cfg_list, paths, vec_values, recombination_rate=1e-8, theta=1e-8, t_min=1e-8, t_max=1e4, num_t=2000, k=2):
    import matplotlib.pyplot as plt

    het_matrix = data["het_matrix"]
    path_order: List[Var] = list(paths)
    cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
    num_samples = len(cfg_mat)
    t_breaks = jnp.insert(jnp.geomspace(t_min, t_max, 2000), 0, 0.0)
    t_iicr = jnp.insert(jnp.geomspace(t_min, t_max, num_t), 0, 0.0)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)

    def compute_loglik(c_map, c_index, data):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_iicr)
        return loglik(data, eta, t_breaks, theta, rho)
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)

        c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_iicr, num_samples=dict(zip(deme_names, cfg)))["c"])(
            jnp.array(unique_cfg)
        )
        
        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0))(c_map, matching_indices, het_matrix)
        return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

    # Outer vmap: Parallelize across vec_values
    # batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # 3. Compute all values (runs on GPU/TPU if available)
    # results = batched_neg_loglik(vec_values) 
    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("IICR Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

In [None]:
data, cfg_list = get_het_data(ts)

# everything runs smoothly for vmap

In [None]:
window_size = 100
recombination_rate=1e-8 * window_size
theta=1e-8 * window_size
t_min=1e-8
t_max=1e4
num_t=2000
k=2
paths = {
    frozenset({
        ('demes', 0, 'epochs', 0, 'end_size'),
        ('demes', 0, 'epochs', 0, 'start_size'),
    }): 4000.,
}
demo = g
het_matrix = data["het_matrix"]
path_order: List[Var] = list(paths)
cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
num_samples = len(cfg_mat)
t_breaks = jnp.insert(jnp.geomspace(t_min, t_max, 100), 0, 0.0)
t_iicr = jnp.insert(jnp.geomspace(t_min, t_max, num_t), 0, 0.0)
rho = recombination_rate
iicr = IICRCurve(demo=demo, k=k)
iicr_call = jax.jit(iicr.__call__)

def compute_loglik(c_map, c_index, data_row):
    c = c_map[c_index]
    eta = PiecewiseConstant(c=c, t=t_iicr)
    return loglik(data_row, eta, t_breaks, theta, rho)
    
def evaluate_at_vec(vec):
    vec_array = jnp.atleast_1d(vec)
    params = _vec_to_dict_jax(vec_array, path_order)

    c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_iicr, num_samples=dict(zip(deme_names, cfg)))["c"])(
        jnp.array(unique_cfg)
    )
    
    # Batched over cfg_mat and all_tmrca_spans 
    batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0))(c_map, matching_indices, het_matrix)
    return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

# Outer vmap: Parallelize across vec_values
vec_values = jnp.linspace(4000, 7000, 10)
batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default
results = batched_neg_loglik(vec_values) 

# Not so smoothly for lax.map

window_size = 100
recombination_rate=1e-8 * window_size
theta=1e-8 * window_size
t_min=1e-8
t_max=1e4
num_t=2000
k=2
paths = {
    frozenset({
        ('demes', 0, 'epochs', 0, 'end_size'),
        ('demes', 0, 'epochs', 0, 'start_size'),
    }): 4000.,
}
demo = g
het_matrix = data["het_matrix"]
path_order: List[Var] = list(paths)
cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
num_samples = len(cfg_mat)
t_breaks = jnp.insert(jnp.geomspace(t_min, t_max, 100), 0, 0.0)
t_iicr = jnp.insert(jnp.geomspace(t_min, t_max, num_t), 0, 0.0)
rho = recombination_rate
iicr = IICRCurve(demo=demo, k=k)
iicr_call = jax.jit(iicr.__call__)

def compute_loglik(c_map, c_index, data_row):
    c = c_map[c_index]
    eta = PiecewiseConstant(c=c, t=t_iicr)
    return loglik(data_row, eta, t_breaks, theta, rho)
    
def evaluate_at_vec(vec):
    vec_array = jnp.atleast_1d(vec)
    params = _vec_to_dict_jax(vec_array, path_order)

    c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_iicr, num_samples=dict(zip(deme_names, cfg)))["c"])(
        jnp.array(unique_cfg)
    )
    
    # Batched over cfg_mat and all_tmrca_spans 
    batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0))(c_map, matching_indices, het_matrix)
    return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

# Outer vmap: Parallelize across vec_values
vec_values = jnp.linspace(4000, 7000, 10)
# batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default
# results = batched_neg_loglik(vec_values) 
results = lax.map(evaluate_at_vec, vec_values)