This notebook reproduces the example from scDesign3 package: [Simulate datasets with multiple lineages](https://songdongyuan1994.github.io/scDesign3/docs/articles/scDesign3-multipleLineages-vignette.html)

In [1]:
import anndata
import os
import requests

save_path = "data/GSE72859.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/p615r4")
    with open(save_path, "wb") as f:
        f.write(response.content)

example_sce = anndata.read_h5ad(save_path)
example_sce = example_sce[:, :100].to_memory()

In [2]:
example_sce.obs[["pseudotime1", "pseudotime2", "l1", "l2"]]

Unnamed: 0,pseudotime1,pseudotime2,l1,l2
W31105,0.950862,0.568357,TRUE,TRUE
W31106,9.168276,-1.000000,TRUE,FALSE
W31107,-1.000000,7.981990,FALSE,TRUE
W31108,11.394132,-1.000000,TRUE,FALSE
W31109,-1.000000,8.080133,FALSE,TRUE
...,...,...,...,...
W39164,-1.000000,10.675819,FALSE,TRUE
W39165,8.708608,-1.000000,TRUE,FALSE
W39166,0.136108,0.117735,TRUE,TRUE
W39167,11.376395,-1.000000,TRUE,FALSE


See page 138 of the [mgcv manual](https://cran.r-project.org/web/packages/mgcv/mgcv.pdf) for how scDesign3 can use separate splines for each group using the `by` syntax. We haven't implemented anything like this in our estimation, though it seems possible to prepare a special estimator in this case. As an alternative, we will fit an interaction. This is not exactly the same, but has a similar qualitative interpretation and still seems to result in plausible data.

In [3]:
from scdesigner.experimental.estimators import negative_binomial_copula
from scdesigner.experimental.samplers import negative_binomial_copula_sample

formula = "~ pseudotime1 * l1 + pseudotime2 * l2"
params = negative_binomial_copula(example_sce, formula, epochs=100)
simulated = negative_binomial_copula_sample(params, example_sce.obs, formula)

In [4]:
import altair
from scdesigner.experimental.diagnose import plot_umap

altair.data_transformers.enable("vegafusion")
altair.renderers.enable("jupyter")

combined = anndata.concat({"real": example_sce, "sim": simulated}, label="source")
plot_umap(combined, color="pseudotime1", shape="l1", facet="source")

JupyterChart()

In [5]:
plot_umap(combined, color="pseudotime2", shape="l2", facet="source")

JupyterChart()