# Predictive accuracy of mutrans model on new lineages

This notebook assumes you have run
```sh
make update
make preprocess
python mutrans.py --vary-leaves
```

In [None]:
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from pyrocov import pangolin
from pyrocov.util import pearson_correlation, quotient_central_moments

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["figure.facecolor"] = "white"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

In [None]:
dataset = torch.load("results/mutrans.data.single.3000.1.50.None.pt", map_location="cpu")
print(dataset.keys())
locals().update(dataset)

In [None]:
lineage_id = dataset["lineage_id"]
clade_id_to_lineage_id = dataset["clade_id_to_lineage_id"]

In [None]:
loo = torch.load("results/mutrans.vary_leaves.pt", map_location="cpu")
print(len(loo))
print(list(loo)[-1])
print(list(loo.values())[-1].keys())

In [None]:
print(len(mutations))
print(list(loo.values())[0]["median"]["rate_loc"].shape)

In [None]:
best_rate_loc = None
loo_rate_loc = {}
for k, v in loo.items():
    rate = quotient_central_moments(v["median"]["rate_loc"], clade_id_to_lineage_id)[1]
    holdout = k[-1]
    if holdout:
        key = holdout[-1][-1][-1][-1].replace("$", "").replace("^", "")
        loo_rate_loc[key] = rate
    else:
        best_rate_loc = rate
print(" ".join(loo_rate_loc))

In [None]:
def plot_prediction(filenames=[], debug=False, use_who=True):
    X1, Y1, X2, Y2, labels, debug_labels = [], [], [], [], [], []
    who = {vs[0]: k for k, vs in pangolin.WHO_ALIASES.items()}
    ancestors = set(lineage_id)
    for child, rate_loc in loo_rate_loc.items():
        parent = pangolin.compress(
            pangolin.get_most_recent_ancestor(
                pangolin.decompress(child), ancestors
            )
        )
        c = lineage_id[child]
        p = lineage_id[parent]
        truth = best_rate_loc[c].item()
        baseline = rate_loc[p].item()
        guess = rate_loc[c].item()
        Y1.append(truth)
        X1.append(guess)
        Y2.append(truth - baseline)
        X2.append(guess - baseline)
        labels.append(who.get(child))
        debug_labels.append(child)
    mae = np.abs(np.array(Y2)).mean()
    print(f"MAE(baseline - full estimate) = {mae:0.4g}")
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    for ax, X, Y in zip(axes, [X1, X2], [Y1, Y2]):
        X = np.array(X)
        Y = np.array(Y)
        ax.scatter(X, Y, 40, lw=0, alpha=1, color="white", zorder=-5)
        ax.scatter(X, Y, 20, lw=0, alpha=0.3, color="darkred")
        lb = min(min(X), min(Y))
        ub = max(max(X), max(Y))
        d = ub - lb
        lb -= 0.03 * d
        ub += 0.05 * d
        ax.plot([lb, ub], [lb, ub], "k--", alpha=0.2, zorder=-10)
        ax.set_xlim(lb, ub)
        ax.set_ylim(lb, ub)
        rho = pearson_correlation(X, Y)
        mae = np.abs(X - Y).mean()
        ax.text(0.3 * lb + 0.7 * ub, 0.8 * lb + 0.2 * ub,
                #f" ρ = {rho:0.3f}\nMAE = {mae:0.3g}",
                f" ρ = {rho:0.3f}",
                backgroundcolor="white", ha="center", va="center")
        for x, y, label, debug_label in zip(X, Y, labels, debug_labels):
            pad = 0.012
            if label is not None:
                ax.plot([x], [y], "ko", mfc="#c77", c="black", ms=4, mew=0.5)
                ax.text(x, y + pad, label if use_who else debug_label,
                        va="bottom", ha="center", fontsize=6)
            elif abs(x - y) > 0.2:
                ax.plot([x], [y], "ko", mfc="#c77", c="black", ms=4, mew=0.5)
                ax.text(x, y + pad, debug_label, va="bottom", ha="center", fontsize=6)
    axes[0].set_ylabel("full estimate")
    axes[0].set_xlabel("LOO estimate")
    axes[1].set_ylabel("full estimate − baseline")
    axes[1].set_xlabel("LOO estimate − baseline")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
plot_prediction(debug=True)
plot_prediction(use_who=False, filenames=["paper/lineage_prediction.png"])