# Ranking mutations by transmissibility

This assumes you've acquired GISAID data and run
```sh
make update               # ~15 minutes on CPU
python rank_mutations.py  # ~10 minutes on GPU
```

In [None]:
import math
import pickle
from collections import Counter
import matplotlib
import matplotlib.pyplot as plt
import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.distributions import constraints
from pyrocov import pangolin

matplotlib.rcParams['figure.dpi'] = 200

## Loading data

In [None]:
aa_features = torch.load("results/nextclade.features.pt")
print(aa_features.keys())

In [None]:
rank_results = torch.load("results/rank_mutations.pt", map_location="cpu")
print(rank_results.keys())

## Assessing model accuracy

The inference approach in `rank_mutations.py` is to:
1. Fit a mean-field variational model via stochastic variational inference (SVI).
2. Initially rank mutations by `|mean|/stddev` of their growth rate coefficietns `log_rate_coef`.
3. Fit MAP model parameters via (SVI).
4. For each of most postitive and most negative mutations, fit MAP model parameters with that feature removed.
5. Rank features by the change in loss for base-model versus model-with-feature-removed.

Let's examine the initial versus final ranking metric.

In [None]:
log_rate_coef = rank_results["log_rate_coef"]
initial_ranks = rank_results["initial_ranks"]
feature_ids = torch.tensor(list(rank_results["log_likelihood"].keys()))
log_likelihood = torch.tensor(list(rank_results["log_likelihood"].values()))
log_likelihood, ind = log_likelihood.sort(0, descending=True)
feature_ids = feature_ids[ind]
log_rate_coef = log_rate_coef[feature_ids]
mean = initial_ranks["mean"][feature_ids]
sigma = mean / initial_ranks["std"][feature_ids]

In [None]:
pos = log_rate_coef > 0
plt.figure(figsize=(5, 5))
plt.scatter(mean[pos], log_likelihood[pos], 1, color="red")
plt.scatter(mean[~pos], log_likelihood[~pos], 1, color="blue")
plt.xlabel("variational posterior mean")
plt.ylabel("MAP log likelihood ratio of dropping feature")
plt.title("Initial versus final mutation metric");

In [None]:
pos = log_rate_coef > 0
plt.figure(figsize=(5, 5))
plt.scatter(sigma[pos], log_likelihood[pos], 1, color="red")
plt.scatter(sigma[~pos].abs(), log_likelihood[~pos], 1, color="blue")
plt.xlabel("variational posterior |mean|/std")
plt.ylabel("MAP log likelihood ratio of dropping feature")
plt.title("Initial versus final mutation metric");

In [None]:
def plot_features():
    mutations = aa_features['mutations']
    xs, idx = rank_results["log_rate_coef"].sort(0)
    assert len(idx) == len(mutations)
    plt.figure(figsize=(6, 6))
    plt.title("Regression coefficients (mutations)")
    plt.plot(xs, 'k.', lw=0, markersize=1, zorder=10)
    plt.axhline(0, color='black', lw=0.5, linestyle='--', alpha=0.5)
    plt.xlabel(f"rank among {len(xs)} mutations")
    plt.ylabel("increased transmissibility")

    I = len(idx)
    y0 = float(xs.min())
    y1 = float(xs.max())
    N = 50
    for i in range(N):
        x = -I / 8
        y = y0 + (y1 - y0) * i / (N - 1)
        plt.plot([i, x], [xs[i], y], color='blue', lw=0.3)
        plt.text(x, y, mutations[int(idx[i])] + " ", fontsize=5, color='blue',
                 verticalalignment="center", horizontalalignment="right")
    for i in range(I - N, I):
        x = I + I / 8
        y = y1 + (y0 - y1) * (I - i - 1) / (N - 1)
        plt.plot([i, x], [xs[i], y], color='red', lw=0.3)
        plt.text(x, y, " " + mutations[int(idx[i])], fontsize=5, color='red',
                 verticalalignment="center", horizontalalignment="left")
    plt.ylim(y0 - (y1 - y0) / 40, y1 + (y1 - y0) / 40)
    plt.xlim(-0.35 * I, 1.35 * I)
    plt.xticks(())
    
plot_features()

In [None]:
def plot_volcano(mean, std):
    xs = mean
    # ys = -log_ndtr(-mean.abs() / std) / math.log(10)
    ys = (mean.abs() / std).log().clamp(min=0) / math.log(10)
    ys = mean.abs() / std
    mutations = aa_features['mutations']
    assert len(xs) == len(mutations)
    y0, y1 = float(ys.min()), float(ys.max())
    x0, x1 = float(xs.min()), float(xs.max())
    ys, idx = ys.sort(0, descending=True)
    xs = xs[idx]
    pos = (0 < xs) & (xs < math.inf)
    neg = (-math.inf < xs) & (xs < 0)
    ys_pos, ys_neg = ys[pos], ys[neg]
    xs_pos, xs_neg = xs[pos], xs[neg]
    idx_pos, idx_neg = idx[pos], idx[neg]
    N = 50

    plt.figure(figsize=(6, 6))
    plt.title(f"Increased transmissibility of {len(mutations)} mutations")
    for mask in [pos, neg]:
        xs_mask, ys_mask = xs[mask], ys[mask]
        plt.plot(xs_mask[:N], ys_mask[:N], 'k.', lw=0, markersize=1, zorder=10)
        plt.plot(xs_mask[N:], ys_mask[N:], 'k.', lw=0, markersize=1, zorder=10, color="#aaa")
    plt.xlabel("variational posterior mean")
    # plt.ylabel("-log10 P(posterior mean has wrong sign)")
    plt.yscale("symlog")
    plt.ylabel("variational posterior |mean| / stddev")
    plt.xlim(x0 - (x1 - x0) * 0.18, x1 + (x1 - x0) * 0.18)
    plt.ylim(0, None)
    # plt.ylim(math.log(2) / math.log(10), None)
    plt.xticks((0,))
    yticks = (0, 1, 2, 5, 10, 20, 50)
    plt.yticks(yticks, list(map(str, yticks)))
    ax = plt.gca()
    for spine in ax.spines.values():
        spine.set_linewidth(0.5)
        
    t = (ax.transScale + ax.transLimits).inverted()
    for i in range(N):
        x = x0
        _, y = t.transform((0, 1 - (i + 1) / (N + 1)))
        plt.plot([x, xs_neg[i]], [y, ys_neg[i]], color='blue', lw=0.1)
        plt.text(x, y, mutations[int(idx_neg[i])] + " ", fontsize=5, color='blue',
                 verticalalignment="center", horizontalalignment="right")
    for i in range(N):
        x = x1
        _, y = t.transform((0, 1 - (i + 1) / (N + 1)))
        plt.plot([x, xs_pos[i]], [y, ys_pos[i]], color='red', lw=0.1)
        plt.text(x, y, " " + mutations[int(idx_pos[i])], fontsize=5, color='red',
                 verticalalignment="center", horizontalalignment="left")

plot_volcano(mean, std)