# Estimating rates and dates from time-stamped sequences
Data and models from the [BEAST documentation](https://beast.community/rates_and_dates),
though we tweak priors.

In [1]:
from treeflow import Alignment, AlignmentFormat

alignment = Alignment("demo-data/YFV.nex", format=AlignmentFormat.NEXUS)
alignment

Silencing TensorFlow...


Alignment(taxon_count=71, pattern_count=654)

In [5]:
from treeflow import parse_newick

starting_tree = parse_newick("demo-data/YFV.newick")
sequence_tensor = alignment.get_encoded_sequence_tensor(starting_tree.taxon_set)
sequence_tensor.shape

TensorShape([654, 71, 4])

We define a model using [TreeFlow's YAML model definition format](https://github.com/christiaanjs/treeflow/blob/master/docs/model-definition.md). We then parse this into a nested dictionary using Python's `yaml` library, and pass that to TreeFlow's `PhyloModel` class constructor. This class is converted into a TensorFlow Probability's `JointDistribution` using `phylo_model_to_joint_distribution`. This joint distribution class implements methods like `log_prob` and `sample`. The final statement of this code block examines the composition of samples from this joint distribution; it is a nested structure of Tensors which represent the variables of the model.

* TODO: Describe tree structure
* TODO: Mention dummy sampling

In [15]:
import yaml
import tensorflow as tf
from treeflow import (
    PhyloModel,
    convert_tree_to_tensor,
    phylo_model_to_joint_distribution
)

model_string = """
clock:
  strict:
    clock_rate:
      lognormal:
        loc: -2.0
        scale: 2.0
site:
  discrete_gamma:
    category_count: 4
    site_gamma_shape:
      lognormal:
        loc: 0.0
        scale: 1.0
substitution:
  hky:
    kappa:
      lognormal:
        loc: 1.0
        scale: 1.25
    frequencies:
      dirichlet:
        concentration:
        - 2.0
        - 2.0
        - 2.0
        - 2.0
tree:
  coalescent:
    pop_size:
      lognormal:
        loc: 1.0
        scale: 1.5
"""

model_dict = yaml.safe_load(model_string)
model = PhyloModel(model_dict)
starting_tensor_tree = convert_tree_to_tensor(starting_tree)
model_dist = phylo_model_to_joint_distribution(model, starting_tensor_tree, alignment)
model_dist_sample = model_dist.sample()
tf.nest.map_structure(lambda x: x.shape, model_dist_sample)

StructTuple(
  pop_size=TensorShape([]),
  tree=TensorflowRootedTree(node_heights=TensorShape([70]), sampling_times=TensorShape([71]), topology=TensorflowTreeTopology(parent_indices=TensorShape([140]), child_indices=TensorShape([141, 2]), preorder_indices=TensorShape([141]))),
  kappa=TensorShape([]),
  frequencies=TensorShape([4]),
  clock_rate=TensorShape([]),
  site_gamma_shape=TensorShape([]),
  alignment=TensorShape([654, 71, 4])
)

In [None]:
from treeflow.vi import fit_fixed_topology_variational_approximation, RobustOptimizer
from tqdm import tqdm
import matplotlib.pyplot as plt

model_dist_pinned = model_dist.experimental_pin(alignment=sequence_tensor)
optimizer = RobustOptimizer(tf.optimizers.Adam(learning_rate=0.001))
num_steps = 30000
approx, (loss, approx_vars) = fit_fixed_topology_variational_approximation(
    model_dist_pinned,
    topologies=dict(tree=starting_tensor_tree.topology),
    optimizer=optimizer,
    num_steps=num_steps,
    init_loc=dict(tree=starting_tensor_tree),
    progress_bar=tqdm
)
plt.plot(loss)

  9%|██████████████▌                                                                                                                                                  | 2720/30000 [01:30<13:53, 32.73it/s]

### Try model with sequences partitioned into codon

In [None]:
codon_partitioned_sequence_tensor = alignment.get_codon_partitioned_sequence_tensor(starting_tree.taxon_set)
codon_partitioned_sequence_tensor.shape