I have the tskit object saved and ready to be loaded. Next chunk is the data generating code, just uncomment it if you want to generate data yourself, it takes approx 1 min to run.

In [None]:
# import msprime as msp
# import demes
# import demesdraw
# import numpy as np

# # Create demography object
# demo = msp.Demography()

# # Add populations
# demo.add_population(initial_size=4000, name="anc")
# demo.add_population(initial_size=500, name="P0", growth_rate=-np.log(3000 / 500)/66)
# demo.add_population(initial_size=500, name="P1", growth_rate=-np.log(3000 / 500)/66)

# # Set initial migration rate
# demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)

# # population growth at 500 generations
# demo.add_population_parameters_change(
#     time=66,
#     initial_size=3000,  # Bottleneck: reduce to 1000 individuals
#     population="P0",
#     growth_rate=0
# )
# demo.add_population_parameters_change(
#     time=66,
#     initial_size=3000,  # Bottleneck: reduce to 1000 individuals
#     population="P1",
#     growth_rate=0
# )

# # Migration rate change changed to 0.001 AFTER 500 generation (going into the past)
# demo.add_migration_rate_change(
#     time=66,
#     rate=0.0005,
#     source="P0",
#     dest="P1"
# )
# demo.add_migration_rate_change(
#     time=66,
#     rate=0.0005,
#     source="P1",
#     dest="P0"
# )

# # THEN add the older events (population split at 1000)
# demo.add_population_split(time=5000, derived=["P0", "P1"], ancestral="anc")

# # Visualize the demography
# g = demo.to_demes()
# demesdraw.tubes(g, log_time=True)

# sample_size = 10
# samples = {f"P{i}": sample_size for i in range(2)}
# anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=3.94 * 1e-8, sequence_length=1e8, random_seed = 12)
# ts = msp.sim_mutations(anc, rate=2.54 * 1e-8, random_seed = 12)from __future__ import annotations


In [None]:
import tskit

ts = tskit.load("two_population_with_bottlneck.trees")

next chunk is just loading all the functions I wrote

In [None]:
from typing import Any, Dict, List, Mapping, Sequence, Set, Tuple

import diffrax as dfx
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
from jax import lax, vmap
from jax.scipy.special import xlogy
from jaxtyping import Array, Float, Scalar, ScalarLike

from demestats.coal_rate import PiecewiseConstant
from demestats.iicr import IICRCurve

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]


def _dict_to_vec(d: Params, keys: Sequence[Var]) -> jnp.ndarray:
    return jnp.asarray([d[k] for k in keys], dtype=jnp.float64)


def _vec_to_dict_jax(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, jnp.ndarray]:
    return {k: v[i] for i, k in enumerate(keys)}


def _vec_to_dict(v: jnp.ndarray, keys: Sequence[Var]) -> Dict[Var, float]:
    return {k: float(v[i]) for i, k in enumerate(keys)}


def compile(ts, subkey, a=None, b=None):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {
        ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()
    }
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    if a is None and b is None:
        samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
        a, b = samples[0].item(0), samples[1].item(0)

    spans = []
    curr_t = None
    curr_L = 0.0
    for tree in ts.trees():
        L = tree.interval.right - tree.interval.left
        t = tree.tmrca(a, b)
        if curr_t is None or t != curr_t:
            if curr_t is not None:
                spans.append([curr_t, curr_L])
            curr_t = t
            curr_L = L
        else:
            curr_L += L
    spans.append([curr_t, curr_L])
    data = jnp.asarray(spans, dtype=jnp.float64)
    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return data, pop_cfg


def get_tmrca_data(ts, key=jax.random.PRNGKey(2), num_samples=200, option="random"):
    data_list = []
    cfg_list = []
    key, subkey = jr.split(key)
    if option == "random":
        for i in range(num_samples):
            data, cfg = compile(ts, subkey)
            data_list.append(data)
            cfg_list.append(cfg)
            key, subkey = jr.split(key)
    elif option == "all":
        from itertools import combinations

        all_config = list(combinations(ts.samples(), 2))
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)
    elif option == "unphased":
        all_config = ts.samples().reshape(-1, 2)
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)

    return data_list, cfg_list


def process_data(data_list, cfg_list):
    max_indices = jnp.array([arr.shape[0] - 1 for arr in data_list])
    num_samples = len(max_indices)
    lens = jnp.array([d.shape[0] for d in data_list], dtype=jnp.int32)
    Lmax = int(lens.max())
    Npairs = len(data_list)
    data_pad = jnp.full((Npairs, Lmax, 2), jnp.array([1.0, 0.0]), dtype=jnp.float64)

    for i, d in enumerate(data_list):
        data_pad = data_pad.at[i, : d.shape[0], :].set(d)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices for sampling configs
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array(
        [find_matching_index(row, unique_cfg) for row in cfg_mat]
    )

    return data_pad, deme_names, max_indices, unique_cfg, matching_indices


def reformat_data(data_pad, matching_indices, max_indices, chunking_length):
    unique_groups = jnp.unique(matching_indices)
    group_unique_times = []
    rearranged_data = []
    new_max_indices = []
    new_matching_indices = []
    associated_indices = []
    group_membership = []

    # Each group is a sampling configuration
    for group in unique_groups:
        positions = jnp.where(
            matching_indices == group
        )  # extract positions matching a group
        group_data = data_pad[
            positions
        ]  # find all tmrca + spans that share a sampling config
        all_first_col = np.array(group_data[:, :, 0].flatten())
        unique_values = np.unique(
            all_first_col
        )  # extra all unique tmrca associated to a sampling config

        unique_value_to_index = {value: idx for idx, value in enumerate(unique_values)}
        indices_in_mapping = np.array(
            [unique_value_to_index[value] for value in all_first_col]
        )
        associated_indices.append(
            indices_in_mapping
        )  # figure how every tmrca in data_pad gets mapped to unique_values

        group_unique_times.append(unique_values)
        rearranged_data.append(group_data)
        new_matching_indices.append(matching_indices[positions])
        new_max_indices.append(max_indices[positions])
        group_membership.append(jnp.full(indices_in_mapping.size, group))

    # Find the maximum length
    max_length = max(len(arr) for arr in group_unique_times)

    # Pad each array with zeros at the end
    padded_unique_times = []
    for arr in group_unique_times:
        pad_length = max_length - len(arr)
        padded = np.pad(arr, (0, pad_length), mode="constant", constant_values=0)
        padded_unique_times.append(padded)

    padded_unique_times = jnp.array(padded_unique_times)
    rearranged_data = jnp.concatenate(rearranged_data, axis=0)
    new_matching_indices = jnp.concatenate(new_matching_indices, axis=0)
    new_max_indices = jnp.concatenate(new_max_indices, axis=0)
    associated_indices = jnp.concatenate(associated_indices, axis=0)
    group_membership = jnp.concatenate(group_membership, axis=0)
    total_elements = group_membership.size
    batch_size = total_elements // chunking_length
    return (
        padded_unique_times,
        rearranged_data,
        new_matching_indices,
        new_max_indices,
        associated_indices,
        unique_groups,
        batch_size,
        group_membership,
    )


def plot_iicr_likelihood(
    demo,
    data_list,
    cfg_list,
    paths,
    vec_values,
    recombination_rate=1e-8,
    t_min=1e-8,
    num_t=2000,
    k=2,
):
    import matplotlib.pyplot as plt

    path_order: List[Var] = list(paths)
    data_pad, deme_names, max_indices, unique_cfg, matching_indices = process_data(
        data_list, cfg_list
    )
    # chunking_length = data_pad.shape[1]
    # unique_times, rearranged_data, new_matching_indices, new_max_indices, associated_indices, unique_groups, batch_size, group_membership = reformat_data(data_pad, matching_indices, max_indices, chunking_length)

    num_samples = len(max_indices)
    rho = recombination_rate
    iicr = IICRCurve(demo=demo, k=k)
    iicr_call = iicr.curve

    first_columns = data_pad[:, :, 0]
    # Compute global max (single float value)
    global_max = jnp.max(first_columns)
    print(global_max)
    t_breaks = jnp.insert(jnp.geomspace(t_min, global_max, num_t), 0, 0.0)

    def loglik(
        eta: Callable[[ScalarLike], ScalarLike],
        r: ScalarLike,
        data: Float[Array, "intervals 2"],
        max_index: Array,
    ) -> Scalar:
        """Compute the log-likelihood of the data given the demographic model.

        Args:
            eta: Coalescent rate at time t.
            r: float, the recombination rate.
            data: the data to compute the likelihood for. The first column is the TMRCA, and
                  the second column is the span.

        Notes:
            - Successive spans that have the same TMRCA should be merged into one span:
              <tmrca, span1> + <tmrca, span1> = <tmrca, span + span>.
            - Missing data/padding indicated by span<=0.
        """
        times, spans = data.T
        i = times.argsort()
        sorted_times = times[i]

        def f(t, y, _):
            c = jnp.where(eta(t) < 0, 1e-30, eta(t))
            A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
            return A.T @ y

        y0 = jnp.array([1.0, 0.0, 0.0])
        solver = dfx.Tsit5()
        term = dfx.ODETerm(f)
        ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=eta.t)
        T = times.max()
        sol = dfx.diffeqsolve(
            term,
            solver,
            0.0,
            T,
            dt0=0.001,
            y0=y0,
            stepsize_controller=ssc,
            saveat=dfx.SaveAt(ts=sorted_times),
        )

        # invert the sorting so that cscs matches times
        i_inv = i.argsort()
        cscs = sol.ys[i_inv]

        @vmap
        def p(t0, csc0, t1, csc1, span):
            p_nr_t0, p_float_t0, p_coal_t0 = csc0
            p_nr_t1, p_float_t1, p_coal_t1 = csc1
            # no recomb for first span - 1 positions
            r1 = xlogy(span - 1, p_nr_t0)
            # coalescence at t1
            r2 = jnp.log(jnp.where(eta(t1) < 0, 1e-30, eta(t1)))
            # back-coalescence process up to t1, depends to t0 >< t1
            r3 = jnp.where(
                t0 < t1,
                jnp.log(p_float_t0) - (eta.R(t1) - eta.R(t0)),
                jnp.log(p_float_t1),
            )
            return r1 + r2 + r3

        ll = p(times[:-1], cscs[:-1], times[1:], cscs[1:], spans[:-1])
        ll = jnp.dot(ll, jnp.arange(len(times[:-1])) < max_index)

        # for the last position, we only know span was at least as long
        ll += xlogy(spans[max_index], cscs[max_index, 0])
        return ll

    def get_c(params, sample_config, times):
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
        single_config_func = iicr_call(params=params, num_samples=ns)
        dictionary = jax.vmap(single_config_func, in_axes=(0))(times)
        return dictionary["c"], dictionary["log_s"]

    def compute_loglik(c_map, c_index, data, max_index):
        c = c_map[c_index]
        eta = PiecewiseConstant(c=c, t=t_breaks)
        return loglik(eta, rho, data, max_index)

    def evaluate_at_vec(vec):
        vec_array = jnp.atleast_1d(vec)
        params = _vec_to_dict_jax(vec_array, path_order)

        c_map = []
        log_s = []
        for j in range(len(unique_cfg)):
            sample_config = unique_cfg[j]
            ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
            single_config_func = iicr_call(params=params, num_samples=ns)
            dictionary = jax.vmap(single_config_func, in_axes=(0))(t_breaks)
            c_map.append(dictionary["c"])
            log_s.append(dictionary["log_s"])

        c_map = jnp.array(c_map)
        log_s = jnp.array(log_s)

        batched_loglik = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(
            c_map, matching_indices, data_pad, max_indices
        )
        return -jnp.sum(batched_loglik) / num_samples  # Same as original neg_loglik

    results = lax.map(evaluate_at_vec, vec_values)

    # 4. Plot
    plt.figure(figsize=(10, 6))
    plt.plot(vec_values, results, "r-", linewidth=2)
    plt.xlabel("vec value")
    plt.ylabel("Negative Log-Likelihood")
    plt.title("IICR Likelihood Landscape")
    plt.grid(True)
    plt.show()

    return results

In [None]:
# This pulls out a SINGLE (num_samples = 1) sampling configuration that has like 55 thousand (tmrca, span)
data_list, cfg_list = get_tmrca_data(
    ts, key=jax.random.PRNGKey(42), num_samples=1, option="random"
)

Next code chunk is ONLY just plots the likelihood in which you get this big zig-zag likelihood plot. It takes 51 seconds to run on GL.

In [None]:
# import jax.numpy as jnp
# from loguru import logger
# logger.disable("demestats")

# paths = {
#     frozenset({('demes', 1, 'epochs', 0, 'end_time'),
#             ('demes', 2, 'epochs', 0, 'end_time'),
#             ('migrations', 0, 'end_time'),
#             ('migrations', 1, 'end_time'),
#             ('migrations', 2, 'start_time'),
#             ('migrations', 3, 'start_time')}): 4000.,
# }
# vec_values = jnp.linspace(5, 120, 25)
# %time result = plot_iicr_likelihood(g, data_list, cfg_list, paths, vec_values, recombination_rate=3.94 * 1e-8, num_t=30)

In [None]:
def loglik(
    eta: Callable[[ScalarLike], ScalarLike],
    r: ScalarLike,
    data: Float[Array, "intervals 2"],
    max_index: Array,
) -> Scalar:
    """Compute the log-likelihood of the data given the demographic model.

    Args:
        eta: Coalescent rate at time t.
        r: float, the recombination rate.
        data: the data to compute the likelihood for. The first column is the TMRCA, and
              the second column is the span.

    Notes:
        - Successive spans that have the same TMRCA should be merged into one span:
          <tmrca, span1> + <tmrca, span1> = <tmrca, span + span>.
        - Missing data/padding indicated by span<=0.
    """
    times, spans = data.T
    i = times.argsort()
    sorted_times = times[i]

    def f(t, y, _):
        c = eta(t)
        A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
        return A.T @ y

    y0 = jnp.array([1.0, 0.0, 0.0])
    solver = dfx.Tsit5()
    term = dfx.ODETerm(f)
    ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=eta.t)
    T = times.max()
    sol = dfx.diffeqsolve(
        term,
        solver,
        0.0,
        T,
        dt0=0.001,
        y0=y0,
        stepsize_controller=ssc,
        saveat=dfx.SaveAt(ts=sorted_times),
    )

    # invert the sorting so that cscs matches times
    i_inv = i.argsort()
    cscs = sol.ys[i_inv]

    @vmap
    def p(t0, csc0, t1, csc1, span):
        p_nr_t0, p_float_t0, p_coal_t0 = csc0
        p_nr_t1, p_float_t1, p_coal_t1 = csc1
        # no recomb for first span - 1 positions
        r1 = xlogy(span - 1, p_nr_t0)
        # coalescence at t1
        r2 = jnp.log(eta(t1))
        # back-coalescence process up to t1, depends to t0 >< t1
        r3 = jnp.where(
            t0 < t1, jnp.log(p_float_t0) - (eta.R(t1) - eta.R(t0)), jnp.log(p_float_t1)
        )
        return r1 + r2 + r3, r1, r2, r3, eta.R(t1) - eta.R(t0)

    ll, r1, r2, r3, etaR = p(times[:-1], cscs[:-1], times[1:], cscs[1:], spans[:-1])
    ll = jnp.dot(ll, jnp.arange(len(times[:-1])) < max_index)

    # for the last position, we only know span was at least as long
    ll += xlogy(spans[max_index], cscs[max_index, 0])
    return ll, cscs, r1, r2, r3, eta.t, etaR


def compute_loglik(c_map, c_index, data, max_index):
    c = c_map[c_index]
    eta = PiecewiseConstant(c=c, t=t_breaks)
    return loglik(eta, rho, data, max_index)


paths = {
    frozenset(
        {
            ("demes", 1, "epochs", 0, "end_time"),
            ("demes", 2, "epochs", 0, "end_time"),
            ("migrations", 0, "end_time"),
            ("migrations", 1, "end_time"),
            ("migrations", 2, "start_time"),
            ("migrations", 3, "start_time"),
        }
    ): 4000.0,
}
path_order: List[Var] = list(paths)
data_pad, deme_names, max_indices, unique_cfg, matching_indices = process_data(
    data_list, cfg_list
)

num_samples = len(max_indices)
rho = recombination_rate = 3.94 * 1e-8
iicr = IICRCurve(demo=g, k=2)
iicr_call = iicr.curve

first_columns = data_pad[:, :, 0]
# Compute global max (single float value)
global_max = jnp.max(first_columns)
print(global_max)
t_breaks = jnp.insert(jnp.geomspace(1e-8, global_max, 30), 0, 0.0)

previous code chunk sets up everything up for the data. next code chunk stores ALL the possible computations we ever make for loglik and it takes approximately 2 minutes to run.

In [None]:
from loguru import logger

logger.disable("demestats")

vec_values = jnp.linspace(5, 120, 25)
c_list = []
log_s_list = []
loglik_list = []
cscs_list = []
r1_list = []
r2_list = []
r3_list = []
etaR_list = []

for vec in vec_values:
    vec_array = jnp.atleast_1d(vec)
    params = _vec_to_dict_jax(vec_array, path_order)
    # print(params)
    c_map = []
    log_s = []
    for j in range(len(unique_cfg)):
        sample_config = unique_cfg[j]
        ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
        single_config_func = iicr_call(params=params, num_samples=ns)
        dictionary = jax.vmap(single_config_func, in_axes=(0))(t_breaks)
        c_map.append(dictionary["c"])
        log_s.append(dictionary["log_s"])

    c_map = jnp.array(c_map)
    log_s = jnp.array(log_s)

    c_list.append(c_map)
    log_s_list.append(log_s)

    (
        batched_loglik,
        batched_cscs,
        batched_r1,
        batched_r2,
        batched_r3,
        batched_eta_jumps,
        etaR,
    ) = vmap(compute_loglik, in_axes=(None, 0, 0, 0))(
        c_map, matching_indices, data_pad, max_indices
    )
    loglik_list.append(batched_loglik)
    cscs_list.append(batched_cscs)
    r1_list.append(batched_r1)
    r2_list.append(batched_r2)
    r3_list.append(batched_r3)
    etaR_list.append(etaR)