In [1]:
import pandas as pd
import numpy as np
from glob import glob
import tskit
import sys
sys.path.append("/Users/jameskitchens/Documents/GitHub/terracotta")
import terracotta as tct
import importlib
importlib.reload(tct)
import time
import emcee
from scipy import linalg


directory = "."

demes = pd.read_csv(f"{directory}/demes_elev_two_type.tsv", sep="\t")

demes["type"] = 0 #ignoring elevation type

samples = pd.read_csv(f"{directory}/samples_elev_two_type.tsv", sep="\t")

world_map = tct.WorldMap(demes, samples)

trees = [tct.nx_bin_ts(tskit.load(ts).simplify(), [0, 10, 100, 1000, 10000, 100000, 1000000, 10000000]).first() for ts in glob(f"{directory}/trees/*")[:1]]

In [2]:
trees_list = [tct.convert_tree_to_tuple_list(tree) for tree in trees]

In [34]:
total_number_of_edges = 0
for tree in trees:
    total_number_of_edges += tree.num_edges+1
branch_lengths = np.zeros(total_number_of_edges, dtype="int64")
edge_counter = 0
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        branch_lengths[edge_counter] = int(tree.branch_length(node))
        edge_counter += 1
branch_lengths = np.unique(np.array(branch_lengths))

transition_matrix = world_map.build_transition_matrix(migration_rates={0:0.02})
exponentiated = linalg.expm(transition_matrix)
exponentiated[exponentiated < 0] = 0

start = time.time()
previous_length = None
previous_mat = None
precomputed_transitions = np.zeros((len(branch_lengths), len(world_map.demes), len(world_map.demes)), dtype="float64")
counter = 0
for bl in branch_lengths:
    if previous_length != None:
        diff = bl - previous_length
        where_next = np.dot(previous_mat, np.linalg.matrix_power(exponentiated, diff))
    else:
        where_next = np.linalg.matrix_power(exponentiated, bl)
    precomputed_transitions[counter] = where_next
    previous_length = bl
    previous_mat = where_next
    counter += 1
print(time.time() - start)

1.1469879150390625


In [29]:
sample_location_vectors = np.zeros((len(world_map.sample_location_vectors), len(world_map.demes)), dtype="float64")
sample_ids = np.zeros(len(world_map.sample_location_vectors), dtype="int64")
counter = 0
for sample in world_map.sample_location_vectors:
    sample_location_vectors[counter] = world_map.sample_location_vectors[sample]
    sample_ids[counter] = sample
    counter += 1

In [4]:
start = time.time()
print(tct.calc_tree_log_likelihood(trees[0], world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions))
print(time.time() - start)

start = time.time()
print(tct.calc_tree_log_likelihood_new(trees_list[0][0], [trees_list[0][1]], world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions))
print(time.time() - start)

(-12478.400057058894, [-12478.400057058894])
6.3912200927734375
(-12478.400057058894, [-12478.400057058894])
4.504210948944092
