In [1]:
import treeflow.tree_processing
from importlib import reload
reload(treeflow.tree_processing)

import treeflow.sequences
newick_file = '../data/analysis-tree.nwk'
fasta_file = '../data/sim-seq.fasta'

tree, taxon_names = treeflow.tree_processing.parse_newick(newick_file)
topology = treeflow.tree_processing.update_topology_dict(tree['topology'])
taxon_count = len(taxon_names)
taxon_count, topology.keys()

(100,
 dict_keys(['postorder_node_indices', 'child_indices', 'preorder_indices', 'preorder_node_indices', 'sibling_indices', 'parent_indices']))

# Variational approximation

In [47]:
import treeflow.tree_transform
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from importlib import reload
reload(treeflow.tree_transform)

anchor_heights = treeflow.tree_processing.get_node_anchor_heights(tree['heights'], topology['postorder_node_indices'], topology['child_indices'])
anchor_heights = tf.convert_to_tensor(anchor_heights, dtype=tf.float32)
tree_chain = treeflow.tree_transform.TreeChain(
    tf.convert_to_tensor(topology['parent_indices'][taxon_count:] - taxon_count),
    tf.convert_to_tensor(topology['preorder_node_indices'][1:] - taxon_count),
    anchor_heights=anchor_heights)
init_heights = tf.convert_to_tensor(tree['heights'][taxon_count:], dtype=tf.float32)
init_heights_trans = tree_chain.inverse(init_heights)

q = tfd.JointDistributionNamed(dict(
    tree=treeflow.tree_transform.FixedTopologyDistribution(
        height_distribution=tfd.TransformedDistribution(
            distribution=tfd.Independent(tfd.Normal(
                loc=tf.Variable(init_heights_trans, name='q_tree_loc'),
                scale=tfp.util.DeferredTensor(tf.Variable(tf.ones_like(init_heights_trans), name='q_tree_scale'), tf.nn.softplus)
            ), reinterpreted_batch_ndims=1),
            bijector=tree_chain),
        topology=tree['topology']
    ),
    kappa=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_kappa_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_kappa_scale'), tf.nn.softplus)
    ),
    pop_size=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_pop_size_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_pop_size_scale'), tf.nn.softplus)
    ),
    frequencies=tfd.Dirichlet(concentration=tfp.util.DeferredTensor(tf.Variable([4.0, 4.0, 4.0, 4.0], name='q_frequencies_concentration'), tf.nn.softplus)),
    site_alpha=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_site_alpha_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_site_alpha_scale'), tf.nn.softplus)
    ),
    clock_rate=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_clock_rate_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_clock_rate_scale'), tf.nn.softplus)
    )
))
[(var.name, var.shape) for var in q.trainable_variables]

[('q_tree_loc:0', TensorShape([99])),
 ('q_tree_scale:0', TensorShape([99])),
 ('q_site_alpha_loc:0', TensorShape([])),
 ('q_site_alpha_scale:0', TensorShape([])),
 ('q_pop_size_loc:0', TensorShape([])),
 ('q_pop_size_scale:0', TensorShape([])),
 ('q_kappa_loc:0', TensorShape([])),
 ('q_kappa_scale:0', TensorShape([])),
 ('q_frequencies_concentration:0', TensorShape([4])),
 ('q_clock_rate_loc:0', TensorShape([])),
 ('q_clock_rate_scale:0', TensorShape([]))]

# Model
## Prior

In [50]:
import treeflow.coalescent
from importlib import reload
reload(treeflow.coalescent)

sampling_times = tf.convert_to_tensor(tree['heights'][:taxon_count], dtype=tf.float32)

prior = tfd.JointDistributionNamed(dict(
    frequencies=tfd.Dirichlet(concentration=[4,4,4,4]),
    kappa=tfd.LogNormal(loc=0, scale=1),
    pop_size=tfd.LogNormal(loc=0, scale=1),
    site_alpha=tfd.LogNormal(loc=0, scale=1),
    clock_rate=tfd.LogNormal(loc=0, scale=1),
    tree=lambda pop_size: treeflow.coalescent.ConstantCoalescent(pop_size=pop_size, sampling_times=sampling_times)
))

prior.log_prob(**q.sample())

<tf.Tensor: id=90490, shape=(), dtype=float32, numpy=48.715714>

In [51]:
res = tfp.vi.fit_surrogate_posterior(prior.log_prob, q, tf.optimizers.Adam(), 3, trace_fn=lambda loss, grads, variables: variables)
res

[<tf.Tensor: id=99185, shape=(3, 99), dtype=float32, numpy=
 array([[ 3.3782930e+00,  3.4117954e+00,  3.4442155e+00, -6.2570626e-01,
          8.1092781e-01,  1.5589943e+00,  3.5093479e+00,  8.8555747e-01,
          7.6570195e-01, -2.4412508e-01, -4.3821758e-01,  1.6858009e+00,
         -1.3917049e+00, -5.8846277e-01,  2.0737803e+00,  2.1140511e+00,
         -1.2818749e-01, -1.1141236e+00, -7.7625531e-01, -5.1379293e-01,
         -2.3565246e-01,  4.4677782e+00, -2.5858173e+00, -6.8568960e-02,
          3.0321419e+00, -9.1351002e-01, -7.1541443e-02, -1.5786476e+00,
         -6.8509263e-01, -6.9343168e-01, -5.2141678e-01, -5.8647221e-01,
         -1.0002382e-03, -4.3526316e+00,  1.5744197e+00,  5.2446499e+00,
          5.2499094e+00,  2.1429598e+00, -1.5861918e-01,  4.1033179e-01,
         -3.0476968e+00,  2.5808351e+00, -6.8954363e-02,  3.6514730e+00,
          3.6770725e+00,  2.3879592e-01, -6.9414717e-01,  3.6867785e-01,
         -4.0400076e+00,  1.3662626e+00,  1.5545955e+00, -1.3380

## Likelihood

In [43]:
reload(treeflow.sequences)
sequences, pattern_counts = treeflow.sequences.get_encoded_sequences(fasta_file, taxon_names)
category_count = 4
log_prob_conditioned = treeflow.sequences.log_prob_conditioned(dict(sequences=sequences, weights=pattern_counts), tree['topology'], category_count)

## Joint density function (conditioned)