# Varying the set of genes used in regression

This notebook assumes you have run preprocessing and the `--vary-gene` experiment
```sh
make update
python mutrans.py --vary-gene
```

In [None]:
import torch
import matplotlib
import matplotlib.pyplot as plt

matplotlib.rcParams["figure.dpi"] = 200
matplotlib.rcParams["axes.edgecolor"] = "gray"
matplotlib.rcParams["savefig.bbox"] = "tight"
matplotlib.rcParams['font.family'] = 'sans-serif'
matplotlib.rcParams['font.sans-serif'] = ['Arial', 'Avenir', 'DejaVu Sans']
matplotlib.rcParams.update({
    # 'text.usetex': True,
    'text.latex.preamble': r'\usepackage{amsfonts}',
})

In [None]:
result = torch.load("results/mutrans.vary_gene.pt")

In [None]:
result.keys()

In [None]:
print("\n".join(result[()].keys()))

In [None]:
def plot_metric(metric, descending=False, ylabel=None, filenames=()):
    plt.figure(figsize=(8,4))
    empty = (("exclude", (("gene", ".*"),)),)
    include = {}
    exclude = {}
    for k, v in result.items():
        if k == empty:
            continue
        k = dict(k)
        if "include" in k:
            gene = k["include"][0][1]
            include[gene] = v[metric]
        if "exclude" in k:
            gene = k["exclude"][0][1]
            exclude[gene] = v[metric]
    assert set(include) == set(exclude)
    rankby = [(include[g] - exclude[g], g) for g in include]
    rankby.sort(reverse=descending)
    genes = [g for _, g in rankby]
    X = list(range(len(genes)))
    plt.axhline(result[empty][metric], color="darkgreen", linestyle=":", label="No genes",
                zorder=-2)
    plt.plot(X, [include[g] for g in genes], "k+", color="darkred", label="A single gene")
    plt.plot(X, [exclude[g] for g in genes], "ko", color="darkblue",label="All but one gene",
             markerfacecolor="None")
    plt.plot(X, [include[g] for g in genes], "wo", markersize=8, zorder=-1)
    plt.plot(X, [exclude[g] for g in genes], "wo", markersize=8, zorder=-1)
    plt.axhline(result[()][metric], color="k", lw=1, linestyle="--", label="All genes",
                zorder=-2)
    plt.legend(loc="best")
    plt.xticks(X, labels=[g[1:-1] for g in genes])
    plt.ylabel(metric if ylabel is None else ylabel)
    for f in filenames:
        plt.savefig(f)

In [None]:
plot_metric("loss", ylabel="ELBO loss",
            filenames=["paper/vary_gene_loss.png"])

In [None]:
plot_metric("KL")

In [None]:
plot_metric("MAE")

In [None]:
plot_metric("RMSE")

In [None]:
plot_metric("naive KL")

In [None]:
plot_metric("naive MAE")

In [None]:
plot_metric("naive RMSE")