# Evaluating Bethe variational inference

The results from this notebook were generated by running
```sh
python bethe_vi.py
```

In [None]:
import torch
import matplotlib.pyplot as plt

In [None]:
results = torch.load("results/bethe_vi.pt")
print("results:", list(results))
data = results["data"]
samples = results["samples"]
print("data:", list(data))
print("samples:", list(samples))
print("args:", results["args"])

In [None]:
plt.figure(figsize=(8, 3), dpi=300)
plt.plot(results["losses"])
plt.ylabel("loss")
plt.xlabel("SVI step")
plt.tight_layout();

In [None]:
trees = samples["trees"]
codes = samples["codes"]
N, L = trees.num_nodes, trees.num_leaves
assert set(trees[0].leaves.tolist()) == set(range(N-L, N))
s = 0
tree = trees[s]
times, codex, codey = tree.times, codes[s, :, 0], codes[s, :, 1]

def plot(ax, X, Y):
    x, y = [], []
    for i, j in enumerate(tree.parents.tolist()):
        if j != -1:
            x.extend((X[i], X[j], None))
            y.extend((Y[i], Y[j], None))
    ax.plot(x, y, color="blue", markersize=3, lw=1/3, zorder=4, label="example tree")
    ax.scatter(X[:N-L], Y[:N-L], s=3, color="blue", zorder=1)
    ax.scatter(X[N-L:], Y[N-L:], s=10, color="black", zorder=0, label="leaf")

fig, axes = plt.subplots(2, 2, figsize=(8, 8), sharex="col", sharey="row", dpi=300)
plot(axes[0][1], codey, codex)
plot(axes[0][0], times, codex)
plot(axes[1][1], codey, times)

# Plot distributions of codes.
alpha = 0.1
N, L = tree.num_nodes, tree.num_leaves
for span, color, zorder in [(slice(0, N-L), "red", 3), (slice(N-L, N), "black", 2)]:
    times = trees.times[:, span].reshape(-1)
    codex = codes[:, span, 0].reshape(-1)
    codey = codes[:, span, 1].reshape(-1)
    options = dict(s=1, alpha=alpha, color=color, zorder=zorder)
    axes[0][1].scatter(codey, codex, **options)
    axes[0][0].scatter(times, codex, **options)
    axes[1][1].scatter(codey, times, **options)

axes[0][0].set_xlim(1.25 * trees[0].times.min(), 0.2)
axes[1][1].set_ylim(1.25 * trees[0].times.min(), 0.2)
axes[0][0].set_ylabel("embedding 0")
axes[1][0].set_ylabel("time")
axes[1][0].set_xlabel("time")
axes[1][1].set_xlabel("embedding 1")
ax = axes[1][0]
ax.plot([], [], "ko", label="leaf")
ax.plot([], [], "ro", label="sampled coalescent events")
ax.plot([], [], "b-", marker="o", label="a single sample phylogeny")
ax.legend(loc="best")
fig.suptitle(f"Sampled {len(trees)} posterior phylogenies for {L} taxa")
plt.subplots_adjust(hspace=0, wspace=0, top=0.95)