In [3]:
import treeflow.tree_processing
import treeflow.sequences
newick_file = '../data/analysis-tree.nwk'
fasta_file = '../data/sim-seq.fasta'

tree, taxon_names = treeflow.tree_processing.parse_newick(newick_file)
topology = treeflow.tree_processing.update_topology_dict(tree['topology'])
taxon_count = len(taxon_names)
taxon_count, topology.keys()

(100,
 dict_keys(['postorder_node_indices', 'child_indices', 'preorder_indices', 'preorder_node_indices', 'sibling_indices', 'parent_indices']))

# Variational approximation

In [8]:
import treeflow.tree_transform
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

from importlib import reload
reload(treeflow.tree_transform)

anchor_heights = treeflow.tree_processing.get_node_anchor_heights(tree['heights'], topology['postorder_node_indices'], topology['child_indices'])
anchor_heights = tf.convert_to_tensor(anchor_heights, dtype=tf.float32)
tree_chain = treeflow.tree_transform.TreeChain(
    topology['parent_indices'][taxon_count:] - taxon_count,
    topology['preorder_node_indices'][1:] - taxon_count,
    anchor_heights=anchor_heights)
init_heights = tf.convert_to_tensor(tree['heights'][taxon_count:], dtype=tf.float32)
init_heights_trans = tree_chain.inverse(init_heights)

q = tfd.JointDistributionNamed(dict(
    tree=treeflow.tree_transform.FixedTopologyDistribution(
        height_distribution=tfd.TransformedDistribution(
            distribution=tfd.Independent(tfd.Normal(
                loc=tf.Variable(init_heights_trans, name='q_tree_loc'),
                scale=tfp.util.DeferredTensor(tf.Variable(tf.ones_like(init_heights_trans), name='q_tree_scale'), tf.nn.softplus)
            )),
            bijector=tree_chain),
        topology=tree['topology']
    ),
    kappa=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_kappa_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_kappa_scale'), tf.nn.softplus)
    ),
    pop_size=tfd.LogNormal(
        loc=tf.Variable(0.0, name='q_pop_size_loc'),
        scale=tfp.util.DeferredTensor(tf.Variable(1.0, name='q_pop_size_scale'), tf.nn.softplus)
    ),
    frequencies=tfd.Dirichlet(concentration=tfp.util.DeferredTensor(tf.Variable([4.0, 4.0, 4.0, 4.0], name='q_frequencies_concentration'), tf.nn.softplus))
))
[(var.name, var.shape) for var in q.trainable_variables]

[('q_tree_loc:0', TensorShape([99])),
 ('q_tree_scale:0', TensorShape([99])),
 ('q_pop_size_loc:0', TensorShape([])),
 ('q_pop_size_scale:0', TensorShape([])),
 ('q_kappa_loc:0', TensorShape([])),
 ('q_kappa_scale:0', TensorShape([])),
 ('q_frequencies_concentration:0', TensorShape([4]))]

# Model
## Prior

In [5]:
import treeflow.coalescent
reload(treeflow.coalescent)

sampling_times = tf.convert_to_tensor(tree['heights'][:taxon_count], dtype=tf.float32)

prior = tfd.JointDistributionNamed(dict(
    frequencies=tfd.Dirichlet(concentration=[4,4,4,4]),
    kappa=tfd.LogNormal(loc=0, scale=1),
    pop_size=tfd.LogNormal(loc=0, scale=1),
    #site_alpha=tfd.LogNormal(loc=0, scale=1),
    #clock_rate=tfd.LogNormal(loc=0, scale=1),
    tree=lambda pop_size: treeflow.coalescent.ConstantCoalescent(pop_size=pop_size, sampling_times=sampling_times)
))

prior.log_prob(**q.sample())



<tf.Tensor: id=2711, shape=(), dtype=float32, numpy=-240.13509>

In [10]:
tfp.vi.fit_surrogate_posterior(prior.log_prob, q, tf.optimizers.Adam(), 1000)

ValueError: Shape must be rank 1 but is rank 2 for 'monte_carlo_variational_loss/expectation/JointDistributionNamed/log_prob/monte_carlo_variational_loss_expectation_JointDistributionNamed_log_prob_ConstantCoalescent_1/log_prob/concat' (op: 'ConcatV2') with input shapes: [100], [1,99], [].

## Likelihood

## Joint density function (conditioned)