In [109]:
import sys
sys.path.append("/Users/jameskitchens/Documents/GitHub/terracotta")
import terracotta as tct
import importlib
importlib.reload(tct)
import pandas as pd
import numpy as np
import tskit
from glob import glob
import emcee

In [110]:
demes = pd.read_csv("dataset/demes.tsv", sep="\t")
world_map = tct.WorldMap(demes)

sample_locations = pd.read_csv("dataset/samples.tsv", sep="\t")
sample_location_vectors = world_map.build_sample_location_vectors(sample_locations=sample_locations)

trees = [tskit.load(ts).first() for ts in glob("dataset/trees/*")]

In [111]:
def lnprior(migration_rates):
    for mr in migration_rates:
        if not(0.0001 < mr < 1):
            return -np.inf
    return 0.0

def lnprob(migration_rates, world_map, trees, sample_location_vectors):
    lp = lnprior(migration_rates)
    if not np.isfinite(lp):
        return -np.inf
    return lp + tct.calc_migration_rate_log_likelihood(
        world_map=world_map,
        trees=trees,
        sample_location_vectors=sample_location_vectors,
        migration_rates={i:mr for i,mr in enumerate(migration_rates)}
    )[0]

In [112]:
nwalkers = 20
niter = 2000
initial_mr = np.array([0.1 for connection_type in world_map.connections.type.unique()])
ndim = len(initial_mr)
p0 = [np.array(initial_mr) + 1e-2 * np.random.randn(ndim) for i in range(nwalkers)]

In [113]:
def main(p0, nwalkers, niter, ndim, lnprob, world_map, trees, sample_location_vectors):
    sampler = emcee.EnsembleSampler(nwalkers, ndim, lnprob, args=[world_map, trees, sample_location_vectors])

    print("Running burn-in...")
    p0, _, _ = sampler.run_mcmc(p0, 2000)
    sampler.reset()

    print("Running production...")
    pos, prob, state = sampler.run_mcmc(p0, niter)

    return sampler, pos, prob, state

sampler, pos, prob, state = main(p0, nwalkers, niter, ndim, lnprob, world_map, trees, sample_location_vectors)

Running burn-in...
emcee: Exception while calling your likelihood function:
  params: [0.06870819 0.00292712 0.0667091 ]
  args: [<terracotta.WorldMap object at 0x1840f9790>, [<tskit.trees.Tree object at 0x183ff4c80>, <tskit.trees.Tree object at 0x183ff62d0>, <tskit.trees.Tree object at 0x183cf92b0>, <tskit.trees.Tree object at 0x183fe01a0>, <tskit.trees.Tree object at 0x183fe0800>, <tskit.trees.Tree object at 0x183fe0ce0>, <tskit.trees.Tree object at 0x1840f9ac0>, <tskit.trees.Tree object at 0x1840fa150>, <tskit.trees.Tree object at 0x1840f97c0>, <tskit.trees.Tree object at 0x1840fa4e0>, <tskit.trees.Tree object at 0x1840fa5a0>, <tskit.trees.Tree object at 0x1840fbc80>, <tskit.trees.Tree object at 0x1840facf0>, <tskit.trees.Tree object at 0x1840fad50>, <tskit.trees.Tree object at 0x183ff2d20>, <tskit.trees.Tree object at 0x183ff30b0>, <tskit.trees.Tree object at 0x183ff1970>, <tskit.trees.Tree object at 0x183ff1e80>, <tskit.trees.Tree object at 0x183ff1c10>, <tskit.trees.Tree object a

Traceback (most recent call last):
  File "/opt/anaconda3/envs/terracotta/lib/python3.12/site-packages/emcee/ensemble.py", line 640, in __call__
    return self.f(x, *self.args, **self.kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/q8/b10jdgls4xvcv3wz767kf6f80000gn/T/ipykernel_14478/332190548.py", line 11, in lnprob
    return lp + tct.calc_migration_rate_log_likelihood(
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jameskitchens/Documents/GitHub/terracotta/terracotta/__init__.py", line 296, in calc_migration_rate_log_likelihood
    log_likelihoods.append(_calc_tree_log_likelihood(tree, sample_location_vectors, transition_matrix)[0])
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/jameskitchens/Documents/GitHub/terracotta/terracotta/__init__.py", line 259, in _calc_tree_log_likelihood
    outgoing_log_message = np.array([logsumexp(np.log(linalg.expm(transition_ma

In [None]:
samples = sampler.flatchain
theta_max  = samples[np.argmax(sampler.flatlnprobability)]
print(theta_max)

In [None]:
import matplotlib.pyplot as plt
plt.plot(range(len(sampler.flatlnprobability)), sampler.flatlnprobability)
plt.show()