## Experiment to determine prior regularization `coef_scale`

This notebook assumes you have run
```sh
make update
make preprocess
python mutrans.py --vary-coef-scale=0.01,0.02,0.05,0.1,0.2,0.5
```

In [None]:
from collections import defaultdict
import matplotlib
import matplotlib.pyplot as plt
import torch
import pyro.distributions as dist
from pyro.ops.tensor_utils import convolve
from pyrocov.util import pearson_correlation
from pyrocov.sarscov2 import GENE_TO_POSITION, GENE_STRUCTURE, aa_mutation_to_position

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["figure.facecolor"] = "white"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']

In [None]:
results = torch.load("results/mutrans.vary_coef_scale.pt")
assert len(results) == 3 * 6

In [None]:
collated = defaultdict(dict)
for config, result in results.items():
    holdout = config[-1]
    if holdout:
        holdout = holdout[0][0]
        assert holdout in ("include", "exclude")
    else:
        holdout = None
    coef_scale = float(config[0].split("=")[-1])
    collated[coef_scale / 100][holdout] = {
        "mean": result["mean"]["coef"] / 100,
        "std": result["std"]["coef"] / 100,
    }
    mutations = result["mutations"]

In [None]:
for coef_scale, col in collated.items():
    col["pearson"] = pearson_correlation(col["include"]["mean"], col["exclude"]["mean"])

In [None]:
pearson = [float(col["pearson"]) for col in collated.values()]
plt.figure(figsize=(6,2.5))
plt.plot(list(collated), pearson, zorder=-1);
plt.plot([list(collated)[2]], [pearson[2]], "ro", zorder=0)
plt.xscale("log")
plt.ylabel("$\\rho$ = Pearson correlation")
plt.xlabel("$\sigma_3$ = regularization strength")
plt.tight_layout(pad=0)
plt.savefig("paper/coef_scale_pearson.png")

In [None]:
fig, axes = plt.subplots(1, 6, figsize=(8, 1.5), sharey=True, sharex=True)
for ax, (coef_scale, col) in zip(axes, sorted(collated.items())):
    ax.scatter(col["exclude"]["mean"].numpy(), col["include"]["mean"].numpy(),
               0.5, color="darkred", alpha=0.5)
    ax.set_title(f"$\\sigma_3$ = {coef_scale}", fontsize=9)
    ax.set_xticks(())
    ax.set_yticks(())
    ax.text(0, 0.4, f"$\\rho$ = {col['pearson']:0.3g}", fontsize=9, ha="center", va="center")
for spine in axes[2].spines.values():
    spine.set_linewidth(2.0)
    spine.set_color("darkred")
    
fig.add_subplot(111, frameon=False)
plt.xticks(())
plt.yticks(())
plt.xlabel("world w/o Europe")
plt.ylabel("only Europe")
plt.subplots_adjust(wspace=0.01)
plt.savefig("paper/coef_scale_scatter.png")

In [None]:
sigma = 5.0
p_sig = dist.Normal(torch.zeros(()).double(),1).cdf(torch.tensor(-sigma).double()).item()

fig, axes = plt.subplots(1, 6, figsize=(8, 1.5), sharey=True, sharex=True)
x0 = min(col[None]["mean"].min().item() for col in collated.values()) * 1.05
x1 = max(col[None]["mean"].max().item() for col in collated.values()) * 1.05
y1 = max((col[None]["mean"] / col[None]["std"]).abs().max().item() * 1.05
         for col in collated.values())
for ax, (coef_scale, col) in zip(axes, sorted(collated.items())):
    mean = col[None]["mean"]
    std = col[None]["std"]
    z = (mean / std).abs()
    ok = (mean > 0) & (z > sigma)
    ax.scatter(mean[ok].numpy(), z[ok].numpy(), 0.5, color="black", alpha=0.3)
    ax.scatter(mean[~ok].numpy(), z[~ok].numpy(), 0.5, color="gray", alpha=0.3)
    ax.set_title(f"$\\sigma_3$ = {coef_scale}", fontsize=9)
    ax.set_yscale("symlog")
    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_xlim(x0, x1)
    ax.set_ylim(0, y1)
    
    ax.plot([0, 0, x1], [y1, sigma, sigma], 'k--', lw=1, alpha=0.2)
    ax.plot([x0, 0, 0], [sigma, sigma, 0], 'k-', lw=1, alpha=0.1)
for spine in axes[2].spines.values():
    spine.set_linewidth(2.0)
    spine.set_color("darkred")

fig.add_subplot(111, frameon=False)
plt.xticks(())
plt.yticks(())
plt.xlabel("effect size $R_m/R_{wt}$")
plt.ylabel("statistical significance\n(z-score)")
plt.subplots_adjust(wspace=0.01)
plt.savefig("paper/coef_scale_volcano.png")

In [None]:
def plot_density(mean, ax=None, *, gene_name=None, kernel_radius=200):
    fg_position = torch.tensor([aa_mutation_to_position(m) for m in mutations])
    assert fg_position.shape == mean.shape
    N = 1 + max(fg_position.max().item(),
                max([end for start, end in GENE_TO_POSITION.values()]))
    kernel = torch.cat([torch.arange(1, 1.0 + kernel_radius / 2),
                        torch.arange(1.0 + kernel_radius / 2, 0, -1)])
    kernel = convolve(kernel, kernel)  # smooth out kernel
    kernel /= kernel.sum()
    def smooth(signal):
        result = convolve(kernel, signal)[kernel_radius:-kernel_radius]
        assert len(result) == N
        return result
    foreground = torch.zeros(N).scatter_add_(0, fg_position, mean.abs())
    if ax is None:
        plt.figure(figsize=(1,1))
        ax = plt.gca()
    X = torch.arange(N)
    Y0 = torch.zeros_like(X)
    Y_fg = smooth(foreground)
    y1 = Y_fg.max().item()
    ax.fill_between(X, Y0, Y_fg, lw=0.1, color="#005")
    ax.set_xticks(())
    ax.set_yticks(())
    ax.set_ylim(0, None)
    if gene_name:
        start, end = GENE_TO_POSITION[gene_name]
        ax.set_xlim(start, end)
    else:
        ax.set_xlim(0, N)
    
plot_density(list(results.values())[0]["mean"]["coef"])
plot_density(list(results.values())[0]["mean"]["coef"], gene_name="N")

In [None]:
fig, axes = plt.subplots(2, 6, figsize=(8, 2.5), dpi=300)
for i, (coef_scale, col) in enumerate(sorted(collated.items())):
    plot_density(col[None]["mean"], axes[0][i], kernel_radius=400)
    plot_density(col[None]["mean"], axes[1][i], kernel_radius=80, gene_name="N")
for i, names in {0: ["left", "top", "right"], 1: ["left", "right", "bottom"]}.items():
    ax = axes[i][2]
    for name in names:
        spine = ax.spines[name]
        spine.set_linewidth(2.0)
        spine.set_color("darkred")
        
fig.add_subplot(111, frameon=False)
plt.xticks(())
plt.yticks(())
plt.xlabel("aa position within genome (top) or N gene (bottom)")
axes[0][0].set_ylabel("whole genome")
axes[1][0].set_ylabel("N gene")
plt.subplots_adjust(hspace=0, wspace=0.01)
plt.savefig("paper/coef_scale_manhattan.png")