In [1]:
import pandas as pd
import numpy as np
from glob import glob
import tskit
import sys
sys.path.append("/Users/jameskitchens/Documents/GitHub/terracotta")
import terracotta as tct
import importlib
importlib.reload(tct)
import time
from scipy.optimize import minimize
from numba import njit, guvectorize, int32
from scipy import linalg
from scipy.special import logsumexp

In [24]:
directory = "."

demes = pd.read_csv(f"{directory}/demes_elev_two_type.tsv", sep="\t")

demes["type"] = 0 #ignoring elevation type

samples = pd.read_csv(f"{directory}/samples_elev_two_type.tsv", sep="\t")

world_map = tct.WorldMap(demes, samples)

#trees = [tct.nx_bin_ts(tskit.load(ts).simplify(), [0, 10, 100, 1000, 10000, 100000, 1000000, 10000000]).first() for ts in glob(f"{directory}/trees/*")]
trees = [tskit.load(ts).simplify().first() for ts in glob(f"{directory}/trees/*")]
print(len(trees))

migration_rates = {0:0.025}
transition_matrix = world_map.build_transition_matrix(migration_rates=migration_rates)
exponentiated = linalg.expm(transition_matrix)
exponentiated[exponentiated < 0] = 0

39


In [25]:
%%time
total_number_of_edges = 0
for tree in trees:
    total_number_of_edges += tree.num_edges+1
branch_lengths = np.zeros(total_number_of_edges, dtype="int64")
edge_counter = 0
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        branch_lengths[edge_counter] = int(tree.branch_length(node))
        edge_counter += 1
branch_lengths = np.unique(np.array(branch_lengths))

CPU times: user 120 ms, sys: 5.6 ms, total: 125 ms
Wall time: 124 ms


In [28]:
%%time
previous_length = None
previous_mat = None
precomputed_transitions_new = {}
for bl in branch_lengths:
    if previous_length != None:
        diff = bl - previous_length
        where_next = np.dot(previous_mat, np.linalg.matrix_power(exponentiated, diff))
    else:
        where_next = np.linalg.matrix_power(exponentiated, bl)
    precomputed_transitions_new[bl] = where_next
    precomputed_transitions_new[bl][precomputed_transitions_new[bl] <= 0] = 1e-99
    previous_length = bl
    previous_mat = where_next

CPU times: user 14min 35s, sys: 47.6 s, total: 15min 22s
Wall time: 1min 21s


In [31]:
%%time
precomputed_transitions = {}
for bl in branch_lengths:
    where_next = np.linalg.matrix_power(exponentiated, bl)
    where_next[where_next <= 0] = 1e-99
    precomputed_transitions[bl] = where_next

CPU times: user 53min 52s, sys: 1min 28s, total: 55min 21s
Wall time: 4min 47s


In [30]:
%%time
precomputed_transitions_old = {}
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        bl = int(tree.branch_length(node))
        if bl not in precomputed_transitions_old:
            where_next = np.linalg.matrix_power(exponentiated, bl)
            where_next[where_next <= 0] = 1e-99
            precomputed_transitions_old[bl] = where_next

CPU times: user 53min 45s, sys: 1min 18s, total: 55min 3s
Wall time: 4min 47s


In [9]:
def _calc_tree_log_likelihood(tree, sample_location_vectors, transition_matrix=None, precomputed_transitions=None):
    """Calculates the log_likelihood of the tree using Felsenstein's Pruning Algorithm.

    NOTE: Assumes that samples are always tips on the tree.
    NOTE: Ignores samples that are completely detached from the tree(s).
    NOTE: Parent of sample cannot have the same time as sample.

    Parameters
    ----------
    tree : tskit.Tree
        This is a tree taken from the tskit.TreeSequence.
    sample_location_vectors : dict
        Contains all of the location vectors for the samples
    transition_matrix : np.matrix
        Instantaneous migration rate matrix between demes

    Returns
    -------
    tree_likelihood : float
        likelihood of the tree (product of the root likelihoods)
    root_log_likes : list
        List of root likelihoods (sum of the root locations vector)
    """

    if precomputed_transitions == None:
        if not isinstance(transition_matrix, np.ndarray):
            raise RuntimeError("Must provide either a transition matrix or precomputed transitions.")
        else:
            precomputed_transitions = {}
            for node in tree.nodes(order="timeasc"):
                bl = tree.branch_length(node)
                if bl not in precomputed_transitions:
                    where_next = linalg.expm(transition_matrix*bl)
                    if np.any(where_next <= 0):
                        where_next[where_next <= 0] = 1e-99
                    precomputed_transitions[bl] = where_next
   
    #log_messages = {l:calculate_first_messages(sample_location_vectors[l], precomputed_transitions[tree.branch_length(l)]) for l in sample_location_vectors if tree.branch_length(l)>0}
    log_messages = {}
    for l in sample_location_vectors:
        bl = int(tree.branch_length(l))
        if bl > 0:
            log_messages[l] = np.log(np.dot(sample_location_vectors[l], precomputed_transitions[bl])) #np.log(np.matmul(sample_location_vectors[l], precomputed_transitions[bl]))
    
    for node in tree.nodes(order="timeasc"):
        children = tree.children(node)
        if len(children) > 0:
            incoming_log_messages = []
            for child in children:
                incoming_log_messages.append(log_messages[child])
            summed_log_messages = np.sum(incoming_log_messages, axis=0)
            bl = int(tree.branch_length(node))
            if bl > 0:
                outgoing_log_message = np.array([logsumexp(np.log(precomputed_transitions[bl]).T + summed_log_messages, axis=1)])
            else:
                outgoing_log_message = summed_log_messages
            log_messages[node] = outgoing_log_message
    roots = tree.roots
    root_log_likes = [logsumexp(log_messages[r]) for r in roots if r not in sample_location_vectors]
    tree_likelihood = sum(root_log_likes)
    return tree_likelihood, root_log_likes

print(_calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions))
print(_calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions_new))

(np.float64(-12445.65279415648), [np.float64(-12445.65279415648)])
(np.float64(-12445.65279415648), [np.float64(-12445.65279415648)])


In [10]:
for i in precomputed_transitions:
    print(np.sum(precomputed_transitions[i] - precomputed_transitions_new[i]))

0.0
-8.890301229800124e-97
6.5233403819991486e-15
-3.208401792303294e-14
-2.1211372623202385e-14
-2.878881318122407e-14
-8.614466752270221e-15
-8.644452341444175e-15
1.9612729409285157e-14
-2.06304699885973e-14
-3.181266014506967e-15
1.1797637519683768e-14
4.449110350968066e-14
5.482723650085397e-14
2.504311862050468e-14
5.047871842744911e-14


In [93]:
%%time
branch_lengths = []
for tree in trees:
    for node in tree.nodes(order="timeasc"):
        bl = int(tree.branch_length(node))
        if bl not in branch_lengths:
            branch_lengths.append(bl)

CPU times: user 242 ms, sys: 2.24 ms, total: 244 ms
Wall time: 243 ms


In [120]:
%%time
precomputed_transitions = {}
for tree in trees[:2]:
    for node in tree.nodes(order="timeasc"):
        bl = int(tree.branch_length(node))
        if bl not in precomputed_transitions:
            where_next = np.linalg.matrix_power(exponentiated, bl) #forces branch lengths to be integer. Could be an issue for bl<1
            where_next[where_next <= 0] = 1e-99
            precomputed_transitions[bl] = where_next

CPU times: user 15min 10s, sys: 24.7 s, total: 15min 34s
Wall time: 1min 30s


In [None]:
CPU times: user 8min 9s, sys: 9.31 s, total: 8min 18s
Wall time: 45.8 s

In [120]:
def calculate_first_messages(loc_vec, transition_prob):
    return np.log(np.dot(loc_vec, np.exp(transition_prob.T)))
    
    
def _calc_tree_log_likelihood(tree, sample_location_vectors, transition_matrix=None, precomputed_transitions=None):
    """Calculates the log_likelihood of the tree using Felsenstein's Pruning Algorithm.

    NOTE: Assumes that samples are always tips on the tree.
    NOTE: Ignores samples that are completely detached from the tree(s).
    NOTE: Parent of sample cannot have the same time as sample.

    Parameters
    ----------
    tree : tskit.Tree
        This is a tree taken from the tskit.TreeSequence.
    sample_location_vectors : dict
        Contains all of the location vectors for the samples
    transition_matrix : np.matrix
        Instantaneous migration rate matrix between demes

    Returns
    -------
    tree_likelihood : float
        likelihood of the tree (product of the root likelihoods)
    root_log_likes : list
        List of root likelihoods (sum of the root locations vector)
    """

    if precomputed_transitions == None:
        if not isinstance(transition_matrix, np.ndarray):
            raise RuntimeError("Must provide either a transition matrix or precomputed transitions.")
        else:
            precomputed_transitions = {}
            for node in tree.nodes(order="timeasc"):
                bl = tree.branch_length(node)
                if bl not in precomputed_transitions:
                    where_next = linalg.expm(transition_matrix*bl)
                    if np.any(where_next <= 0):
                        where_next[where_next <= 0] = 1e-99
                    precomputed_transitions[bl] = where_next
   
    #log_messages = {l:calculate_first_messages(sample_location_vectors[l], precomputed_transitions[tree.branch_length(l)]) for l in sample_location_vectors if tree.branch_length(l)>0}
    log_messages = {}
    for l in sample_location_vectors:
        bl = tree.branch_length(l)
        if bl > 0:
            log_messages[l] = np.log(np.dot(sample_location_vectors[l], precomputed_transitions[bl])) #np.log(np.matmul(sample_location_vectors[l], precomputed_transitions[bl]))

    for node in tree.nodes(order="timeasc"):
        children = tree.children(node)
        if len(children) > 0:
            incoming_log_messages = []
            for child in children:
                incoming_log_messages.append(log_messages[child])
            summed_log_messages = np.sum(incoming_log_messages, axis=0)
            bl = tree.branch_length(node)
            if bl > 0:
                outgoing_log_message = np.array([logsumexp(np.log(precomputed_transitions[bl]).T + summed_log_messages, axis=1)])
            else:
                outgoing_log_message = summed_log_messages
            log_messages[node] = outgoing_log_message
    roots = tree.roots
    root_log_likes = [logsumexp(log_messages[r]) for r in roots if r not in sample_location_vectors]
    tree_likelihood = sum(root_log_likes)
    return tree_likelihood, root_log_likes

In [121]:
%time _calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, transition_matrix=world_map.build_transition_matrix(migration_rates=migration_rates))

CPU times: user 15min 7s, sys: 28.9 s, total: 15min 36s
Wall time: 1min 45s


(np.float64(9.24395820829972), [np.float64(9.24395820829972)])

In [118]:
%time _calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, transition_matrix=world_map.build_transition_matrix(migration_rates=migration_rates))

CPU times: user 14min 16s, sys: 20.4 s, total: 14min 36s
Wall time: 1min 28s


(np.float64(-12730.671311846883), [np.float64(-12730.671311846883)])

In [111]:
%time _calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions)

CPU times: user 2min 58s, sys: 6.45 s, total: 3min 4s
Wall time: 20.7 s


(np.float64(-12445.65279415648), [np.float64(-12445.65279415648)])

In [103]:
%time _calc_tree_log_likelihood(tree=trees[0], sample_location_vectors=world_map.sample_location_vectors, precomputed_transitions=precomputed_transitions)

CPU times: user 28.3 s, sys: 496 ms, total: 28.8 s
Wall time: 17.1 s


(np.float64(-12730.79775661428), [np.float64(-12730.79775661428)])

In [3]:
initial_mr = np.array([np.random.uniform(0, 0.02) for connection_type in world_map.connections.type.unique()])
minimize(tct.optimize.run, initial_mr, args=(world_map, trees[:1]), method="BFGS", options={"disp":True})

  return _core_matmul(x1, x2)
  log_messages[l] = np.log(np.matmul(sample_location_vectors[l], precomputed_transitions[bl]))
  eAw = eAw @ eAw
  eAw = eAw @ eAw
  return _core_matmul(x1, x2)


KeyboardInterrupt: 

In [None]:
#previous_length = None
#previous_mat = None
#precomputed_transitions = {}
#for bl in branch_lengths:
#    if previous_length != None:
#        diff = bl - previous_length
#        where_next = np.dot(previous_mat, np.linalg.matrix_power(exponentiated, diff))
#    else:
#        where_next = np.linalg.matrix_power(exponentiated, bl)
#    precomputed_transitions[bl] = where_next
#    previous_length = bl
#    previous_mat = where_next



#precomputed_transitions = {}
#for bl in branch_lengths:
#    where_next = np.linalg.matrix_power(exponentiated, bl)
#    where_next[where_next <= 0] = 1e-99
#    precomputed_transitions[bl] = where_next

#np.linalg.matrix_power(exponentiated, branch_lengths[0])


#@njit()
#def raise_to_power(bl):
#    where_next = np.linalg.matrix_power(exponentiated, bl) #forces branch lengths to be integer. Could be an issue for bl<1
#    where_next[where_next <= 0] = 1e-99
#    return where_next

#for bl in branch_lengths:
#    raise_to_power(bl)