In [1]:
import tempfile

import bio2zarr.tskit
import msprime as msp
import sgkit

from demestats.fit.read_data import get_het_data

demo = msp.Demography()
demo.add_population(initial_size=5000, name="anc")
demo.add_population(initial_size=5000, name="P0")
demo.add_population(initial_size=5000, name="P1")
demo.add_population(initial_size=5000, name="P2")
demo.add_population(initial_size=5000, name="P3")
demo.add_population(initial_size=5000, name="P4")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P1", "P2"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P2", "P3"), rate=0.0001)
demo.set_symmetric_migration_rate(populations=("P3", "P4"), rate=0.0001)
tmp = [f"P{i}" for i in range(5)]
demo.add_population_split(time=4000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(5)}
anc = msp.sim_ancestry(
    samples=samples,
    demography=demo,
    recombination_rate=1e-8,
    sequence_length=1e8,
    random_seed=12,
)
ts = msp.sim_mutations(anc, rate=1e-8, random_seed=13)

# demesdraw.tubes(g)


d = tempfile.TemporaryDirectory()
bio2zarr.tskit.convert(ts, d.name + "/ts")
ds = sgkit.load_dataset(d.name + "/ts")

het_matrix, cfg_list = get_het_data(ts, ds, num_samples=100)

In [4]:
from typing import Any, List, Mapping, Set, Tuple

import jax
import jax.numpy as jnp
from jax import vmap
from loguru import logger
from phlashlib.iicr import PiecewiseConstant
from phlashlib.loglik import loglik

from demestats.fit.util import _vec_to_dict_jax, process_data
from demestats.iicr import IICRCurve

logger.disable("demestats")

Path = Tuple[Any, ...]
Var = Path | Set[Path]
Params = Mapping[Var, float]

paths = {
    frozenset(
        {("demes", 0, "epochs", 0, "end_size"), ("demes", 0, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 1, "epochs", 0, "end_size"), ("demes", 1, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 2, "epochs", 0, "end_size"), ("demes", 2, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 3, "epochs", 0, "end_size"), ("demes", 3, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 4, "epochs", 0, "end_size"), ("demes", 4, "epochs", 0, "start_size")}
    ): 5000.0,
    frozenset(
        {("demes", 5, "epochs", 0, "end_size"), ("demes", 5, "epochs", 0, "start_size")}
    ): 5000.0,
    # ('migrations', 0, 'rate'):0.0001,
    # ('migrations', 1, 'rate'):0.0001,
    # ('migrations', 2, 'rate'):0.0001,
    # ('migrations', 3, 'rate'):0.0001,
    # ('migrations', 4, 'rate'):0.0001,
    # ('migrations', 5, 'rate'):0.0001,
    # ('migrations', 6, 'rate'):0.0001,
    # ('migrations', 7, 'rate'):0.0001,
    frozenset(
        {
            ("demes", 0, "epochs", 0, "end_time"),
            ("demes", 1, "start_time"),
            ("demes", 2, "start_time"),
            ("demes", 3, "start_time"),
            ("demes", 4, "start_time"),
            ("demes", 5, "start_time"),
            ("migrations", 0, "start_time"),
            ("migrations", 1, "start_time"),
            ("migrations", 2, "start_time"),
            ("migrations", 3, "start_time"),
            ("migrations", 4, "start_time"),
            ("migrations", 5, "start_time"),
            ("migrations", 6, "start_time"),
            ("migrations", 7, "start_time"),
        }
    ): 4000.0,
}
# paths = {
#      ('migrations', 0, 'rate'):0.0001,
#      ('migrations', 1, 'rate'):0.0001,
#      ('migrations', 2, 'rate'):0.0001,
#      ('migrations', 3, 'rate'):0.0001,
#      ('migrations', 4, 'rate'):0.0001,
#      ('migrations', 5, 'rate'):0.0001,
#      ('migrations', 6, 'rate'):0.0001,
#      ('migrations', 7, 'rate'):0.0001,
#      }
path_order: List[Var] = list(paths)
cfg_mat, deme_names, unique_cfg, matching_indices = process_data(cfg_list)
num_samples = len(cfg_mat)
rho = theta = 1e-8 * 100
k = 2
iicr = IICRCurve(demo=g, k=2)
iicr_call = jax.jit(iicr.__call__)


def process_base_model(deme_names, cfg):
    curve = iicr.curve(num_samples=dict(zip(deme_names, cfg)))
    timepoints = jax.vmap(curve.quantile)(jnp.linspace(0, 1, 17)[1:-1])
    timepoints = jnp.insert(timepoints, 0, 0.0)
    return timepoints


vec = jnp.array(
    [
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        5000.0,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        0.0001,
        4000.0,
    ]
)
params = _vec_to_dict_jax(vec, path_order)


def compute_loglik(c_index, data, c_map, times, theta, rho):
    c = c_map[c_index]
    t = times[c_index]
    eta = PiecewiseConstant(c=c, t=t)
    return loglik(data, eta, t, theta, rho)

## Time it takes to extract 16 time points from base model

In [11]:
%time times = jax.vmap(process_base_model, in_axes=(None, 0))(deme_names, jnp.array(unique_cfg))

CPU times: user 5.78 s, sys: 280 ms, total: 6.06 s
Wall time: 6.55 s


## Time it takes to evaluate c_map for 15 unique sampling configurations and each configuration has 16 time points

In [12]:
%time c_map = jax.vmap(lambda cfg, time: iicr_call(params=params, t=time, num_samples=dict(zip(deme_names, cfg)))["c"])(jnp.array(unique_cfg), times)

CPU times: user 11.1 ms, sys: 0 ns, total: 11.1 ms
Wall time: 11.2 ms


## Time it takes to evaluate the likelihood once

In [13]:
%time batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None, None))(matching_indices, het_matrix, c_map, times, theta, rho)

CPU times: user 2.57 s, sys: 165 ms, total: 2.74 s
Wall time: 2.95 s


## Error for running out of memory. You just need to crank num_samples higher to get that error. 

In [5]:
# %time batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None, None))(matching_indices, het_matrix, c_map, times, theta, rho)

W1017 15:18:50.616907 2929795 bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.98GiB (rounded to 3199996928)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
W1017 15:18:50.617111 2929795 bfc_allocator.cc:512] ******______________________________________________________________________________________________


CPU times: user 264 ms, sys: 14 ms, total: 278 ms
Wall time: 10.4 s


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 3199996800 bytes.

## Time it takes to evaluate likelihood with everything put together

In [14]:
def neg_loglik(
    vec,
    path_order,
    unique_cfg,
    times,
    matching_indices,
    het_matrix,
    theta,
    rho,
    num_samples,
    deme_names,
    iicr_call,
):
    params = _vec_to_dict_jax(vec, path_order)
    jax.debug.print("Param values: {}", jnp.array(list(params.values())))

    c_map = jax.vmap(
        lambda cfg, time: iicr_call(
            params=params, t=time, num_samples=dict(zip(deme_names, cfg))
        )["c"]
    )(jnp.array(unique_cfg), times)

    # Batched over cfg_mat (matching_indices) and all_tmrca_spans (het_matrix)
    batched_loglik = vmap(compute_loglik, in_axes=(0, 0, None, None, None, None))(
        matching_indices, het_matrix, c_map, times, theta, rho
    )
    loss = -jnp.sum(batched_loglik) / num_samples
    jax.debug.print("Loss: {loss}", loss=loss)
    return loss

In [16]:
%time neg_loglikelihood = neg_loglik(vec, path_order, unique_cfg, times, matching_indices, jnp.array(het_matrix), theta, rho, num_samples, deme_names, iicr_call)

Param values: [5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 1.e-04]
Loss: 118358.76040414307
CPU times: user 2.8 s, sys: 591 ms, total: 3.39 s
Wall time: 3.62 s


## Attempt to take the gradient

In [18]:
%time gradients = jax.grad(neg_loglik)(vec, path_order, unique_cfg, times, matching_indices, jnp.array(het_matrix), theta, rho, num_samples, deme_names, iicr_call)
gradients

Param values: [5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 1.e-04]
Loss: 118358.7578125
CPU times: user 5.15 s, sys: 878 ms, total: 6.02 s
Wall time: 6.56 s


Array([ 7.52598936e-03, -1.20917500e-05, -5.29833800e-05,  0.00000000e+00,
       -1.00451765e-04, -1.63062744e-05,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00], dtype=float64)

## Gradient error when sample_size is too large

In [48]:
# %time gradients = jax.grad(neg_loglik)(vec, path_order, unique_cfg, times, matching_indices, jnp.array(het_matrix), theta, rho, num_samples, deme_names, iicr_call)

Param values: [5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 5.e+03 1.e-04]


E1016 22:40:16.848600 2142686 gpu_hlo_schedule.cc:817] The byte size of input/output arguments (17351256104) exceeds the base limit (12696256512). This indicates an error in the calculation!
W1016 22:40:16.856940 2142686 hlo_rematerialization.cc:3198] Can't reduce memory use below 16.13GiB (17325120104 bytes) by rematerialization; only reduced to 27.33GiB (29345376264 bytes), down from 27.33GiB (29345456296 bytes) originally
W1016 22:40:27.479932 2142686 bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.23GiB (rounded to 2400000000)requested by op 
If the cause is memory fragmentation maybe the environment variable 'TF_GPU_ALLOCATOR=cuda_malloc_async' will improve the situation. 
Current allocation summary follows.
Current allocation summary follows.
W1016 22:40:27.480583 2142686 bfc_allocator.cc:512] *****__********____________****************************************************************_________
E1016 22:40:27.480635 2142686 pjrt_stream_executor_cl

CPU times: user 15.7 s, sys: 2.23 s, total: 17.9 s
Wall time: 31.1 s


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2400000000 bytes.