This notebook reproduces the example from the scDesign3 package: [Simulate datasets with cell library size](https://songdongyuan1994.github.io/scDesign3/docs/articles/scDesign3-librarySize-vignette.html)

In [1]:
import anndata
import os
import requests

save_path = "data/sce_filteredExpr10_Zhengmix4eq.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/6wj08l")
    with open(save_path, "wb") as f:
        f.write(response.content)

sce = anndata.read_h5ad(save_path)
sce.obs["library"] = sce.X.sum(axis=1)

Ideally we would use a true offset. Here, we just use log of the library size as a predictor.

In [2]:
from scdesigner.experimental.estimators import negative_binomial_copula
from scdesigner.experimental.samplers import negative_binomial_copula_sample

formula = "~ phenoid + log(library)"
params = negative_binomial_copula(sce, formula)
samples = negative_binomial_copula_sample(params, sce.obs, formula)

In [3]:
import altair
from scdesigner.experimental.diagnose import plot_umap

altair.data_transformers.enable("vegafusion")
altair.renderers.enable("jupyter")

combined = anndata.concat({"real": sce, "sim": samples}, label = "source")
plot_umap(combined, color="phenoid", facet="source")

JupyterChart()

In [4]:
import altair as alt
import pandas as pd
import numpy as np

# create data for plotting
samples_library = samples.obs.library.values
sce_library = sce.obs.library.values
plot_data = pd.DataFrame({
    'Library Size': np.concatenate([samples_library, sce_library]),
    'Source': ['Simulated'] * len(samples_library) + ['Reference'] * len(sce_library)
})

# make the violin plot
# https://altair-viz.github.io/gallery/violin_plot.html
alt.Chart(plot_data).transform_density(
    'Library Size',
    as_=['Library Size', 'density'],
    groupby=['Source'],
).mark_area(orient='horizontal').encode(
    y='Library Size:Q',
    color='Source:N',
    x=alt.X('density:Q', stack='center').title(None).axis(labels=False, ticks=False),
    column=alt.Column('Source:N').header(titleOrient="bottom")
).properties(
    width=200,
    height=400
)

JupyterChart()