In [1]:
import demes
import demesdraw
import msprime as msp

demo = msp.Demography()
demo.add_population(initial_size=5000, name="anc")
demo.add_population(initial_size=5000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.add_population(initial_size=5000, name="P2")
demo.add_population(initial_size=5000, name="P3")
demo.add_population(initial_size=5000, name="P4")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P1", "P2"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P2", "P3"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P3", "P4"), rate=0.0001)
tmp = [f"P{i}" for i in range(5)]
demo.add_population_split(time=4000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(5)}
anc = msp.sim_ancestry(
    samples=samples,
    demography=demo,
    recombination_rate=1e-8,
    sequence_length=1e8,
    random_seed=12,
)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=13)

# demesdraw.tubes(g)

import jax
import jax.numpy as jnp
import jax.random as jr


def compile(ts, subkey, a=None, b=None):
    # using a set to pull out all unique populations that the samples can possibly belong to
    pop_cfg = {
        ts.population(ts.node(n).population).metadata["name"] for n in ts.samples()
    }
    pop_cfg = {pop_name: 0 for pop_name in pop_cfg}

    if a == None and b == None:
        samples = jax.random.choice(subkey, ts.num_samples, shape=(2,), replace=False)
        a, b = samples[0].item(0), samples[1].item(0)

    spans = []
    curr_t = None
    curr_L = 0.0
    for tree in ts.trees():
        L = tree.interval.right - tree.interval.left
        t = tree.tmrca(a, b)
        if curr_t is None or t != curr_t:
            if curr_t is not None:
                spans.append([curr_t, curr_L])
            curr_t = t
            curr_L = L
        else:
            curr_L += L
    spans.append([curr_t, curr_L])
    data = jnp.asarray(spans, dtype=jnp.float64)
    pop_cfg[ts.population(ts.node(a).population).metadata["name"]] += 1
    pop_cfg[ts.population(ts.node(b).population).metadata["name"]] += 1
    return data, pop_cfg


def get_tmrca_data(ts, key=jax.random.PRNGKey(2), num_samples=200, option="random"):
    data_list = []
    cfg_list = []
    key, subkey = jr.split(key)
    if option == "random":
        for i in range(num_samples):
            data, cfg = compile(ts, subkey)
            data_list.append(data)
            cfg_list.append(cfg)
            key, subkey = jr.split(key)
    elif option == "all":
        from itertools import combinations

        all_config = list(combinations(ts.samples(), 2))
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)
    elif option == "unphased":
        all_config = ts.samples().reshape(-1, 2)
        for a, b in all_config:
            data, cfg = compile(ts, subkey, a, b)
            data_list.append(data)
            cfg_list.append(cfg)

    return data_list, cfg_list


def process_data(data_list, cfg_list):
    max_indices = jnp.array([arr.shape[0] - 1 for arr in data_list])
    num_samples = len(max_indices)
    lens = jnp.array([d.shape[0] for d in data_list], dtype=jnp.int32)
    Lmax = int(lens.max())
    Npairs = len(data_list)
    data_pad = jnp.full((Npairs, Lmax, 2), jnp.array([1.0, 0.0]), dtype=jnp.float64)

    for i, d in enumerate(data_list):
        data_pad = data_pad.at[i, : d.shape[0], :].set(d)

    deme_names = cfg_list[0].keys()
    D = len(deme_names)
    cfg_mat = jnp.zeros((num_samples, D), dtype=jnp.int32)
    for i, cfg in enumerate(cfg_list):
        for j, n in enumerate(deme_names):
            cfg_mat = cfg_mat.at[i, j].set(cfg.get(n, 0))

    unique_cfg = jnp.unique(cfg_mat, axis=0)

    # Find matching indices for sampling configs
    def find_matching_index(row, unique_arrays):
        matches = jnp.all(row == unique_arrays, axis=1)
        return jnp.where(matches)[0][0]

    # Vectorize over all rows in `arr`
    matching_indices = jnp.array(
        [find_matching_index(row, unique_cfg) for row in cfg_mat]
    )

    return data_pad, deme_names, max_indices, unique_cfg, matching_indices


def reformat_data(data_pad, matching_indices, max_indices, chunking_length):
    unique_groups = jnp.unique(matching_indices)
    group_unique_times = []
    rearranged_data = []
    new_max_indices = []
    new_matching_indices = []
    associated_indices = []
    group_membership = []

    # Each group is a sampling configuration
    for group in unique_groups:
        positions = jnp.where(
            matching_indices == group
        )  # extract positions matching a group
        group_data = data_pad[
            positions
        ]  # find all tmrca + spans that share a sampling config
        all_first_col = np.array(group_data[:, :, 0].flatten())
        unique_values = np.unique(
            all_first_col
        )  # extra all unique tmrca associated to a sampling config

        unique_value_to_index = {value: idx for idx, value in enumerate(unique_values)}
        indices_in_mapping = np.array(
            [unique_value_to_index[value] for value in all_first_col]
        )
        associated_indices.append(
            indices_in_mapping
        )  # figure how ever tmrca in data_pad gets mapped to unique_values

        group_unique_times.append(unique_values)
        rearranged_data.append(group_data)
        new_matching_indices.append(matching_indices[positions])
        new_max_indices.append(max_indices[positions])
        group_membership.append(jnp.full(indices_in_mapping.size, group))

    # Find the maximum length
    max_length = max(len(arr) for arr in group_unique_times)

    # Pad each array with zeros at the end
    padded_unique_times = []
    for arr in group_unique_times:
        pad_length = max_length - len(arr)
        padded = np.pad(arr, (0, pad_length), mode="constant", constant_values=0)
        padded_unique_times.append(padded)

    padded_unique_times = jnp.array(padded_unique_times)
    rearranged_data = jnp.concatenate(rearranged_data, axis=0)
    new_matching_indices = jnp.concatenate(new_matching_indices, axis=0)
    new_max_indices = jnp.concatenate(new_max_indices, axis=0)
    associated_indices = jnp.concatenate(associated_indices, axis=0)
    group_membership = jnp.concatenate(group_membership, axis=0)
    total_elements = group_membership.size
    batch_size = total_elements // chunking_length
    return (
        padded_unique_times,
        rearranged_data,
        new_matching_indices,
        new_max_indices,
        associated_indices,
        unique_groups,
        batch_size,
        group_membership,
    )


data_list, cfg_list = get_tmrca_data(
    ts, key=jax.random.PRNGKey(42), num_samples=400, option="random"
)

  data = jnp.asarray(spans, dtype=jnp.float64)


In [13]:
from typing import Any, Dict, List, Mapping, Optional, Sequence, Set, Tuple

import numpy as np
from loguru import logger

from demesinfer.fit.util import _vec_to_dict, _vec_to_dict_jax
from demesinfer.iicr import IICRCurve

logger.disable("demesinfer")

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

import diffrax as dfx
from jax import vmap
from jax.scipy.special import xlog1py, xlogy
from jaxtyping import Array, Float, Scalar, ScalarLike

paths = {
    frozenset(
        {("demes", 0, "epochs", 0, "end_size"), ("demes", 0, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 1, "epochs", 0, "end_size"), ("demes", 1, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 2, "epochs", 0, "end_size"), ("demes", 2, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 3, "epochs", 0, "end_size"), ("demes", 3, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 4, "epochs", 0, "end_size"), ("demes", 4, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 5, "epochs", 0, "end_size"), ("demes", 5, "epochs", 0, "start_size")}
    ): 5000.0,
    # ('migrations', 0, 'rate'):0.0001,
    # ('migrations', 1, 'rate'):0.0001,
    # ('migrations', 2, 'rate'):0.0001,
    # ('migrations', 3, 'rate'):0.0001,
    # ('migrations', 4, 'rate'):0.0001,
    # ('migrations', 5, 'rate'):0.0001,
    # ('migrations', 6, 'rate'):0.0001,
    # ('migrations', 7, 'rate'):0.0001,
    frozenset(
        {
            ("demes", 0, "epochs", 0, "end_time"),
            ("demes", 1, "start_time"),
            ("demes", 2, "start_time"),
            ("demes", 3, "start_time"),
            ("demes", 4, "start_time"),
            ("demes", 5, "start_time"),
            ("migrations", 0, "start_time"),
            ("migrations", 1, "start_time"),
            ("migrations", 2, "start_time"),
            ("migrations", 3, "start_time"),
            ("migrations", 4, "start_time"),
            ("migrations", 5, "start_time"),
            ("migrations", 6, "start_time"),
            ("migrations", 7, "start_time"),
        }
    ): 4000.0,
}

path_order: List[Var] = list(paths)

data_pad, deme_names, max_indices, unique_cfg, matching_indices = process_data(
    data_list, cfg_list
)
chunking_length = data_pad.shape[1]
(
    unique_times,
    rearranged_data,
    new_matching_indices,
    new_max_indices,
    associated_indices,
    unique_groups,
    batch_size,
    group_membership,
) = reformat_data(data_pad, matching_indices, max_indices, chunking_length)

num_samples = len(max_indices)
rho = recombination_rate = 1e-8
iicr = IICRCurve(demo=g, k=2)
iicr_call = iicr.curve

vec = jnp.array([5000.0, 5000.0, 5000.0, 5000.0, 5000.0, 5000.0, 4000.0])
params = _vec_to_dict_jax(vec, path_order)


def compute_c(params, sample_config, deme_names, iicr_call, times):
    ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
    single_config_func = iicr_call(params=params, num_samples=ns)
    dictionary = jax.vmap(single_config_func, in_axes=(0))(times)
    return dictionary["c"], dictionary["log_s"]


def loglik_ode(
    params,
    sample_config,
    deme_names,
    iicr_call,
    r: ScalarLike,
    times: Float[Array, "intervals 1"],
) -> Scalar:
    # times = data.T
    i = times.argsort()
    sorted_times = times[i]
    ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
    single_config_func = iicr_call(params=params, num_samples=ns)

    def f(t, y, _):
        c = jnp.clip(single_config_func(t)["c"], a_min=1e-21, a_max=None)
        A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
        return A.T @ y

    y0 = jnp.array([1.0, 0.0, 0.0])
    solver = dfx.Tsit5()
    term = dfx.ODETerm(f)
    ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=single_config_func.jump_ts)
    # ssc = dfx.PIDController(rtol=1e-6, atol=1e-6)
    T = times.max()
    sol = dfx.diffeqsolve(
        term,
        solver,
        0.0,
        T,
        dt0=0.001,
        y0=y0,
        stepsize_controller=ssc,
        saveat=dfx.SaveAt(ts=sorted_times),
    )

    # invert the sorting so that cscs matches times
    i_inv = i.argsort()
    cscs = sol.ys[i_inv]
    return cscs


def compute_loglik(data, cscs, max_index, c_map, log_s):
    times, spans = data.T

    @vmap
    def p(t0, csc0, t1, csc1, span, c_rate_t1, log_s_t0, log_s_t1):
        p_nr_t0, p_float_t0, p_coal_t0 = csc0
        p_nr_t1, p_float_t1, p_coal_t1 = csc1
        # no recomb for first span - 1 positions
        r1 = xlogy(span - 1, p_nr_t0)
        # coalescence at t1
        r2 = jnp.log(c_rate_t1)
        # back-coalescence process up to t1, depends to t0 >< t1
        r3 = jnp.where(
            t0 < t1,
            jnp.log(p_float_t0) + log_s_t1 - log_s_t0,
            jnp.log(
                p_float_t1
            ),  # t0 < t1, jnp.log(p_float_t0) - eta.R(t0, t1), jnp.log(p_float_t1)
        )
        return r1 + r2 + r3

    ll = p(
        times[:-1],
        cscs[:-1],
        times[1:],
        cscs[1:],
        spans[:-1],
        c_map[1:],
        log_s[:-1],
        log_s[1:],
    )
    ll = jnp.dot(ll, jnp.arange(len(times[:-1])) < max_index)

    # for the last position, we only know span was at least as long
    ll += xlogy(spans[max_index], cscs[max_index, 0])
    return ll

## Time it takes to compute c_map

In [14]:
%time c_map, log_s = jax.vmap(compute_c, in_axes=(None, 0, None, None, 0))(params, unique_cfg, deme_names, iicr_call, unique_times)

CPU times: user 4.18 s, sys: 255 ms, total: 4.44 s
Wall time: 4.77 s


## Time it takes to compute cscs

In [15]:
%time all_cscs = jax.vmap(loglik_ode, in_axes=(None, 0, None, None, None, 0))(params, unique_cfg, deme_names, iicr_call, rho, unique_times)

CPU times: user 33.7 s, sys: 1.87 s, total: 35.6 s
Wall time: 36.7 s


## Time to compute the likelihood

In [16]:
extracted = all_cscs[group_membership, associated_indices]
final_cscs_flat = extracted.reshape(batch_size, chunking_length, 3)
log_s = log_s[group_membership, associated_indices].reshape(batch_size, chunking_length)
c_map = c_map[group_membership, associated_indices].reshape(batch_size, chunking_length)

In [17]:
%time batched_loglik = vmap(compute_loglik, in_axes=(0, 0, 0, 0, 0))(rearranged_data, final_cscs_flat, new_max_indices, c_map, log_s)

CPU times: user 43.2 ms, sys: 6.92 ms, total: 50.1 ms
Wall time: 50.8 ms


## Time to compute likelihood with everything put together

In [18]:
def compute_ode(
    params,
    sample_config,
    deme_names,
    iicr_call,
    r: ScalarLike,
    times: Float[Array, "intervals 1"],
) -> Scalar:
    # times = data.T
    i = times.argsort()
    sorted_times = times[i]
    ns = {name: sample_config[i] for i, name in enumerate(deme_names)}
    single_config_func = iicr_call(params=params, num_samples=ns)
    dictionary = jax.vmap(single_config_func, in_axes=(0))(times)

    def f(t, y, _):
        c = jnp.clip(single_config_func(t)["c"], a_min=1e-21, a_max=None)
        A = jnp.array([[-r, r, 0.0], [c, -2 * c, c], [0.0, 0.0, 0.0]])
        return A.T @ y

    y0 = jnp.array([1.0, 0.0, 0.0])
    solver = dfx.Tsit5()
    term = dfx.ODETerm(f)
    ssc = dfx.PIDController(rtol=1e-6, atol=1e-6, jump_ts=single_config_func.jump_ts)
    # ssc = dfx.PIDController(rtol=1e-6, atol=1e-6)
    T = times.max()
    sol = dfx.diffeqsolve(
        term,
        solver,
        0.0,
        T,
        dt0=0.001,
        y0=y0,
        stepsize_controller=ssc,
        saveat=dfx.SaveAt(ts=sorted_times),
    )

    # invert the sorting so that cscs matches times
    i_inv = i.argsort()
    cscs = sol.ys[i_inv]
    return cscs, jnp.clip(dictionary["c"], a_min=1e-21, a_max=None), dictionary["log_s"]


def neg_loglik(
    vec,
    path_order,
    unique_cfg,
    deme_names,
    iicr_call,
    rho,
    unique_times,
    group_membership,
    associated_indices,
    batch_size,
    chunking_length,
    rearranged_data,
    new_max_indices,
    num_samples,
):
    params = _vec_to_dict_jax(vec, path_order)
    all_cscs, c_map, log_s = jax.vmap(
        compute_ode, in_axes=(None, 0, None, None, None, 0)
    )(params, unique_cfg, deme_names, iicr_call, rho, unique_times)

    jax.debug.print("Param values: {}", jnp.array(list(params.values())))

    extracted = all_cscs[group_membership, associated_indices]
    final_cscs_flat = extracted.reshape(batch_size, chunking_length, 3)
    log_s = log_s[group_membership, associated_indices].reshape(
        batch_size, chunking_length
    )
    c_map = c_map[group_membership, associated_indices].reshape(
        batch_size, chunking_length
    )
    batched_loglik = vmap(compute_loglik, in_axes=(0, 0, 0, 0, 0))(
        rearranged_data, final_cscs_flat, new_max_indices, c_map, log_s
    )
    loss = -jnp.sum(batched_loglik) / num_samples
    jax.debug.print("Loss: {}", loss)
    return loss

In [19]:
%time neg_loglikelihood = neg_loglik(vec, path_order, unique_cfg, deme_names, iicr_call, rho, unique_times, group_membership, associated_indices, batch_size, chunking_length, rearranged_data, new_max_indices, num_samples)

Param values: [5000. 5000. 5000. 5000. 5000. 5000. 4000.]
Loss: 343287.94309517136
CPU times: user 33.5 s, sys: 2.31 s, total: 35.8 s
Wall time: 37 s


## Attempt to take the gradient

In [None]:
%time gradient = jax.grad(neg_loglik)(vec, path_order, unique_cfg, deme_names, iicr_call, rho, unique_times, group_membership, associated_indices, batch_size, chunking_length, rearranged_data, new_max_indices, num_samples)

## Computing c_map for migration

In [22]:
paths = {
    ("migrations", 0, "rate"): 0.0001,
    ("migrations", 1, "rate"): 0.0001,
    ("migrations", 2, "rate"): 0.0001,
    ("migrations", 3, "rate"): 0.0001,
    ("migrations", 4, "rate"): 0.0001,
    ("migrations", 5, "rate"): 0.0001,
    ("migrations", 6, "rate"): 0.0001,
    ("migrations", 7, "rate"): 0.0001,
}
path_order: List[Var] = list(paths)
vec = jnp.array([0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001, 0.0001])
params = _vec_to_dict_jax(vec, path_order)

In [23]:
%time c_map, log_s = jax.vmap(compute_c, in_axes=(None, 0, None, None, 0))(params, unique_cfg, deme_names, iicr_call, unique_times)

CPU times: user 3.9 s, sys: 204 ms, total: 4.1 s
Wall time: 4.39 s


## Error computing cmap jointly with migration and frozensets

In [24]:
paths = {
    frozenset(
        {("demes", 0, "epochs", 0, "end_size"), ("demes", 0, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 1, "epochs", 0, "end_size"), ("demes", 1, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 2, "epochs", 0, "end_size"), ("demes", 2, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 3, "epochs", 0, "end_size"), ("demes", 3, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 4, "epochs", 0, "end_size"), ("demes", 4, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 5, "epochs", 0, "end_size"), ("demes", 5, "epochs", 0, "start_size")}
    ): 5000.0,
    ("migrations", 0, "rate"): 0.0001,
    ("migrations", 1, "rate"): 0.0001,
    ("migrations", 2, "rate"): 0.0001,
    ("migrations", 3, "rate"): 0.0001,
    ("migrations", 4, "rate"): 0.0001,
    ("migrations", 5, "rate"): 0.0001,
    ("migrations", 6, "rate"): 0.0001,
    ("migrations", 7, "rate"): 0.0001,
    frozenset(
        {
            ("demes", 0, "epochs", 0, "end_time"),
            ("demes", 1, "start_time"),
            ("demes", 2, "start_time"),
            ("demes", 3, "start_time"),
            ("demes", 4, "start_time"),
            ("demes", 5, "start_time"),
            ("migrations", 0, "start_time"),
            ("migrations", 1, "start_time"),
            ("migrations", 2, "start_time"),
            ("migrations", 3, "start_time"),
            ("migrations", 4, "start_time"),
            ("migrations", 5, "start_time"),
            ("migrations", 6, "start_time"),
            ("migrations", 7, "start_time"),
        }
    ): 4000.0,
}
vec = jnp.array(
    [
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        4000.0,
    ]
)
path_order: List[Var] = list(paths)
params = _vec_to_dict_jax(vec, path_order)

In [25]:
%time c_map, log_s = jax.vmap(compute_c, in_axes=(None, 0, None, None, 0))(params, unique_cfg, deme_names, iicr_call, unique_times)

CPU times: user 1.03 ms, sys: 990 Î¼s, total: 2.02 ms
Wall time: 7.23 ms


ValueError: Comparator raised exception while sorting pytree dictionary keys.