In [2]:
from importlib import reload

import treeflow.tree_processing
reload(treeflow.tree_processing)
import treeflow.sequences
newick_file = '../data/analysis-tree.nwk'
fasta_file = '../data/sim-seq.fasta'

tree, taxon_names = treeflow.tree_processing.parse_newick(newick_file)
topology = treeflow.tree_processing.update_topology_dict(tree['topology'])
taxon_count = len(taxon_names)
taxon_count, topology.keys()

(100,
 dict_keys(['postorder_node_indices', 'child_indices', 'preorder_indices', 'preorder_node_indices', 'sibling_indices', 'parent_indices']))

# Variational approximation

In [3]:
import treeflow.tree_transform
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions


import treeflow.tf_util
reload(treeflow.tf_util)

reload(treeflow.tree_transform)

anchor_heights = treeflow.tree_processing.get_node_anchor_heights(tree['heights'], topology['postorder_node_indices'], topology['child_indices'])
anchor_heights = tf.convert_to_tensor(anchor_heights, dtype=tf.float32)
tree_chain = treeflow.tree_transform.TreeChain(
    topology['parent_indices'][taxon_count:] - taxon_count,
    topology['preorder_node_indices'][1:] - taxon_count,
    anchor_heights=anchor_heights)
init_heights = tf.convert_to_tensor(tree['heights'][taxon_count:], dtype=tf.float32)
init_heights_trans = tree_chain.inverse(init_heights)
leaf_heights = tf.convert_to_tensor(tree['heights'][:taxon_count], dtype=tf.float32)

height_dist = tfd.Blockwise([
    tfd.Independent(tfd.Deterministic(leaf_heights), reinterpreted_batch_ndims=1),
    tfd.TransformedDistribution(
        distribution=tfd.Independent(tfd.Normal(
            loc=tf.Variable(init_heights_trans, name='q_tree_loc'),
            scale=tfp.util.DeferredTensor(tf.Variable(tf.ones_like(init_heights_trans), name='q_tree_scale'), tf.nn.softplus)
        ), reinterpreted_batch_ndims=1),
        bijector=tree_chain
    )
])

q = tfd.JointDistributionNamed(dict(
    tree=treeflow.tree_transform.FixedTopologyDistribution(
        height_distribution=height_dist,
        topology=tree['topology']
    ),
    kappa=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_kappa_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_kappa_scale'), tf.nn.softplus)
    ),
    pop_size=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_pop_size_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_pop_size_scale'), tf.nn.softplus)
    ),
    frequencies=tfd.Dirichlet(concentration=tfp.util.DeferredTensor(tf.Variable([4.0, 4.0, 4.0, 4.0], name='q_frequencies_concentration'), tf.nn.softplus))
))
#[(var.name, var.shape) for var in q.trainable_variables]
q.sample()

{'tree': {'topology': {'parent_indices': <tf.Tensor: shape=(198,), dtype=int64, numpy=
   array([100, 100, 101, 102, 103, 104, 105, 106, 106, 107, 108, 109, 111,
          111, 112, 112, 115, 115, 116, 117, 118, 118, 120, 120, 121, 122,
          122, 123, 125, 125, 127, 127, 128, 128, 129, 129, 130, 132, 132,
          133, 136, 140, 140, 141, 141, 142, 144, 145, 146, 146, 147, 148,
          148, 151, 151, 153, 153, 154, 156, 156, 157, 158, 158, 159, 161,
          161, 162, 164, 165, 166, 167, 167, 168, 169, 170, 173, 174, 174,
          175, 176, 177, 177, 179, 181, 181, 183, 183, 184, 185, 185, 186,
          186, 187, 188, 189, 190, 192, 192, 193, 194, 101, 102, 103, 104,
          105, 110, 107, 108, 109, 110, 114, 113, 113, 114, 139, 116, 117,
          119, 119, 138, 121, 124, 123, 124, 126, 126, 137, 135, 131, 130,
          131, 134, 133, 134, 135, 136, 137, 138, 139, 198, 143, 142, 143,
          144, 145, 150, 147, 149, 149, 150, 152, 152, 155, 154, 155, 172,
          157

# Model
## Prior

In [4]:
import treeflow.coalescent
reload(treeflow.coalescent)

sampling_times = tf.convert_to_tensor(tree['heights'][:taxon_count], dtype=tf.float32)

prior = tfd.JointDistributionNamed(dict(
    frequencies=tfd.Dirichlet(concentration=[4,4,4,4]),
    kappa=tfd.LogNormal(loc=0, scale=1),
    pop_size=tfd.LogNormal(loc=0, scale=1),
    #site_alpha=tfd.LogNormal(loc=0, scale=1),
    #clock_rate=tfd.LogNormal(loc=0, scale=1),
    tree=lambda pop_size: treeflow.coalescent.ConstantCoalescent(pop_size=pop_size, sampling_times=sampling_times)
))

prior.log_prob_parts(q.sample(sample_shape=[2]))

{'pop_size': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-8.627281 , -2.3740983], dtype=float32)>,
 'tree': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-329.62326 , -105.402954], dtype=float32)>,
 'kappa': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([-0.5493054, -0.5087566], dtype=float32)>,
 'frequencies': <tf.Tensor: shape=(2,), dtype=float32, numpy=array([2.8152523, 1.9390144], dtype=float32)>}

In [5]:
tfp.vi.fit_surrogate_posterior(prior.log_prob, q, tf.optimizers.Adam(), 10)



<tf.Tensor: shape=(10,), dtype=float32, numpy=
array([ 528.4222 ,  595.3603 ,  804.103  ,  532.0404 ,  599.3718 ,
        534.39667,  665.00134,  762.3739 , 3053.35   ,  713.8525 ],
      dtype=float32)>

## Likelihood

In [10]:
import treeflow.sequences
reload(treeflow.sequences)
import treeflow.substitution_model

subst_model = treeflow.substitution_model.HKY()
category_weights = tf.ones(1)
category_rates = tf.ones(1)

alignment = treeflow.sequences.get_encoded_sequences(fasta_file, taxon_names)
log_prob_conditioned = treeflow.sequences.log_prob_conditioned(alignment, tree['topology'], 1)
                 
def log_likelihood(tree, kappa, frequencies):  
    return log_prob_conditioned(
        subst_model=subst_model,
        category_weights=category_weights,
        category_rates=category_rates,
        branch_lengths=treeflow.sequences.get_branch_lengths(tree, ),
        frequencies=frequencies,
        kappa=kappa
    )
                 
q_sample = q.sample(sample_shape=[2])
                 
log_likelihood(q_sample['tree'], q_sample['kappa'], q_sample['frequencies'])

InvalidArgumentError: Incompatible shapes: [2,198] vs. [1,199] [Op:Sub]