In [1]:
import sys
sys.path.append("..")

import numpy as np
import tensorflow as tf
from tcmc.nwk_utils import nwk_read
import tcmc.tcmc as tcmc
import tensorflow_probability as tfp

# without the next line, tensorflow-2.0.0, gives an error on 
# float32 versus float64 incompatibility
tf.keras.backend.set_floatx('float64')

In [2]:
s = 4 # alphabet size
n = 5 # number of sequences
k = 4 # number of ancestral nodes in tree
M = 2 # number of models / rate matrices

### Example tree
![](example_tree.png "200px")

In [3]:
T = np.full((n + k, n + k), 0, dtype = np.float64)
T[0,5] = 0.1
T[1,5] = 0.1
T[5,6] = 0.1
T[2,6] = 0.2
T[3,7] = 0.3
T[4,7] = 0.3
T[6,8] = 0.2
T[7,8] = 0.1

In [6]:
inputs = tf.keras.Input(shape = (n, s, ), dtype = tf.float64)
p = tcmc.TCMCProbability(M, T, n)
xprime = p(inputs)
theta = p.get_weights()
print("weights:", theta)
print("\nrates:", theta[0].shape, "\tsums =", np.sum(theta[0], axis = -1))
print("pi:", theta[1].shape, "\tsums =", np.sum(theta[1], axis = -1))


weights: [array([[0.15399915, 0.1205418 , 0.27481729, 0.14334083, 0.18599185,
        0.12130907],
       [0.23395279, 0.12189349, 0.12704466, 0.14470567, 0.2144538 ,
        0.1579496 ]]), array([[0.25312674, 0.2512489 , 0.22999236],
       [0.2260969 , 0.29365459, 0.23431882]])]

rates: (2, 6) 	sums = [0.99999999 1.00000001]
pi: (2, 3) 	sums = [0.734368   0.75407031]


In [5]:
# contract edges adjacent to root
n = 5
k = 3
s = 4
M = 2

T_ex = np.full((n+k,n+k), 0, dtype=np.float64)
T_ex[0,5] = .01
T_ex[1,5] = .2
T_ex[3,6] = .3
T_ex[4,6] = .02
T_ex[3,7] = .3
T_ex[5,7] = .15
T_ex[6,7] = .05


R = np.array([[1,2,1,1,2,1], [1,1,1,1,1,1]]) # Kimura 80 model versus Jukes-Kantor 
R_inv = R ** .5
pi = np.array([[0.4,0.15,0.05,0.4], [.25,.25,.25,.25]])
pi_inv = tcmc.stereographic_projection(pi ** .5)
x_test = np.array([
    [
        [1,0,0,0],
        [0,1,0,0],
        [0,1,0,0],
        [0,1,0,0],
        [0,0,1,0]
    ],
], dtype=np.float64)

in_ex = tf.keras.Input(shape=(n,s,), dtype=tf.float64)
p_ex = tcmc.TCMCProbability(M, T_ex, n)
#p_ex(in_ex)
#p_ex.set_weights([R_inv,pi_inv])
#p.get_weights()

p_ex(x_test)

<tf.Tensor: id=1992, shape=(1, 2), dtype=float64, numpy=array([[7.84516397e-07, 4.19774086e-07]])>