# Varying the set of genes used in regression

This notebook assumes you have run preprocessing and the `--vary-gene` experiment
```sh
make update
make preprocess
python mutrans.py --vary-gene
```

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["figure.facecolor"] = "white"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams["savefig.pad_inches"] = 0.01
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']
matplotlib.rcParams.update({
    # 'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsfonts}',
})

In [None]:
result = torch.load("results/mutrans.vary_gene.pt")

In [None]:
result.keys()

In [None]:
print("\n".join(result[()].keys()))

In [None]:
def plot_metric(metric, descending=False, ylabel=None, filenames=()):
    plt.figure(figsize=(8,4))
    empty = (("exclude", (("gene", ".*"),)),)
    include = {}
    exclude = {}
    for k, v in result.items():
        if k == empty:
            continue
        k = dict(k)
        if "include" in k:
            gene = k["include"][0][1]
            include[gene] = v[metric]
        if "exclude" in k:
            gene = k["exclude"][0][1]
            exclude[gene] = v[metric]
    rankby = [(include[g], g) for g in include]
    rankby.sort(reverse=descending)
    genes = [g for _, g in rankby]
    X = list(range(len(genes)))
    def plot_line(top, color, linestyle):
        if top:
            plt.axhline(result[empty][metric], color=color, linestyle=linestyle,
                        lw=(1 if color == "black" else 1.5),
                        label="No genes", zorder=-2)
        else:
            plt.axhline(result[()][metric], color=color, linestyle=linestyle,
                        lw=(1 if color == "black" else 1.5),
                        label="All genes", zorder=-2)
    plot_line(not descending, "darkgreen", ":")
    plt.plot(X, [include[g] for g in genes], "k.", color="darkred", label="A single gene")
    plt.plot(X, [include[g] for g in genes], "wo", markersize=8, zorder=-1)
    plot_line(descending, "black", "--")
    plt.legend(loc="best")
    plt.xticks(X, labels=[g[1:-1] for g in genes], fontsize=9)
    plt.ylabel(metric if ylabel is None else ylabel)
    plt.xlabel("Gene")
    for f in filenames:
        plt.savefig(f)

In [None]:
plot_metric("loss", ylabel="ELBO loss",
            filenames=["paper/vary_gene_loss.png"])

In [None]:
plot_metric("rate_scale", ylabel=r"E[$\sigma_4$]",
            filenames=["paper/vary_gene_loss.png"])

In [None]:
plot_metric("KL")

In [None]:
plot_metric("MAE")

In [None]:
plot_metric("RMSE")

In [None]:
plot_metric("ELL", ylabel="expected log likelihood", descending=True,
            filenames=["paper/vary_gene_likelihood.png"])

In [None]:
def plot_elbo(descending=True, filenames=()):
    plt.figure(figsize=(8,4))
    empty = (("exclude", (("gene", ".*"),)),)
    include = {}
    exclude = {}
    for k, v in result.items():
        if k == empty:
            continue
        k = dict(k)
        if "include" in k:
            gene = k["include"][0][1]
            include[gene] = -v["loss"]
        if "exclude" in k:
            gene = k["exclude"][0][1]
            exclude[gene] = -v["loss"]
    rankby = [(include[g], g) for g in include]
    rankby.sort(reverse=descending)
    genes = [g for _, g in rankby]
    X = list(range(len(genes)))
    plt.axhline(-result[()]["loss"], color="darkgreen", linestyle=":", label="All genes",
                zorder=-2)
    plt.plot(X, [include[g] for g in genes], "ko", color="darkred", label="One gene",
            markeredgecolor="white", markersize=7)
    plt.axhline(-result[empty]["loss"], color="k", linestyle="--", lw=1, label="No genes",
                zorder=-2)
    plt.legend(loc="best")
    plt.xticks(X, labels=[g[1:-1] for g in genes])
    plt.ylabel("ELBO")
    plt.xlabel("Gene")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
        
plot_elbo(filenames=["paper/vary_gene_elbo.png"])

## --vary-nsp

In [None]:
result = torch.load("results/mutrans.vary_nsp.pt")
result.keys()

In [None]:
def plot_elbo(descending=True, filenames=()):
    plt.figure(figsize=(8,4))
    empty = (("exclude", (("gene", ".*"),)),)
    full = (("include", (("gene", "^ORF1[ab]:"),)),)
    include = {}
    for k, v in result.items():
        k = dict(k)
        if "include" not in k:
            continue
        k = dict(k["include"])
        if "region" not in k:
            continue
        gene, region = k["region"]
        include[gene, region] = -v["loss"]
    rankby = [(include[gr], gr[-1]) for gr in include]
    rankby.sort(reverse=descending)
    labels = [g for _, g in rankby]
    elbos = [e for e, _ in rankby]
    X = list(range(len(labels)))
    plt.axhline(-result[full]["loss"], color="darkgreen", linestyle=":", label="All of ORF1",
                zorder=-2)
    plt.plot(X, elbos, "ko", color="darkred", label="One nsp",
             markeredgecolor="white", markersize=7)
    plt.axhline(-result[empty]["loss"], color="k", linestyle="--", lw=1, label="No nsps",
                zorder=-2)
    plt.legend(loc="best")
    plt.xticks(X, labels=labels)
    plt.ylabel("ELBO")
    plt.xlabel("Nonstructural protein within ORF1")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
        
plot_elbo(filenames=["paper/vary_nsp_elbo.png"])

In [None]:
def plot_ell(descending=True, filenames=()):
    plt.figure(figsize=(8,4))
    empty = (("exclude", (("gene", ".*"),)),)
    full = (("include", (("gene", "^ORF1[ab]:"),)),)
    include = {}
    for k, v in result.items():
        k = dict(k)
        if "include" not in k:
            continue
        k = dict(k["include"])
        if "region" not in k:
            continue
        gene, region = k["region"]
        include[gene, region] = v["ELL"]
    rankby = [(include[gr], gr[-1]) for gr in include]
    rankby.sort(reverse=descending)
    labels = [g for _, g in rankby]
    ells = [e for e, _ in rankby]
    X = list(range(len(labels)))
    plt.axhline(result[full]["ELL"], color="darkgreen", linestyle=":", label="All of ORF1",
                zorder=-2)
    plt.plot(X, ells, "ko", color="darkred", label="One nsp",
             markeredgecolor="white", markersize=7)
    plt.axhline(result[empty]["ELL"], color="k", linestyle="--", lw=1, label="No nsps",
                zorder=-2)
    plt.legend(loc="best")
    plt.xticks(X, labels=labels)
    plt.ylabel("expected log likelihood")
    plt.xlabel("Nonstructural protein within ORF1")
    plt.tight_layout()
    for f in filenames:
        plt.savefig(f)
        
plot_ell(filenames=["paper/vary_nsp_likelihood.png"])