In [2]:
import sys
sys.path += ['/Users/cswa648/git/phylo-hacking/pymc']
import theano
import newick
from xml.etree import ElementTree

%matplotlib inline

data_xml_filename = '../../data/ratites.SRD06.RLC.YULE.xml'
xml_root = ElementTree.parse(data_xml_filename)
newick_string = xml_root.find('.//newick').text
tree = newick.loads(newick_string)[0]
print(tree.ascii_art())

                            ┌─ANDI
                     ┌──────┤
                     │      └─DIGI
              ┌──────┤
              │      │             ┌─CASS
              │      │      ┌──────┤
       ┌──────┤      └──────┤      └─EMU
       │      │             └─KIWI
───────┤      │
       │      └─RHEA
       └─OST


In [27]:
sequence_dict = { seq_element.find('./taxon').attrib['idref']: list(seq_element.itertext())[-1].strip()
 for seq_element in xml_root.findall('.//sequence') }

from importlib import reload
import pylo.topology
reload(pylo.topology)

from pylo.topology import TreeTopology
topology = TreeTopology(tree)

import numpy as np
from pylo.transform import group_sequences, encode_sequences
sequence_dict_encoded = encode_sequences(sequence_dict)
pattern_dict, pattern_counts = group_sequences(sequence_dict_encoded)
child_patterns = np.array(topology.build_sequence_table(pattern_dict))
pattern_counts = np.array(pattern_counts)

import theano.tensor as tt

child_patterns_ = tt.as_tensor_variable(child_patterns)
pattern_counts_ = tt.as_tensor_variable(pattern_counts)

In [34]:
import pylo.tree.coalescent
import pylo.pruning

reload(pylo.tree.coalescent)
reload(pylo.pruning)

import pymc3 as pm
from pylo.tree.coalescent import CoalescentTree, ConstantPopulationFunction
from pylo.hky import HKYSubstitutionModel
from pylo.pruning import LeafSequences

with pm.Model() as model:    
    population_size = pm.Gamma('population_size', alpha=2.0, beta=0.1)
    population_function = ConstantPopulationFunction(topology, population_size)
    tree_heights = CoalescentTree('tree', topology, population_function)
    branch_lengths = pm.Deterministic('branch_lengths', topology.get_child_branch_lengths(tree_heights))
    rates = pm.Lognormal('rates', shape=(topology.get_internal_node_count(), 2))
    
    distances = pm.Deterministic('distances', rates*branch_lengths)
    
    kappa = pm.Exponential('kappa', lam=0.1)
    pi = pm.Dirichlet('pi', a=np.ones(4))
    
    substitution_model = HKYSubstitutionModel(kappa, pi)
    sequences = LeafSequences('sequences', topology, substitution_model, distances, child_patterns_, pattern_counts_)
    
model

<pymc3.model.Model at 0x1c14d77630>

In [35]:
for RV in model.basic_RVs:
    print(RV.name, RV.logp(model.test_point))

population_size_log__ -0.6137056388801092
tree_tree_height_proportion__ -17.79500663732926
rates_log__ -11.027262398456068
kappa_log__ -1.0596601002984285
pi_stickbreaking__ -3.753417975251508


In [None]:
import matplotlib.pyplot as plt

with model:
    fullrank = pm.ADVI() # Not full rank

fr_tracker = pm.callbacks.Tracker(
    mean=fullrank.approx.mean.eval,  # callable that returns mean
    std=fullrank.approx.cov.eval  # callable that returns cov
)

approx_dim = np.sum([param.shape.eval() for param in fullrank.approx.params])
tolerance = 0.1*np.sqrt(approx_dim)
convergence = pm.callbacks.CheckParametersConvergence(tolerance=tolerance, diff='relative')

fit = fullrank.fit(n=100000, callbacks=[fr_tracker, convergence])

fig = plt.figure(figsize=(16, 9))
mu_ax = fig.add_subplot(221)
std_ax = fig.add_subplot(222)
hist_ax = fig.add_subplot(212)
mu_ax.plot(fr_tracker['mean'])
mu_ax.set_title('Mean track')
cov_data = np.stack(fr_tracker['std'])
std_ax.plot(cov_data.reshape((cov_data.shape[0], -1)))
std_ax.set_title('Cov track')
hist_ax.plot(fullrank.hist)
hist_ax.set_title('Negative ELBO track');

In [None]:
import arviz as az

trace = fit.sample(draws=1000)
inf_data = az.from_pymc3(trace=trace)

In [None]:
az.plot_joint(inf_data, ['branch_lengths', 'rates'])

In [None]:
log_weight = fit.sized_symbolic_logp - fit.symbolic_logq
log_weight

In [None]:
n_samples = 10000
log_weights = fit.set_size_and_deterministic(log_weight, n_samples, False).eval()
plt.hist(log_weights)

In [None]:
from scipy.special import logsumexp

log_weights_norm = (log_weights - logsumexp(log_weights))[:, np.newaxis]

plt.hist(log_weights_norm)

In [None]:
pm.stats._psislw(log_weights_norm, n_samples)