# Evaluating the Bethe model versus BEAST2

The results in this file were genererated by

1. Running BEAST 2.6.3 on `beast/M487.xml` to generate MCMC samples
2. Running [IQ-TREE 2.1.2](http://www.iqtree.org/#download) to generate a maximum likelihood tree
3. Running `python bethe.py` to fit a variational posterior

In [None]:
import sys
from collections import Counter
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.poutine as poutine
from pyrophylo.phylo import Phylogeny
from pyrophylo.io import stack_nexus_trees

## Load results from BEAST2

In [None]:
%%time
beast_phylo = stack_nexus_trees('beast/M487.trees')
beast_phylo = beast_phylo[501:]
assert len(beast_phylo) == 1000

In [None]:
beast_counts = Counter(beast_phylo.hash_topology())
beast_probs = {tree: count / len(beast_phylo) for tree, count in beast_counts.items()}
print(len(beast_counts))
print([count / len(beast_phylo) for tree, count in beast_counts.most_common(10)])

In [None]:
def plot_skyline(phylo, max_samples=100):
    phylo = phylo[:max_samples]
    num_lineages = phylo.num_lineages()
    times = phylo.times - phylo.times.max(-1, True).values
    plt.figure(figsize=(8, 4), dpi=300)
    plt.plot(times.T, num_lineages.T, 'k-', lw=1, alpha=0.5)
    plt.xlabel("time")
    plt.ylabel("number of lineages")
    
plot_skyline(beast_phylo)

## Load results from IQ-TREE

These results were generated by running
```sh
iqtree-2.1.2-MacOSX/bin/iqtree2 -s data/treebase/M487.nex \
  -m jc69 --date-root -1 --date-tip 0 --clock-sd 0
```

In [None]:
iqtree_phylo = stack_nexus_trees("data/treebase/M487.nex.timetree.nex")
assert len(iqtree_phylo) == 1

In [None]:
plot_skyline(iqtree_phylo)

## Load results from Bethe VI

In [None]:
pyro.clear_param_store()
bethe = torch.load("results/bethe.pt")
model = bethe["model"]
guide = bethe["guide"]
if bethe["args"].double:
    torch.set_default_dtype(torch.double)

In [None]:
%%time
num_samples = len(beast_phylo)
leaves = torch.arange(model.num_leaves)
trees = []
with torch.no_grad():
    for i in range(num_samples):
        trace = poutine.trace(guide).get_trace()
        with poutine.replay(trace=trace):
            codes, times, parents = model(mode="predict")
            tree = Phylogeny.from_unsorted(times, parents, leaves)
            trees.append(tree)
            if i % 10 == 0:
                sys.stdout.write(".")
                sys.stdout.flush()
    bethe_phylo = Phylogeny.stack(trees)

In [None]:
bethe_counts = Counter(bethe_phylo.hash_topology())
bethe_probs = {tree: count / len(bethe_phylo) for tree, count in bethe_counts.items()}
print(len(bethe_counts))
print([count / len(bethe_phylo) for tree, count in bethe_counts.most_common(10)])

In [None]:
plot_skyline(bethe_phylo)