# Analyzing results of grid search

This notebook assumes you've downloaded data and run a grid search experiment
```sh
make update  # many hours
python mutrans.py --grid-search  # many hours
```

In [None]:
import math
import re
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pyrocov.util import pearson_correlation
from pyrocov.plotting import force_apart

matplotlib.rcParams["figure.dpi"] = 200

In [None]:
df = pd.read_csv("results/grid_search.tsv", sep="\t")
df = df.fillna("")
df

In [None]:
df.columns

In [None]:
model_type = df["model_type"].to_list()
cond_data = df["cond_data"].to_list()
mutation_corr = df["mutation_corr"].to_numpy()
mutation_error = df["mutation_rmse"].to_numpy() / df["mutation_stddev"].to_numpy()
mae_pred = df["England B.1.1.7 MAE"].to_numpy()

loss = df["loss"].to_numpy()
min_loss, max_loss = loss.min(), loss.max()
assert (loss > 0).all(), "you'll need to switch to symlog or sth"
loss = np.log(loss)
loss -= loss.min()
loss /= loss.max()
R_alpha = df["R(B.1.1.7)/R(A)"].to_numpy()
R_delta = df["R(B.1.617.2)/R(A)"].to_numpy()

def plot_concordance(filenames=[], colorby="R"):
    legend = {}
    def abbreviate_param(match):
        k = match.group()[:-1]
        v = k[0].upper()
        legend[v] = k
        return v
    def abbreviate_sample(match):
        k = match.group()[:-1]
        v = k[0]
        legend[v] = k
        return v + "꞊"
    fig, axes = plt.subplots(2, figsize=(8, 12))
    for ax, X, Y, xlabel, ylabel in zip(
        axes, [mutation_error, R_alpha], [mae_pred, R_delta],
        [
            # "Pearson correlation of mutaitons",
            "Cross-validation error of mutation coefficients   (lower is better)",
            "R(α) / R(A)"],
        ["England α portion MAE   (lower is better)", "R(δ) / R(A)"]
    ):
        ax.scatter(X, Y, 30, loss, lw=0, alpha=0.8, cmap="coolwarm")
        ax.set_xlabel(xlabel)
        ax.set_ylabel(ylabel)
        
        X_, Y_ = force_apart(X, Y, stepsize=2)
        assert X_.dim() == 1
        X_X = []
        Y_Y = []
        for x_, x, y_, y in zip(X_, X, Y_, Y):
            X_X.extend([float(x_), float(x), None])
            Y_Y.extend([float(y_), float(y), None])
        ax.plot(X_X, Y_Y, "k-", lw=0.5, alpha=0.5, zorder=-10)
        for x, y, mt, cd, l in zip(X_, Y_, model_type, cond_data, loss):
            name = f"{mt}-{cd}"
            name = re.sub("[a-z_]+-", abbreviate_param, name)
            name = re.sub("[a-z_]+=", abbreviate_sample, name)
            name = name.replace("-", "")
            ax.text(x, y, name, fontsize=7, va="center", alpha=1 - 0.666 * l)
            
    axes[0].set_xscale("log")
    axes[0].set_yscale("log")
    axes[0].plot([], [], "bo", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"loss={min_loss:0.2g} (better)")
    axes[0].plot([], [], "ro", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"loss={max_loss:0.2g} (worse)")
    for k, v in sorted(legend.items()):
        axes[0].plot([], [], "wo", label=f"{k} = {v}")
    axes[0].legend(loc="upper right", fontsize="small")
    min_max = [max(X.min(), Y.min()), min(X.max(), Y.max())]
    axes[1].plot(min_max, min_max, "k--", alpha=0.2, zorder=-10)
    plt.subplots_adjust(hspace=0.15)
    for filename in filenames:
        plt.savefig(filename)
        
plot_concordance(["paper/grid_search.png"])

In [None]:
import torch
grid = torch.load("results/mutrans.grid.pt")

def plot_mutation_agreements(grid):
    fig, axes = plt.subplots(len(grid), 3, figsize=(8, 1 + 3 * len(grid)))
    for axe, (name, holdouts) in zip(axes, sorted(grid.items())):
        (name0, fit0), (name1, fit1), (name2, fit2) = holdouts.items()
        pairs = [
            [(name0, fit0), (name1, fit1)],
            [(name0, fit0), (name2, fit2)],
            [(name1, fit1), (name2, fit2)],
        ]
        means = [v["coef"] * 0.01 for v in holdouts.values()]
        x0 = min(mean.min().item() for mean in means)
        x1 = max(mean.max().item() for mean in means)
        lb = 1.05 * x0 - 0.05 * x1
        ub = 1.05 * x1 - 0.05 * x0
        axe[1].set_title(str(name))
        axe[0].set_ylabel(str(name).replace("-", "\n").replace(",", "\n"), fontsize=8)
        for ax, ((name1, fit1), (name2, fit2)) in zip(axe, pairs):
            mutations = sorted(set(fit1["mutations"]) & set(fit2["mutations"]))
            means = []
            for fit in (fit1, fit2):
                m_to_i = {m: i for i, m in enumerate(fit["mutations"])}
                idx = torch.tensor([m_to_i[m] for m in mutations])
                means.append(fit["coef"])
            ax.plot([lb, ub], [lb, ub], 'k--', alpha=0.3, zorder=-100)
            ax.scatter(means[1].numpy(), means[0].numpy(), 30, alpha=0.3, lw=0, color="darkred")
            ax.axis("equal")
            ax.set_title("ρ = {:0.2g}".format(pearson_correlation(means[0], means[1])))
plot_mutation_agreements(grid)

## Debugging plotting code

In [None]:
from pyrocov.plotting import force_apart
torch.manual_seed(1234567890)
X, Y = torch.randn(2, 200)
X_, Y_ = force_apart(X, Y)
plt.plot(X, Y, "ko")
for i in range(8):
    plt.plot(X_ + i / 20, Y_, "r.")
X_X = []
Y_Y = []
for x_, x, y_, y in zip(X_, X, Y_, Y):
    X_X.extend([float(x_), float(x), None])
    Y_Y.extend([float(y_), float(y), None])
plt.plot(X_X, Y_Y, "k-", lw=0.5, alpha=0.5, zorder=-10);