In [4]:
import tempfile

import bio2zarr.tskit
import msprime as msp
import sgkit

from demestats.fit.read_data import get_het_data

demo = msp.Demography()
demo.add_population(initial_size=5000, name="anc")
demo.add_population(initial_size=5000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time=4000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 4
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(
    samples=samples,
    demography=demo,
    recombination_rate=1e-8,
    sequence_length=1e7,
    random_seed=12,
)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=13)

# demesdraw.tubes(g)


d = tempfile.TemporaryDirectory()
bio2zarr.tskit.convert(ts, d.name + "/ts")
ds = sgkit.load_dataset(d.name + "/ts")

het_matrix, cfg_list = get_het_data(ts, ds, num_samples=5)

In [6]:
from typing import Any, List, Mapping, Set, Tuple

import jax
import jax.numpy as jnp
from jax import vmap
from loguru import logger
from phlashlib.iicr import PiecewiseConstant
from phlashlib.loglik import loglik

from demestats.fit.util import _vec_to_dict_jax, process_data
from demestats.iicr import IICRCurve

paths = {
    ("migrations", 0, "rate"): 0.0009,
    ("migrations", 1, "rate"): 0.0009,
}


logger.disable("demestats")

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

path_order: List[Var] = list(paths)
cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
num_samples = len(cfg_mat)
rho = theta = 1e-8 * 100
k = 2
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)


def process_base_model(deme_names, cfg):
    curve = iicr.curve(num_samples=dict(zip(deme_names, cfg)))
    timepoints = jax.vmap(curve.quantile)(jnp.linspace(0, 1, 32)[1:-1])
    timepoints = jnp.insert(timepoints, 0, 0.0)
    return timepoints


times = jax.vmap(process_base_model, in_axes=(None, 0))(
    deme_names, jnp.array(unique_cfg)
)


def compute_loglik(c_index, data, c_map, times, theta, rho):
    c = c_map[c_index]
    t = times[c_index]
    eta = PiecewiseConstant(c=c, t=t)
    return loglik(data, eta, t, theta, rho)


def neg_loglik(
    vec,
    path_order,
    unique_cfg,
    times,
    matching_indices,
    het_matrix,
    theta,
    rho,
    num_samples,
    deme_names,
):
    if (vec > jnp.array([0.001, 0.001])).any():
        return jnp.inf

    if (vec < jnp.array([0, 0])).any():
        return jnp.inf

    params = _vec_to_dict_jax(vec, path_order)
    jax.debug.print("Param values: {}", jnp.array(list(params.values())))

    c_map = jax.vmap(
        lambda cfg, time: iicr_call(
            params=params, t=time, num_samples=dict(zip(deme_names, cfg))
        )["c"]
    )(jnp.array(unique_cfg), times)

    # Batched over cfg_mat (matching_indices) and all_tmrca_spans (het_matrix)
    batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None, None))(
        matching_indices, het_matrix, c_map, times, theta, rho
    )
    loss = -jnp.sum(batched_loglik) / num_samples
    jax.debug.print("Loss: {loss}", loss=loss)
    return loss


vec = jnp.array([0.0001, 0.0001])
loss = neg_loglik(
    vec,
    path_order,
    unique_cfg,
    times,
    matching_indices,
    het_matrix,
    theta,
    rho,
    num_samples,
    deme_names,
)
print(loss)
grad = jax.grad(neg_loglik)(
    vec,
    path_order,
    unique_cfg,
    times,
    matching_indices,
    jnp.array(het_matrix),
    theta,
    rho,
    num_samples,
    deme_names,
)
print(grad)

Param values: [0.0001 0.0001]
Loss: 11592.283790431211
11592.283790431211
Param values: [0.0001 0.0001]
Loss: 11592.2841796875
[-10543.97273662    533.54649831]
