# Predictive accuracy of mutrans model on new lineages

This notebook assumes you have run
```sh
make update  # downloads and preprocesses data
python mutrans.py --vary-leaves=50  # or some largish number
```

In [None]:
import pickle
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import torch
from pyrocov import pangolin
from pyrocov.util import pearson_correlation

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

In [None]:
dataset = torch.load("results/mutrans.data.single.None.pt", map_location="cpu")
print(dataset.keys())
locals().update(dataset)

In [None]:
lineage_id = {name: i for i, name in enumerate(lineage_id_inv)}

In [None]:
results = torch.load("results/mutrans.pt")
for key in results:
    print(key)

In [None]:
best_fit = list(results.values())[0]
print(best_fit.keys())

In [None]:
loo = torch.load("results/mutrans.vary_leaves.pt")
print(len(loo))
print(list(loo)[0])
print(list(loo.values())[0].keys())

In [None]:
print(len(mutations))
print(features.shape)
print(best_fit["median"]["rate_loc"].shape)
print(list(loo.values())[0]["median"]["rate_loc"].shape)

In [None]:
loo_coef = {
    k[-1][-1][-1][-1][-1].replace("$", "").replace("^", ""): v["median"]["coef"]
    for k, v in loo.items()
}
print(list(loo_coef))

In [None]:
best_rate_loc = best_fit["median"]["rate_loc"]

In [None]:
def plot_prediction(filenames=[]):
    X1 = []
    Y1 = []
    X2 = []
    Y2 = []
    for child, coef in loo_coef.items():
        parent = pangolin.compress(pangolin.get_parent(pangolin.decompress(child)))
        c = lineage_id[child]
        p = lineage_id[parent]
        truth = best_rate_loc[c].item()
        naive = best_rate_loc[p].item()
        guess = rate_loc = 0.01 * torch.dot(features[c], coef).item()
        X1.append(truth)
        Y1.append(guess)
        X2.append(truth - naive)
        Y2.append(guess - naive)
    fig, axes = plt.subplots(1, 2, figsize=(7, 3.5))
    for ax, X, Y in zip(axes, [X1, X2], [Y1, Y2]):
        X = np.array(X)
        Y = np.array(Y)
        ax.scatter(X, Y, 50, lw=0, alpha=1, color="white")
        ax.scatter(X, Y, 30, lw=0, alpha=0.5, color="darkred")
        lb = min(min(X), min(Y))
        ub = max(max(X), max(Y))
        d = ub - lb
        lb -= 0.03 * d
        ub += 0.03 * d
        ax.plot([lb, ub], [lb, ub], "k--", alpha=0.2, zorder=-10)
        ax.set_xlim(lb, ub)
        ax.set_ylim(lb, ub)
        ax.text(lb + 0.06 * d, ub - 0.1 * d, f"ρ = {pearson_correlation(X, Y):0.2g}")
    axes[0].set_xlabel("full estimate")
    axes[0].set_ylabel("LOO estimate")
    axes[1].set_xlabel("full estimate − parent estimate")
    axes[1].set_ylabel("LOO estimate − parent estimate")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
plot_prediction(filenames=["paper/lineage_prediction.png"])