In [1]:
import pandas as pd
import numpy as np
from glob import glob
import tskit
import sys
sys.path.append("/Users/jameskitchens/Documents/GitHub/terracotta")
import terracotta as tct
import importlib
importlib.reload(tct)
import time
import emcee
from scipy import linalg


directory = "."

demes = pd.read_csv(f"{directory}/demes_elev_two_type.tsv", sep="\t")

demes["type"] = 0 #ignoring elevation type

samples = pd.read_csv(f"{directory}/samples_elev_two_type.tsv", sep="\t")

world_map = tct.WorldMap(demes, samples)

trees = [tct.nx_bin_ts(tskit.load(ts).simplify(), [0, 10, 100, 1000, 10000, 100000, 1000000, 10000000]).first() for ts in glob(f"{directory}/trees/*")]
#trees = [tskit.load(ts).first() for ts in glob(f"{directory}/trees/*")]

cl = []
bal = []
r = []
for tree in trees:
    child_list, branch_above_list, roots = tct.convert_tree_to_tuple_list(tree)
    cl.append(child_list)
    bal.append(branch_above_list)
    r.append(roots)

total_number_of_edges = 0
for tree in trees:
    total_number_of_edges += tree.num_edges+1
branch_lengths = np.zeros(total_number_of_edges, dtype="int64")
edge_counter = 0
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        branch_lengths[edge_counter] = int(tree.branch_length(node))
        edge_counter += 1
branch_lengths = np.unique(np.array(branch_lengths))

In [7]:
%%time
tct.calc_migration_rate_log_likelihood(
    world_map=world_map,
    trees=trees,
    migration_rates={0:0.02},
    branch_lengths=branch_lengths
)

- 14.28492021560669
- 8.929627180099487
- 7.975739002227783
- 10.028924226760864
- 9.265682220458984
- 8.086036205291748
- 7.318737030029297
- 6.650225877761841
- 6.475908041000366
- 7.014204263687134
- 9.571678161621094
- 6.973972797393799
- 7.880432844161987
- 8.16109299659729
- 7.587339162826538
- 7.101822137832642
- 6.790343999862671
- 7.162989139556885
- 6.398159027099609
- 7.388814926147461
- 7.99402117729187
- 7.229645013809204
- 7.735358953475952
- 7.531213045120239
- 10.122698068618774
- 8.991009950637817
- 7.660728216171265
- 5.058655023574829
- 7.708830833435059
- 9.2101149559021
- 10.277617931365967
- 12.847586154937744
- 13.09038496017456
- 13.41930603981018
- 9.552277088165283
- 9.222786903381348
- 9.125576257705688
- 8.191262006759644
- 7.854036808013916
CPU times: user 1h 9min 25s, sys: 8min 2s, total: 1h 17min 28s
Wall time: 7min 57s


(-756115.7311965061,
 [-21038.478953005462,
  -19288.964438964806,
  -21510.936503161254,
  -19846.95199720352,
  -16631.551108706783,
  -19992.65665046741,
  -15771.348433841966,
  -23094.546570679442,
  -18622.594000619694,
  -17764.24220095693,
  -19258.25405925385,
  -19974.66662721844,
  -15658.399569752892,
  -15676.588070602022,
  -21210.252177049053,
  -19731.64719610874,
  -24226.363345793485,
  -19287.23688063402,
  -25304.0573718476,
  -14844.889822315065,
  -18009.01452039574,
  -21131.42460745062,
  -18540.43290022818,
  -17576.315166879813,
  -16719.865084056164,
  -21777.493679726624,
  -21915.58330190281,
  -26533.10320427965,
  -20014.049198899993,
  -17358.31591528033,
  -22715.784173004344,
  -17117.452285288808,
  -16174.619122574808,
  -15554.676128060066,
  -21421.014671444667,
  -18301.03555796306,
  -18836.393241978858,
  -18652.404271583124,
  -19032.128187325965])

In [2]:
child_list, branch_above_list, roots  = tct.convert_tree_to_tuple_list(trees[0])

In [3]:
total_number_of_edges = 0
for tree in trees:
    total_number_of_edges += tree.num_edges+1
branch_lengths = np.zeros(total_number_of_edges, dtype="int64")
edge_counter = 0
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        branch_lengths[edge_counter] = int(tree.branch_length(node))
        edge_counter += 1
branch_lengths = np.unique(np.array(branch_lengths))

transition_matrix = world_map.build_transition_matrix(migration_rates={0:0.02})
exponentiated = linalg.expm(transition_matrix)
exponentiated[exponentiated < 0] = 0

start = time.time()
previous_length = None
previous_mat = None
precomputed_transitions = np.zeros((len(branch_lengths), len(world_map.demes), len(world_map.demes)), dtype="float64")
precomputed_log = np.zeros((len(branch_lengths), len(world_map.demes), len(world_map.demes)), dtype="float64")
counter = 0
for bl in branch_lengths:
    if previous_length != None:
        diff = bl - previous_length
        where_next = np.dot(previous_mat, np.linalg.matrix_power(exponentiated, diff))
    else:
        where_next = np.linalg.matrix_power(exponentiated, bl)
    precomputed_transitions[counter] = where_next
    precomputed_transitions[counter][precomputed_transitions[counter] <= 0] = 1e-99
    precomputed_log[counter] = np.log(precomputed_transitions[counter]).T
    previous_length = bl
    previous_mat = where_next
    counter += 1
#precomputed_transitions[precomputed_transitions <= 0] = 1e-99
print(time.time() - start)

1.3260462284088135


In [4]:
sample_location_vectors = np.zeros((len(world_map.sample_location_vectors), len(world_map.demes)), dtype="float64")
sample_ids = np.zeros(len(world_map.sample_location_vectors), dtype="int64")
counter = 0
for sample in world_map.sample_location_vectors:
    sample_location_vectors[counter] = world_map.sample_location_vectors[sample]
    sample_ids[counter] = sample
    counter += 1

In [8]:
start = time.time()
print(tct.calc_tree_log_likelihood_new(
    child_list=child_list,
    branch_above_list=branch_above_list,
    roots=roots,
    sample_ids=sample_ids,
    sample_location_vectors=sample_location_vectors,
    branch_lengths=branch_lengths,
    precomputed_transitions=precomputed_transitions,
    precomputed_log=precomputed_log
))
print(time.time() - start)

(-12478.400057058894, array([-12478.40005706]))
2.7816619873046875
