This notebook reproduces the example from scDesign package: [Simulate datasets with condition effect](https://songdongyuan1994.github.io/scDesign3/docs/articles/scDesign3-conditionEffect-vignette.html)

In [None]:
import anndata
import os
import requests

save_path = "data/stxBrain.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/o1y03l")
    with open(save_path, "wb") as f:
        f.write(response.content)

example_sce = anndata.read_h5ad(save_path)

In [None]:
example_sce = example_sce[:, :10].to_memory()
example_sce.obs

This is not quite the simulator used in the scDesign3 vignette, because we are using the same copula correlation across all groups. We need a version of negative_binomial_copula that takes a grouping variable in the formula as well.

In [None]:
from scdesigner.experimental.estimators import negative_binomial_copula

formula = "~ bs(spatial1, df=100) * bs(spatial2, df=100)"
params = negative_binomial_copula(example_sce, formula)

In [None]:
from scdesigner.experimental.samplers import negative_binomial_copula_sample

simulated = negative_binomial_copula_sample(params, example_sce.obs, formula)

In [None]:
import altair
from scdesigner.experimental.diagnose import plot_umap
altair.data_transformers.enable("vegafusion")

combined = anndata.concat({"real": example_sce, "sim": simulated}, label="source")
plot_umap(combined, color="seurat_clusters", facet="source", n_comps=5)

In [None]:
import pandas as pd
import numpy as np

def plot_spatial(adata, spatial_names=["spatial1", "spatial2"]):
    plot_df = pd.concat([
        adata.obs[spatial_names].reset_index(drop=True), 
        pd.DataFrame(np.log1p(adata.X)).reset_index(drop=True)
    ], axis=1)

    plot_df.columns = spatial_names + list(adata.var_names)

    plot_df_melted = plot_df.melt(id_vars=spatial_names, var_name="gene", value_name="expression")
    return altair.Chart(plot_df_melted).mark_point(size=1).encode(
        x=spatial_names[0],
        y=spatial_names[1],
        fill=altair.Fill("expression", scale=altair.Scale(scheme="viridis")),
        color=altair.Color("expression", scale=altair.Scale(scheme="viridis"))
    ).properties(width=200, height=200)\
    .facet(facet="gene", columns=5)

In [None]:
plot_spatial(example_sce)

In [None]:
plot_spatial(simulated)