In [1]:
import sys
sys.path += ['/Users/cswa648/git/phylo-hacking/pymc']
import theano
import newick
from xml.etree import ElementTree

%matplotlib inline

data_xml_filename = '../../data/ratites.SRD06.RLC.YULE.xml'
xml_root = ElementTree.parse(data_xml_filename)
newick_string = xml_root.find('.//newick').text
tree = newick.loads(newick_string)[0]
print(tree.ascii_art())

                            ┌─ANDI
                     ┌──────┤
                     │      └─DIGI
              ┌──────┤
              │      │             ┌─CASS
              │      │      ┌──────┤
       ┌──────┤      └──────┤      └─EMU
       │      │             └─KIWI
───────┤      │
       │      └─RHEA
       └─OST


In [131]:
sequence_dict = { seq_element.find('./taxon').attrib['idref']: list(seq_element.itertext())[-1].strip()
 for seq_element in xml_root.findall('.//sequence') }

from importlib import reload
import pylo.topology
reload(pylo.topology)

from pylo.topology import TreeTopology
topology = TreeTopology(tree)

import numpy as np
from pylo.transform import group_sequences, encode_sequences
sequence_dict_encoded = encode_sequences(sequence_dict)
pattern_dict, pattern_counts = group_sequences(sequence_dict_encoded)
child_patterns = np.array(topology.build_sequence_table(pattern_dict))
pattern_counts = np.array(pattern_counts)

import theano.tensor as tt

child_patterns_ = tt.as_tensor_variable(child_patterns)
pattern_counts_ = tt.as_tensor_variable(pattern_counts)

In [132]:
import pylo.tree.coalescent
import pylo.pruning

reload(pylo.tree.coalescent)
reload(pylo.pruning)

import pymc3 as pm
from pylo.tree.coalescent import CoalescentTree, ConstantPopulationFunction
from pylo.hky import HKYSubstitutionModel
from pylo.pruning import LeafSequences

with pm.Model() as model:    
    population_size = pm.Gamma('population_size', alpha=2.0, beta=0.1)
    population_function = ConstantPopulationFunction(topology, population_size)
    tree_heights = CoalescentTree('tree', topology, population_function)
    branch_lengths = pm.Deterministic('branch_lengths', topology.get_child_branch_lengths(tree_heights))
    rates = pm.Lognormal('rates', shape=(topology.get_internal_node_count(), 2))
    
    distances = pm.Deterministic('distances', rates*branch_lengths)
    
    kappa = pm.Exponential('kappa', lam=0.1)
    pi = pm.Dirichlet('pi', a=np.ones(4))
    
    substitution_model = HKYSubstitutionModel(kappa, pi)
    sequences = LeafSequences('sequences', topology, substitution_model, distances, child_patterns_, pattern_counts_)
    
model

<pymc3.model.Model at 0x1c29a821d0>

In [14]:
for RV in model.basic_RVs:
    print(RV.name, RV.logp()(model.test_point))

population_size_log__ [2.22044605e-16]
tree_tree_height_proportion__ [-1.05698841 -0.40460858 -0.19809907  0.27117483 -0.14753601  2.25793952
 -2.61624531]
rates_log__ [-0. -0. -0. -0. -0. -0. -0. -0. -0. -0. -0. -0.]
kappa_log__ [0.30685282]
pi_stickbreaking__ [2.22044605e-16 1.11022302e-16 0.00000000e+00]


In [88]:
model.test_point

def random_point():
    return { key: np.random.normal(size=val.shape) for key, val in model.test_point.items() }

sequences_fn = model.fn(sequences)

def logps(point):
    return np.concatenate([[RV.logp(point) for RV in model.basic_RVs], [sequences_fn(point)]])

[logps(random_point()) for i in range(3)]

[array([ -2.73323126, -19.79893947, -14.22547334,  -1.21857168,
         -7.31402012,          nan]),
 array([  -6.42265503, -165.94033232,  -17.03279796,   -2.14978332,
          -4.48153137,           nan]),
 array([ -3.83118748, -22.3522723 , -14.78518453,  -3.6675207 ,
         -4.06782641,          nan])]

In [148]:
topology.get_node_child_leaf_mask()

array([[ True,  True],
       [ True,  True],
       [False,  True],
       [False, False],
       [False,  True],
       [False,  True]])

In [150]:
topology.get_init_heights()

array([ 0.,  0., 15., 10., 10., 45., 10., 55., 65., 20., 75., 10., 95.])

In [149]:
topology.child_indices[topology.node_mask]

array([[ 0,  1],
       [ 3,  4],
       [ 5,  6],
       [ 2,  7],
       [ 8,  9],
       [10, 11]])

In [154]:
topology.init_heights[topology.node_mask, np.newaxis] - topology.init_heights[topology.child_indices[topology.node_mask]]

array([[15., 15.],
       [35., 35.],
       [10., 45.],
       [50., 10.],
       [10., 55.],
       [20., 85.]])

In [161]:
topology.get_node_child_leaf_mask()

array([[ True,  True],
       [ True,  True],
       [False,  True],
       [False, False],
       [False,  True],
       [False,  True]])

In [160]:
topology.node_index_mapping[topology.child_indices[topology.node_mask]]

array([[-1, -1],
       [-1, -1],
       [ 1, -1],
       [ 0,  2],
       [ 3, -1],
       [ 4, -1]])

In [147]:
point = random_point()

sequences_fn(point), branch_lengths_fn(point)

(array(453999.21132449), array([[-10.58468291, -10.58468291],
        [ -9.497373  ,  -9.497373  ],
        [  1.01926153,  -8.47811147],
        [  4.49602817,   2.38945673],
        [  6.03634872,  -0.05230602],
        [  0.05230602,   0.        ]]))

In [44]:
import matplotlib.pyplot as plt

with model:
    fullrank = pm.ADVI() # Not full rank

fr_tracker = pm.callbacks.Tracker(
    mean=fullrank.approx.mean.eval,  # callable that returns mean
    std=fullrank.approx.std.eval  # callable that returns cov
)

approx_dim = np.sum([param.shape.eval() for param in fullrank.approx.params])
tolerance = 0.1*np.sqrt(approx_dim)
convergence = pm.callbacks.CheckParametersConvergence(tolerance=tolerance, diff='relative')

fit = fullrank.fit(n=100000, callbacks=[fr_tracker, convergence])

fig = plt.figure(figsize=(16, 9))
mu_ax = fig.add_subplot(221)
std_ax = fig.add_subplot(222)
hist_ax = fig.add_subplot(212)
mu_ax.plot(fr_tracker['mean'])
mu_ax.set_title('Mean track')
cov_data = np.stack(fr_tracker['std'])
std_ax.plot(cov_data.reshape((cov_data.shape[0], -1)))
std_ax.set_title('Cov track')
hist_ax.plot(fullrank.hist)
hist_ax.set_title('Negative ELBO track');

  0%|          | 0/100000 [00:00<?, ?it/s]


FloatingPointError: NaN occurred in optimization. 
The current approximation of RV `rates_log__`.ravel()[0] is NaN.
The current approximation of RV `rates_log__`.ravel()[1] is NaN.
The current approximation of RV `rates_log__`.ravel()[2] is NaN.
The current approximation of RV `rates_log__`.ravel()[3] is NaN.
The current approximation of RV `rates_log__`.ravel()[4] is NaN.
The current approximation of RV `rates_log__`.ravel()[5] is NaN.
The current approximation of RV `rates_log__`.ravel()[6] is NaN.
The current approximation of RV `rates_log__`.ravel()[7] is NaN.
The current approximation of RV `rates_log__`.ravel()[8] is NaN.
The current approximation of RV `rates_log__`.ravel()[9] is NaN.
The current approximation of RV `rates_log__`.ravel()[10] is NaN.
The current approximation of RV `rates_log__`.ravel()[11] is NaN.
The current approximation of RV `kappa_log__`.ravel()[0] is NaN.
The current approximation of RV `pi_stickbreaking__`.ravel()[0] is NaN.
The current approximation of RV `pi_stickbreaking__`.ravel()[1] is NaN.
The current approximation of RV `pi_stickbreaking__`.ravel()[2] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[0] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[1] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[2] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[3] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[4] is NaN.
The current approximation of RV `tree_tree_height_proportion__`.ravel()[5] is NaN.
Try tracking this parameter: http://docs.pymc.io/notebooks/variational_api_quickstart.html#Tracking-parameters

In [None]:
import arviz as az

trace = fit.sample(draws=1000)
inf_data = az.from_pymc3(trace=trace)

In [None]:
az.plot_joint(inf_data, ['branch_lengths', 'rates'])

In [None]:
log_weight = fit.sized_symbolic_logp - fit.symbolic_logq
log_weight

In [None]:
n_samples = 10000
log_weights = fit.set_size_and_deterministic(log_weight, n_samples, False).eval()
plt.hist(log_weights)

In [None]:
from scipy.special import logsumexp

log_weights_norm = (log_weights - logsumexp(log_weights))[:, np.newaxis]

plt.hist(log_weights_norm)

In [None]:
pm.stats._psislw(log_weights_norm, n_samples)