In [None]:
import anndata
from scdesigner.simulator import simulator
from scdesigner.margins.marginal import NB

I downloaded the data from the scDesign3 [quickstart](https://songdongyuan1994.github.io/scDesign3/docs/articles/scDesign3.html). There are only 100 genes in this demo.

In [None]:
import os
import requests

save_path = "data/example_sce.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/69435h")
    with open(save_path, "wb") as f:
        f.write(response.content)

example_sce = anndata.read_h5ad(save_path)
example_sce

The result seems quite sensitive to the learning rate. How can we pick a good default? Maybe we can systematically survey good lr's across a range of public data. Alternatively, is there a good way to adapt during learning?

In [None]:
import numpy as np

example_sce.X = example_sce.X.toarray().astype(np.float32)
sim = simulator(example_sce, NB("~ pseudotime"))
sim

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

def plot_gene(sim, example_sce, ix=0):
    # get predictions
    gene = example_sce.var_names[ix]
    y_hat = sim.predict(example_sce.obs)["mu"]
    variance = y_hat + (y_hat ** 2) * sim.predict(example_sce.obs)["alpha"] 

    # build confidence bands
    y_hat["pseudotime"] = example_sce.obs["pseudotime"].values
    y_hat["lower_sd"] = y_hat[gene].values - np.sqrt(variance[gene].values)
    y_hat["upper_sd"] = y_hat[gene].values + np.sqrt(variance[gene].values) 

    # plot
    y_hat = y_hat.sort_values(by="pseudotime")
    sns.scatterplot(x="pseudotime", y=gene, data={"pseudotime": example_sce.obs["pseudotime"], gene: example_sce.X[:, ix]})
    plt.fill_between(y_hat['pseudotime'], y_hat['lower_sd'], y_hat['upper_sd'], color='orange', alpha=0.3)
    sns.scatterplot(x="pseudotime", y=gene, data=y_hat)

for i in range(5):
    plot_gene(sim, example_sce, i)
    plt.show()

In [None]:
sim = simulator(example_sce, NB("~ bs(pseudotime, df=10)"), max_epochs=6)

In [None]:
for i in range(5):
    plot_gene(sim, example_sce, i)
    plt.show()

In [None]:
fmla = {"mu": "~ bs(pseudotime, df=10)", "alpha": "~ bs(pseudotime, df=4)"}
sim = simulator(example_sce, NB(fmla), max_epochs=6)

In [None]:
for i in range(5):
    plot_gene(sim, example_sce, i)
    plt.show()