# Analyzing results of grid search

This notebook assumes you've downloaded data and run a grid search experiment
```sh
make update  # many hours
python mutrans.py --grid-search  # many hours
```

In [None]:
import math
import re
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200

In [None]:
df = pd.read_csv("results/grid_search.tsv", sep="\t")
df = df.fillna("")
df

In [None]:
model_type = df["model_type"].to_list()
cond_data = df["cond_data"].to_list()
corr = df["ρ_mutation"].to_numpy()
mae = df["England B.1.1.7 MAE"].to_numpy()

loss = df["loss"].to_numpy()
min_loss, max_loss = loss.min(), loss.max()
assert (loss > 0).all(), "you'll need to switch to symlog or sth"
loss = np.log(loss)
loss -= loss.min()
loss /= loss.max()

RB_1_1_7 = df["R(B.1.1.7)/R(A)"].to_numpy()
RB_1_617_2 = df["R(B.1.617.2)/R(A)"].to_numpy()
R = (RB_1_617_2 / RB_1_1_7)
min_R, max_R = R.min(), R.max()
R = np.log(np.clip(R, -2, 2))
R -= R.min()
R /= R.max()

def plot_corr_vs_mae(filenames=[], colorby="R", x_min=0):
    legend = {}
    def abbreviate_param(match):
        k = match.group()[:-1]
        v = k[0].upper()
        legend[v] = k
        return v
    def abbreviate_sample(match):
        k = match.group()[:-1]
        v = k[0]
        legend[v] = k
        return v + "꞊"
    plt.figure(figsize=(8, 6))
    X = corr
    Y = mae
    mask = X >= x_min
    plt.scatter(X[mask], Y[mask], 30, (R if colorby == "R" else loss)[mask],
                lw=0, alpha=0.8, cmap="coolwarm")
    plt.xlabel("Pearson correlation of mutation coefficients   (higher is better)")
    plt.ylabel("England B.1.1.7 portion MAE   (lower is better)")
    for x, y, mt, cd, l in zip(corr, mae, model_type, cond_data, loss):
        if x < x_min:
            continue
        name = f"  {mt}-{cd}"
        name = re.sub("[a-z_]+-", abbreviate_param, name)
        name = re.sub("[a-z_]+=", abbreviate_sample, name)
        name = name.replace("-", "")
        plt.text(x, y, name, fontsize=7, va="center", alpha=1 - 0.666 * l)
    if colorby == "R":
        plt.plot([], [], "bo", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"R(B.1.617.2)/R(B.1.1.7)={min_R:0.2g}")
        plt.plot([], [], "ro", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"R(B.1.617.2)/R(B.1.1.7)={max_R:0.2g}")
    else:
        plt.plot([], [], "bo", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"loss={min_loss:0.2g} (better)")
        plt.plot([], [], "ro", markeredgewidth=0, markersize=5, alpha=0.5,
                 label=f"loss={max_loss:0.2g} (worse)")
    for k, v in sorted(legend.items()):
        plt.plot([], [], "wo", label=f"{k} = {v}")
    plt.legend(loc="upper left", fontsize="small")
    plt.xlim(x_min, 1)
    plt.yscale("log")
    plt.tight_layout()
    for filename in filenames:
        plt.savefig(filename)
        
plot_corr_vs_mae(["paper/grid_search.png"])
plot_corr_vs_mae(["paper/grid_search_zoom.png"], x_min=0.8)