loading all the functions

In [None]:
from __future__ import annotations

from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import jax
import jax.numpy as jnp
import msprime as msp
from scipy.optimize import LinearConstraint, minimize
import jax.random as jr
from jax import vmap, lax 

from demesinfer.coal_rate import PiecewiseConstant
from demesinfer.constr import EventTree, constraints_for
from demesinfer.iicr import IICRCurve
from demesinfer.loglik.arg import loglik
from jax.scipy.special import xlogy
from demesinfer.sfs import ExpectedSFS

import diffrax as dfx
import jax
import jax.numpy as jnp
from jax import vmap
from jax.scipy.special import xlog1py, xlogy
from jaxtyping import Array, Float, Scalar, ScalarLike

from demesinfer.coal_rate import CoalRate

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)

def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}

def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}

def compile(ts, subkey, a=None, b=None):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()}
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    if a == None and b == None:
        samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
        a, b = samples[0].item(0), samples[1].item(0)

    spans = []
    curr_t = None
    curr_L = 0.0
    for tree in ts.trees():
        L = tree.interval.right - tree.interval.left
        t = tree.tmrca(a, b)
        if curr_t is None or t != curr_t:
            if curr_t is not None:
                spans.append([curr_t, curr_L])
            curr_t = t
            curr_L = L
        else:
            curr_L += L
    spans.append([curr_t, curr_L])
    data = jnp.asarray(spans, dtype=jnp.float64)
    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return data, pop_cfg

def get_tmrca_data(ts, key=jax.random.PRNGKey(2), num_samples=200, option="random"):
    data_list = []
    cfg_list = []
    key, subkey = jr.split(key)
    if option == "random":
        for i in range(num_samples):
            data, cfg = compile(ts, subkey)
            data_list.append(data)
            cfg_list.append(cfg)
            key, subkey = jr.split(key)
    elif option == "all":
        from itertools import combinations
        all_config = list(combinations(ts.samples(), 2))
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)
    elif option == "unphased":
        all_config = ts.samples().reshape(-1, 2)
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)

    return data_list, cfg_list     

def process_data(data_list, cfg_list):
    max_indices = jnp.array([arr.shape[0]-1 for arr in data_list])
    num_samples = len(max_indices)
    lens = jnp.array([d.shape[0] for d in data_list], dtype=jnp.int32)
    Lmax = int(lens.max())
    Npairs = len(data_list)
    data_pad = jnp.full((Npairs, Lmax, 2), jnp.array([1.0, 0.0]), dtype=jnp.float64)

    for i, d in enumerate(data_list):
        data_pad = data_pad.at[i, : d.shape[0], :].set(d)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array([find_matching_index(row, unique_cfg) for row in cfg_mat])
    
    return data_pad, deme_names, max_indices, unique_cfg, matching_indices

def reformat_data(data_pad, matching_indices, max_indices):
    unique_groups = np.unique(matching_indices)
    group_unique_times = []
    rearranged_data = []
    new_max_indices = []
    new_matching_indices = []
    associated_indices = []

    for group in unique_groups:
        group_mask = matching_indices == group
        group_data = [arr for arr, keep in zip(data_pad, group_mask) if keep]
        new_matching_indices.append([num for num, keep in zip(matching_indices, group_mask) if keep])
        all_first_col = np.concatenate([arr[:, 0] for arr in group_data])
        unique_values = np.unique(all_first_col)

        unique_value_to_index = {value: idx for idx, value in enumerate(unique_values)}
        indices_in_mapping = np.array([unique_value_to_index[value] for value in all_first_col])
        associated_indices.append(indices_in_mapping)

        group_unique_times.append(unique_values)
        rearranged_data.append(group_data)
        new_max_indices.append(np.array([num for num, keep in zip(max_indices, group_mask) if keep]))

    # Find the maximum length
    max_length = max(len(arr) for arr in group_unique_times)

    # Pad each array with zeros at the end
    padded_unique_times = []
    for arr in group_unique_times:
        pad_length = max_length - len(arr)
        padded = np.pad(arr, (0, pad_length), mode='constant', constant_values=0)
        padded_unique_times.append(padded)

    padded_unique_times = np.array(padded_unique_times)
    rearranged_data = jnp.array(np.vstack(rearranged_data))
    new_matching_indices = jnp.array([x.item() for group in new_matching_indices for x in group])
    new_max_indices = jnp.concatenate([jnp.array(arr) for arr in new_max_indices])
    return padded_unique_times, rearranged_data, new_matching_indices, new_max_indices, associated_indices, unique_groups

def loglik_ode(eta: CoalRate, r: ScalarLike, times: Float[Array, "intervals 1"]) -> Scalar:
    """Compute the log-likelihood of the data given the demographic model.

    Args:
        eta: Coalescent rate at time t.
        r: float, the recombination rate.
        data: the data to compute the likelihood for. The first column is the TMRCA, and
              the second column is the span.

    Notes:
        - Successive spans that have the same TMRCA should be merged into one span:
          <tmrca, span1> + <tmrca, span1> = <tmrca, span + span>.
        - Missing data/padding indicated by span<=0.
    """
    # times = data.T
    i = times.argsort()
    sorted_times = times[i]

    def f(t, y, _):
        c = eta(t)
        A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
        return A.T @ y

    y0 = jnp.array([1.0, 0.0, 0.0])
    solver = dfx.Tsit5()
    term = dfx.ODETerm(f)
    ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=eta.jumps)
    T = times.max()
    sol = dfx.diffeqsolve(
        term,
        solver,
        0.0,
        T,
        dt0=0.001,
        y0=y0,
        stepsize_controller=ssc,
        saveat=dfx.SaveAt(ts=sorted_times),
    )

    # invert the sorting so that cscs matches times
    i_inv = i.argsort()
    cscs = sol.ys[i_inv]
    return cscs

def plot_iicr_likelihood(demo, data_list, cfg_list, paths, vec_values, recombination_rate=1e-8, t_min=1e-8, num_t=2000, k=2):
    import matplotlib.pyplot as plt

    path_order: List[Var] = list(paths)
    data_pad, deme_names, max_indices, unique_cfg, matching_indices = process_data(data_list, cfg_list)
    unique_times, data_pad, matching_indices, max_indices, associated_indices, unique_groups = reformat_data(data_pad, matching_indices, max_indices)

    num_samples = len(max_indices)
    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    print(global_max)
    t_breaks = jnp.insert(jnp.geomspace(t_min, global_max, num_t), 0, 0.0)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)
    chunking_length = data_pad.shape[1]

    def compute_ode(c, times):
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik_ode(eta, rho, times)
    
    def compute_loglik(data, cscs, max_index, c_map, c_index):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        times, spans = data.T
        @vmap
        def p(t0, csc0, t1, csc1, span):
            p_nr_t0, p_float_t0, p_coal_t0 = csc0
            p_nr_t1, p_float_t1, p_coal_t1 = csc1
            # no recomb for first span - 1 positions
            r1 = xlogy(span - 1, p_nr_t0)
            # coalescence at t1
            r2 = jnp.log(eta(t1))
            # back-coalescence process up to t1, depends to t0 >< t1
            r3 = jnp.where(
                t0 < t1, jnp.log(p_float_t0) - eta.R(t0, t1), jnp.log(p_float_t1)
            )
            return r1 + r2 + r3

        ll = p(times[:-1], cscs[:-1], times[1:], cscs[1:], spans[:-1])
        ll = jnp.dot(ll, jnp.arange(len(times[:-1])) < max_index)
        
        # for the last position, we only know span was at least as long
        ll += xlogy(spans[max_index], cscs[max_index, 0])
        return ll
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)

        c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_breaks, num_samples=dict(zip(deme_names, cfg)))["c"])(
            jnp.array(unique_cfg)
        )
        
        all_cscs = jax.vmap(compute_ode, in_axes=(0, 0))(c_map, unique_times)

        final_cscs = []
        for i in unique_groups:
            final_cscs.append(all_cscs[i][associated_indices[i]].reshape(int(associated_indices[i].shape[0]/chunking_length), chunking_length, 3))

        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(0, 0, 0, None, 0))(data_pad, jnp.vstack(final_cscs), max_indices, c_map, matching_indices)
        return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

    # Outer vmap: Parallelize across vec_values
    # batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # 3. Compute all values (runs on GPU/TPU if available)
    # results = batched_neg_loglik(vec_values) 
    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("IICR Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results


def plot_sfs_likelihood(demo, paths, vec_values, afs, afs_samples, theta=None, sequence_length=None):
    import matplotlib.pyplot as plt

    path_order: List[Var] = list(paths)
    esfs = ExpectedSFS(demo, num_samples=afs_samples)

    def sfs_loglik(afs, esfs, sequence_length, theta):
        afs = afs.flatten()[1:-1]
        esfs = esfs.flatten()[1:-1]
        
        if theta:
            assert(sequence_length)
            tmp = esfs * sequence_length * theta
            return jnp.sum(-tmp + xlogy(afs, tmp))
        else:
            return jnp.sum(xlogy(afs, esfs/esfs.sum()))
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)
        e1 = esfs(params)
        return -sfs_loglik(afs, e1, sequence_length, theta)

    # Outer vmap: Parallelize across vec_values
    # batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # 3. Compute all values (runs on GPU/TPU if available)
    # results = batched_neg_loglik(vec_values) 
    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("SFS Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

def plot_likelihood(demo, data_list, cfg_list, paths, vec_values, afs, afs_samples, theta=None, sequence_length=None, recombination_rate=1e-8, t_min=1e-8, num_t=2000, k=2):
    import matplotlib.pyplot as plt

    path_order: List[Var] = list(paths)
    data_pad, cfg_mat, deme_names, max_indices, unique_cfg, matching_indices = process_data(data_list, cfg_list)
    num_samples = len(max_indices)
    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    print(global_max)
    t_breaks = jnp.insert(jnp.geomspace(t_min, global_max, num_t), 0, 0.0)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)
    esfs = ExpectedSFS(demo, num_samples=afs_samples)

    def compute_loglik(c_map, c_index, data, max_index):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data, max_index)
    
    def sfs_loglik(afs, esfs, sequence_length, theta):
        afs = afs.flatten()[1:-1]
        esfs = esfs.flatten()[1:-1]
        
        if theta:
            assert(sequence_length)
            tmp = esfs * sequence_length * theta
            return jnp.sum(-tmp + xlogy(afs, tmp))
        else:
            return jnp.sum(xlogy(afs, esfs/esfs.sum()))
    
    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)

        c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_breaks, num_samples=dict(zip(deme_names, cfg)))["c"])(
            jnp.array(unique_cfg)
        )
        
        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(c_map, matching_indices, data_pad, max_indices)

        e1 = esfs(params)
        return (-jnp.sum(batched_loglik) / num_samples) + -sfs_loglik(afs, e1, sequence_length, theta) # Same as original neg_loglik

    # Outer vmap: Parallelize across vec_values
    # batched_neg_loglik = vmap(evaluate_at_vec)  # in_axes=0 is default

    # 3. Compute all values (runs on GPU/TPU if available)
    # results = batched_neg_loglik(vec_values) 
    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, 'r-', linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("SFS and IICR Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

def fit(
    demo,
    data_list, 
    cfg_list,
    paths: Params,
    afs,
    afs_samples,
    *,
    k: int = 2,
    t_min: float = 1e-8,
    # t_max: float,
    num_t: int = 2000,
    method: str = "trust-constr",
    options: Optional[dict] = None,
    recombination_rate: float = 1e-8,
    sequence_length: float = None,
    theta: float = None,
):
    data_pad, cfg_mat, deme_names, max_indices, unique_cfg, matching_indices = process_data(data_list, cfg_list)
    num_samples = len(max_indices)

    path_order: List[Var] = list(paths)
    x0 = _dict_to_vec(paths, path_order)
    et = EventTree(demo)

    cons = constraints_for(et, *path_order)
    linear_constraints: list[LinearConstraint] = []

    Aeq, beq = cons["eq"]
    if Aeq.size:
        linear_constraints.append(LinearConstraint(Aeq, beq, beq))

    G, h = cons["ineq"]
    if G.size:
        lower = -jnp.inf * jnp.ones_like(h)
        linear_constraints.append(LinearConstraint(G, lower, h))

    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    t_breaks = jnp.insert(jnp.geomspace(t_min, global_max, num_t), 0, 0.0)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = jax.jit(iicr.__call__)
    esfs = ExpectedSFS(demo, num_samples=afs_samples)

    def compute_loglik(c_map, c_index, data, max_index):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data, max_index)
    
    def sfs_loglik(afs, esfs, sequence_length, theta):
        afs = afs.flatten()[1:-1]
        esfs = esfs.flatten()[1:-1]
        
        if theta:
            assert(sequence_length)
            tmp = esfs * sequence_length * theta
            return jnp.sum(-tmp + xlogy(afs, tmp))
        else:
            return jnp.sum(xlogy(afs, esfs/esfs.sum()))
    
    @jax.value_and_grad
    def neg_loglik(vec):
        params = _vec_to_dict_jax(vec, path_order)
        c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_breaks, num_samples=dict(zip(deme_names, cfg)))["c"])(
            jnp.array(unique_cfg)
        )
        
        # Batched over cfg_mat and all_tmrca_spans 
        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(c_map, matching_indices, data_pad, max_indices)
        
        likelihood = jnp.sum(batched_loglik)
        e1 = esfs(params)

        return (-likelihood / num_samples) + -sfs_loglik(afs, e1, sequence_length, theta)

    res = minimize(
        fun=lambda x: float(neg_loglik(x)[0]),
        # fun=lambda x: float(neg_loglik(x)),
        x0=jnp.asarray(x0),
        jac=lambda x: jnp.asarray(neg_loglik(x)[1], dtype=float),
        method=method,
        # bounds = [(3000. / 5000., 7000. / 5000.)],
        constraints=linear_constraints,
    )

    return _vec_to_dict(jnp.asarray(res.x), path_order)


Running the snake model simulation: Please note that this simulation takes like 6-7 minutes

In [None]:
import msprime as msp
import demes
import demesdraw
import numpy as np

# Create demography object
demo = msp.Demography()

# Add populations
demo.add_population(initial_size=4000, name="anc")
demo.add_population(initial_size=500, name="P0", growth_rate=-np.log(3000 / 500)/66)
demo.add_population(initial_size=500, name="P1", growth_rate=-np.log(3000 / 500)/66)
demo.add_population(initial_size=100, name="P2", growth_rate=-np.log(3000 / 100)/66)
demo.add_population(initial_size=800, name="P3", growth_rate=-np.log(3000 / 800)/66)
demo.add_population(initial_size=500, name="P4", growth_rate=-np.log(3000 / 500)/66)

# Set initial migration rate
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P1", "P2"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P2", "P3"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P3", "P4"), rate=0.0001)


# population growth at 500 generations
demo.add_population_parameters_change(
    time=66,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P0",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=66,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P1",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=66,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P2",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=66,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P3",
    growth_rate=0
)
demo.add_population_parameters_change(
    time=66,
    initial_size=3000,  # Bottleneck: reduce to 1000 individuals
    population="P4",
    growth_rate=0
)

# Migration rate change changed to 0.001 AFTER 500 generation (going into the past)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P0",
    dest="P1"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P1",
    dest="P0"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P1",
    dest="P2"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P2",
    dest="P1"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P2",
    dest="P3"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P3",
    dest="P2"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P3",
    dest="P4"
)
demo.add_migration_rate_change(
    time=66,
    rate=0.0005, 
    source="P4",
    dest="P3"
)

# THEN add the older events (population split at 1000)
demo.add_population_split(time=5000, derived=["P0", "P1", "P2", "P3", "P4"], ancestral="anc")

# Visualize the demography
g = demo.to_demes()
demesdraw.tubes(g, log_time=True)

sample_size = 10
samples = {f"P{i}": sample_size for i in range(5)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=3.94 * 1e-8, sequence_length=1e8, random_seed = 12)
ts = msp.sim_mutations(anc, rate=2.54 * 1e-8, random_seed = 12)

afs_samples = {f"P{i}": sample_size*2 for i in range(5)}
afs = ts.allele_frequency_spectrum(sample_sets=[ts.samples([1]), ts.samples([2]), ts.samples([3]), ts.samples([4]), ts.samples([5])], span_normalise=False)

pulling out 400 samples. This takes 10min 40s, if you don't want to wait that long just decrease num_samples

In [None]:
%time data_list, cfg_list = get_tmrca_data(ts, key=jax.random.PRNGKey(42), num_samples=400, option="random")

The next piece is just running everything so I can play around with the likelihood computation

In [None]:
paths = {
    ('migrations', 0, 'rate'): 4000.,
}

path_order: List[Var] = list(paths)
data_pad, deme_names, max_indices, unique_cfg, matching_indices = process_data(data_list, cfg_list)
unique_times, data_pad, matching_indices, max_indices, associated_indices, unique_groups = reformat_data(data_pad, matching_indices, max_indices)

num_samples = len(max_indices)
first_columns = data_pad[:, :, 0]
# Compute global max (single float value)
global_max = jnp.max(first_columns)
print(global_max)
t_breaks = jnp.insert(jnp.geomspace(1e-8, global_max, 2000), 0, 0.0)
rho = 1e-8
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)
chunking_length = data_pad.shape[1]

vec = [0.1]
vec_array = jnp.atleast_1d(vec)
params = _vec_to_dict_jax(vec_array, path_order)

# def compute_ode(c, times):
#     eta = PiecewiseConstant(c=c, t=t_breaks)
#     return loglik_ode(eta, rho, times)

def compute_loglik(data, cscs, max_index, c_map, c_index):
    c = c_map[c_index]
    eta = PiecewiseConstant(c=c, t=t_breaks)
    times, spans = data.T
    @vmap
    def p(t0, csc0, t1, csc1, span):
        p_nr_t0, p_float_t0, p_coal_t0 = csc0
        p_nr_t1, p_float_t1, p_coal_t1 = csc1
        # no recomb for first span - 1 positions
        r1 = xlogy(span - 1, p_nr_t0)
        # coalescence at t1
        r2 = jnp.log(eta(t1))
        # back-coalescence process up to t1, depends to t0 >< t1
        r3 = jnp.where(
            t0 < t1, jnp.log(p_float_t0) - eta.R(t0, t1), jnp.log(p_float_t1)
        )
        return r1 + r2 + r3

    ll = p(times[:-1], cscs[:-1], times[1:], cscs[1:], spans[:-1])
    ll = jnp.dot(ll, jnp.arange(len(times[:-1])) < max_index)

    # for the last position, we only know span was at least as long
    ll += xlogy(spans[max_index], cscs[max_index, 0])
    return ll

def evaluate_at_vec(vec):
    vec_array = jnp.atleast_1d(vec)
    params = _vec_to_dict_jax(vec_array, path_order)
    def compute_c(sample_config, times):
        # Convert sample_config (array) to dictionary of population sizes
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}

        # Compute IICR and log-likelihood
        c = iicr_call(params=params, t=t_breaks, num_samples=ns)["c"]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        cscs = loglik_ode(eta, rho, times)
        return c, cscs
    c_map, all_cscs = vmap(compute_c, in_axes=(0, 0))(unique_cfg, unique_times)


#     c_map = jax.vmap(lambda cfg: iicr_call(params=params, t=t_breaks, num_samples=dict(zip(deme_names, cfg)))["c"])(
#         jnp.array(unique_cfg)
#     )

#    all_cscs = jax.vmap(compute_ode, in_axes=(0, 0))(c_map, unique_times)

    final_cscs = []
    for i in unique_groups:
        final_cscs.append(all_cscs[i][associated_indices[i]].reshape(int(associated_indices[i].shape[0]/chunking_length), chunking_length, 3))

    # Batched over cfg_mat and all_tmrca_spans 
    batched_loglik = vmap(compute_loglik, in_axes=(0, 0, 0, None, 0))(data_pad, jnp.vstack(final_cscs), max_indices, c_map, matching_indices)
    return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

vec_values = jnp.linspace(0.0001, 0.0010, 10)

Timing how long it takes to run one likelihood evaluation

In [None]:
%time evaluate_at_vec([0.0005])

Timing how long it takes to run 10 likelihood evaluations with lax.map

In [None]:
%time results = lax.map(evaluate_at_vec, vec_values)