# Predictive accuracy of mutrans model on new lineages

This notebook assumes you have run
```sh
make update  # downloads and preprocesses data
python mutrans.py --vary-leaves=50  # or some largish number
```

In [None]:
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from pyrocov import pangolin
from pyrocov.util import pearson_correlation

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

In [None]:
dataset = torch.load("results/mutrans.data.single.None.pt", map_location="cpu")
print(dataset.keys())
locals().update(dataset)

In [None]:
lineage_id = {name: i for i, name in enumerate(lineage_id_inv)}

In [None]:
results = torch.load("results/mutrans.pt")
for key in results:
    print(key)

In [None]:
best_fit = list(results.values())[0]
print(best_fit.keys())

In [None]:
loo = torch.load("results/mutrans.vary_leaves.pt", map_location="cpu")
print(len(loo))
print(list(loo)[0])
print(list(loo.values())[0].keys())

In [None]:
print(len(mutations))
print(best_fit["median"]["rate_loc"].shape)
print(list(loo.values())[0]["median"]["rate_loc"].shape)

In [None]:
loo_rate_loc = {
    k[-1][-1][-1][-1][-1].replace("$", "").replace("^", ""): v["median"]["rate_loc"]
    for k, v in loo.items()
}
print(" ".join(loo_rate_loc))

In [None]:
best_rate_loc = best_fit["median"]["rate_loc"]

In [None]:
def plot_prediction(filenames=[], debug=False):
    X1, Y1, X2, Y2, labels, debug_labels = [], [], [], [], [], []
    who = {vs[0]: k for k, vs in pangolin.WHO_ALIASES.items()}
    ancestors = set(lineage_id)
    for child, rate_loc in loo_rate_loc.items():
        parent = pangolin.compress(
            pangolin.get_most_recent_ancestor(
                pangolin.decompress(child), ancestors
            )
        )
        c = lineage_id[child]
        p = lineage_id[parent]
        truth = best_rate_loc[c].item()
        baseline = rate_loc[p].item()
        guess = rate_loc[c].item()
        X1.append(truth)
        Y1.append(guess)
        X2.append(truth - baseline)
        Y2.append(guess - baseline)
        labels.append(who.get(child))
        debug_labels.append(child)
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    label_offsets = {
        "Alpha": (0.0, 0.012),
        "Beta": (-0.03, 0.03),
        "Gamma": (-0.06, 0.06),
        "Delta": (-0.03, 0.03),
        "Eta": (0.03, -0.03),
        "Iota": (0.0, -0.012),
        "Kappa": (0.04, -0.04),
        "Lambda": (0.03, -0.03),
    }
    label_offsets = {}
    for ax, X, Y in zip(axes, [X1, X2], [Y1, Y2]):
        X = np.array(X)
        Y = np.array(Y)
        ax.scatter(X, Y, 40, lw=0, alpha=1, color="white", zorder=-5)
        ax.scatter(X, Y, 20, lw=0, alpha=0.5, color="darkred")
        lb = min(min(X), min(Y))
        ub = max(max(X), max(Y))
        d = ub - lb
        lb -= 0.03 * d
        ub += 0.05 * d
        ax.plot([lb, ub], [lb, ub], "k--", alpha=0.2, zorder=-10)
        ax.set_xlim(lb, ub)
        ax.set_ylim(lb, ub)
        ax.text(lb + 0.06 * d, ub - 0.1 * d, f"ρ = {pearson_correlation(X, Y):0.2g}",
                backgroundcolor="white")
        for x, y, label, debug_label in zip(X, Y, labels, debug_labels):
            if label is not None:
                pad = 0.01
                ax.plot([x], [y], "ko", mfc="#c77", c="black", ms=4, mew=0.5)
                if label in label_offsets:
                    dx, dy = label_offsets[label]
                    ax.plot([x, x + dx], [y, y + dy], "k-", lw=0.5, alpha=0.25, zorder=-1)
                    ax.text(x + dx, y + dy, label, fontsize=7, alpha=0.7,
                            ha="center" if dx == 0 else "left" if dx > 0 else "right",
                            va="center" if dy == 0 else "bottom" if dy > 0 else "top")
                else:
                    ax.text(x, y + pad, label, va="bottom", ha="center", fontsize=7)
            if debug and abs(x - y) > 0.2:
                ax.text(x, y - pad, debug_label, va="top", ha="center", fontsize=5)
    axes[0].set_xlabel("full estimate")
    axes[0].set_ylabel("LOO estimate")
    axes[1].set_xlabel("full estimate − baseline")
    axes[1].set_ylabel("LOO estimate − baseline")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
plot_prediction(debug=True)
plot_prediction(filenames=["paper/lineage_prediction.png"])