In [5]:
import msprime as msp
import demes
import demesdraw

demo = msp.Demography()
demo.add_population(initial_size = 1e4, name = "anc")
demo.add_population(initial_size = 1e4, name = "P0")
demo.add_population(initial_size = 1e4, name = "P1")
demo.set_symmetric_migration_rate(populations=("P0", "P1"), rate=0.01)
tmp = [f"P{i}" for i in range(2)]
demo.add_population_split(time = 1000, derived=tmp, ancestral="anc")
g = demo.to_demes()
sample_size = 10
samples = {f"P{i}": sample_size for i in range(2)}
anc = msp.sim_ancestry(samples=samples, demography=demo, recombination_rate=1e-8, sequence_length=1e7)
ts = msp.sim_mutations(anc, rate=1e-8)

In [None]:
from momi3 import Momi3
import numpy as np
momi_object = Momi3(g).iicr(2)
params = [("demes", 0, "epochs", 0, "start_size")]
f, x = momi_object.reparameterize(list(params))
parameters = list(x.keys())
x

{frozenset({('demes', 0, 'epochs', 0, 'end_size'),
            ('demes',
             0,
             'epochs',
             0,
             'start_size')}): Array(10000., dtype=float64)}

In [9]:
from momi3.jsfs import JSFS
from momi3.momi import Momi3
import jax
momi_sfs_object = Momi3(g).sfs({'P0':20, 'P1':20})
afs = ts.allele_frequency_spectrum(sample_sets=[ts.samples([1]), ts.samples([2])], span_normalise=False)
jsfs = JSFS.from_dense(afs, ["P0", "P1"])

In [None]:
from momi3 import Momi3
import numpy as np
params = [("demes", 0, "epochs", 0, "start_size")]
f, x = momi_sfs_object.reparameterize(list(params))
parameters = list(x.keys())

In [10]:
from jax import vmap
import jax.numpy as jnp
x_values = jnp.linspace(5000, 20000, 100)  # adjust these steps as needed

# Vectorize the likelihood computation over x_values
def compute_likelihood(val):
    updated_x = x.copy()
    updated_x[parameters[0]] = val
    params = updated_x
    return momi_sfs_object.loglik(params, jsfs)

# Use vmap to compute likelihoods for all x_values
likelihoods = vmap(compute_likelihood)(x_values)

[32m2025-06-12 15:17:20.334[0m | [34m[1mDEBUG   [0m | [36mmomi3.sfs.migration[0m:[36mlift_cm[0m:[36m117[0m - [34m[1musing diffeq solver for {'axes': OrderedDict({'P0': 21, 'P1': 21}), 'drift': {'P0': BCOO(float64[21, 21], nse=57), 'P1': BCOO(float64[21, 21], nse=57)}, 'mig': {('P0', 'P0'): ({0: BCOO(float64[21, 21], nse=64), 1: BCOO(float64[21, 21], nse=60)}, {1: BCOO(float64[21, 21], nse=40)}), ('P0', 'P1'): ({0: BCOO(float64[21, 21], nse=64), 1: BCOO(float64[21, 21], nse=60)}, {1: BCOO(float64[21, 21], nse=40)}), ('P1', 'P0'): ({0: BCOO(float64[21, 21], nse=64), 1: BCOO(float64[21, 21], nse=60)}, {1: BCOO(float64[21, 21], nse=40)}), ('P1', 'P1'): ({0: BCOO(float64[21, 21], nse=64), 1: BCOO(float64[21, 21], nse=60)}, {1: BCOO(float64[21, 21], nse=40)})}, 'mut': {'P0': BCOO(float64[21, 21], nse=61), 'P1': BCOO(float64[21, 21], nse=61)}}[0m


AttributeError: jax.tree_map was removed in JAX v0.6.0: use jax.tree.map (jax v0.4.25 or newer) or jax.tree_util.tree_map (any JAX version).

In [None]:
import matplotlib.pyplot as plt
# Plot
plt.figure(figsize=(10, 6))
plt.plot(x_values, likelihoods, label='Likelihood')
plt.xlabel('x (parameter values)')
plt.ylabel('Debugger Likelihood')
plt.title('Debugger likelihood over parameters')
plt.legend()
plt.grid(True)
plt.show()