# NB Estimation for Simulated data

In [1]:
%load_ext autoreload
%autoreload 2

from scdesigner.simulators import NegBinRegressionSimulator
import numpy as np
import pandas as pd
import torch
import anndata
from anndata import AnnData

We first generate negative binomial samples with both mean and dispersion comes from a linear model. Then we fit the negative binomial model with both mean and dispersion.

In [2]:
n, g, d, p = 10000, 20, 2, 2
X1 = np.random.normal(size=(n, d)) # 
X2 = np.random.normal(size=(n, p)) #    
B = np.random.normal(size=(d, g)) # feature x gene
D = np.random.normal(size=(p, g)) # feature x gene
mu = np.exp(X1 @ B) # cell x gene
r = np.exp(X2 @ D) # cell x gene

# generate samples
Y = np.random.negative_binomial(r, r / (r + mu))

X1 = pd.DataFrame(X1, columns=[f"mean_dim{j}" for j in range(d)]) # cell x feature
X2 = pd.DataFrame(X2, columns=[f"dispersion_dim{j}" for j in range(p)]) # cell x feature
obs = pd.concat([X1, X2], axis=1)

adata = AnnData(X=Y, obs=obs)



In [3]:
nb_simulator = NegBinRegressionSimulator()
nb_simulator.fit(adata, {"mean": "~ mean_dim0 + mean_dim1 - 1", 
                                      "dispersion": "~ dispersion_dim0 + dispersion_dim1 - 1"})
nb_params = nb_simulator.params
print("Parameter keys:", list(nb_params.keys()))

                                                            

Parameter keys: ['beta', 'gamma']




Now we compare the ground truth and estimated parameters.

In [4]:
# Compare with ground truth
print("\n=== Ground Truth vs Estimated Parameters ===")
print("Ground Truth Mean (B):")
display(pd.DataFrame(B))    
print("Estimated Mean:")
display(nb_params["beta"])
print("Mean parameter error:", np.mean(np.abs(B - nb_params["beta"].values)))

print("\nGround Truth Dispersion (D):")
display(pd.DataFrame(D))
print("Estimated Dispersion:")
display(nb_params["gamma"])
print("Dispersion parameter error:", np.mean(np.abs(D - nb_params["gamma"].values)))


=== Ground Truth vs Estimated Parameters ===
Ground Truth Mean (B):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,-0.885021,-1.466365,0.448272,-0.168992,1.041965,-0.515934,0.553246,1.138602,-1.518466,0.728219,-0.441761,0.128771,-0.624948,0.652953,-0.48825,-0.141961,1.612828,0.688569,0.92988,-0.765566
1,1.065042,-0.74625,-0.191624,1.16328,0.791408,0.135344,0.035549,0.119952,1.140553,-0.443818,-1.472146,-0.816129,-0.543193,-0.474946,1.859756,-0.671229,-1.137285,-0.509316,-1.304134,1.430872


Estimated Mean:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
mean_dim0,-0.871707,-1.479767,0.39306,-0.206132,1.050088,-0.448879,0.57435,1.156746,-1.497602,0.688133,-0.414141,0.091266,-0.641935,0.70978,-0.571389,-0.184057,1.600609,0.681295,0.89275,-0.715734
mean_dim1,1.014223,-0.7655,-0.134166,1.210657,0.756265,0.103501,0.056013,0.131817,1.103213,-0.464384,-1.450575,-0.844111,-0.502702,-0.47196,1.857537,-0.681866,-1.123567,-0.539182,-1.297758,1.434106


Mean parameter error: 0.0289069619613604

Ground Truth Dispersion (D):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.303772,-1.11647,-0.193148,0.173487,2.256355,-1.117915,0.769341,-0.833273,-0.209936,0.573649,0.51716,-0.04196,-2.113994,1.063512,-1.326901,1.896008,0.565633,-0.127796,-0.784898,1.373176
1,-0.304074,-0.84207,-0.611705,0.660281,0.767458,0.653207,3.084871,0.683469,-2.017148,-1.041547,1.986966,0.478158,0.259265,1.517845,-1.237194,-2.662761,-1.043966,-2.419723,0.057098,-1.364817


Estimated Dispersion:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
dispersion_dim0,0.203923,-1.058633,-0.1744,0.19497,2.31394,-1.052671,0.675396,-0.842439,-0.14165,0.491889,0.523717,0.011723,-2.102712,0.96884,-1.291272,1.882042,0.599416,-0.095142,-0.830948,1.390869
dispersion_dim1,-0.262755,-0.945501,-0.686431,0.519616,0.760761,0.674216,3.073125,0.644052,-2.0301,-1.005876,2.025144,0.549112,0.203463,1.632842,-1.225474,-2.611338,-1.063972,-2.381649,0.124127,-1.303215


Dispersion parameter error: 0.048432226688146424


The estimation for the dispersion parameter is less accurate than the mean parameter.

# NB estimation for real data

In [5]:
import os
import requests

save_path = "data/example_sce.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/69435h")
    with open(save_path, "wb") as f:
        f.write(response.content)

example_sce = anndata.read_h5ad(save_path)
example_sce

AnnData object with n_obs × n_vars = 2087 × 100
    obs: 'clusters_coarse', 'clusters', 'S_score', 'G2M_score', 'cell_type', 'sizeFactor', 'pseudotime'
    var: 'highly_variable_genes'
    uns: 'X_name', 'clusters_coarse_colors', 'clusters_colors', 'day_colors', 'neighbors', 'pca'
    obsm: 'PCA', 'UMAP', 'X_pca', 'X_umap'
    layers: 'counts', 'cpm', 'logcounts', 'spliced', 'unspliced'
    obsp: 'connectivities', 'distances'

In [7]:
sim = NegBinRegressionSimulator()
sim.fit(example_sce, {"mean": "~ pseudotime", "dispersion": "~ ."})

                                                           

In [11]:
display(sim.params['beta'])
display(sim.params['gamma'])

Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
Intercept,1.768224,1.477402,2.061628,1.713396,3.422275,1.914225,2.269121,1.318398,1.289312,1.687198,...,0.702285,0.352568,2.193075,2.556925,0.691878,0.647325,2.242665,1.980175,-0.762403,0.619287
pseudotime,2.056158,1.667936,1.83168,2.29334,-5.852592,1.312713,0.694322,1.489479,2.196359,1.840136,...,0.979922,0.772675,-1.982561,-3.796952,1.187684,0.185511,-2.467906,-2.88411,2.159065,-0.174052


Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
Intercept,-0.255255,-0.313529,-0.110844,0.006956,-0.204924,-0.037107,-0.195834,-0.377578,-0.133512,-0.314853,...,0.180739,0.149848,0.310314,-0.21451,0.303152,-0.012965,0.177621,0.340793,-1.179454,-0.395739
clusters_coarse[T.Ngn3 low EP],-0.506407,-0.853216,-1.949587,-1.018138,-0.186941,-2.442313,-2.003157,-0.96655,-1.502361,-2.312207,...,0.311675,-0.271376,-0.039985,-0.462057,0.408251,-0.878225,-0.53454,0.072972,-2.656342,-0.448245
clusters_coarse[T.Ngn3 high EP],-0.455978,-0.634149,-1.014705,-1.03623,-0.260413,-0.959062,-0.232813,-0.812435,-0.751226,-1.548665,...,0.198856,0.032466,0.256123,0.138053,0.165828,0.180643,0.254098,0.344766,-1.931066,-0.168606
clusters_coarse[T.Pre-endocrine],-0.338756,-0.49298,-0.19412,-0.397644,-0.494959,-0.032659,0.020459,-0.769262,-0.781942,-1.448584,...,0.046769,0.240056,0.360415,0.030449,-0.278936,0.149393,0.065505,-0.046933,-2.028451,-0.191925
clusters_coarse[T.Endocrine],-0.224659,-0.308601,0.108855,0.14263,0.02324,0.109799,-0.420989,-0.369293,-0.092068,-0.236639,...,0.228293,0.080792,0.488632,-0.640758,0.332392,0.108718,0.214259,0.018495,-0.531121,-0.785129
clusters[T.Ngn3 low EP],-0.506407,-0.853216,-1.949587,-1.018138,-0.186941,-2.442313,-2.003157,-0.96655,-1.502361,-2.312207,...,0.311675,-0.271376,-0.039985,-0.462057,0.408251,-0.878225,-0.53454,0.072972,-2.656342,-0.448245
clusters[T.Ngn3 high EP],-0.455978,-0.634149,-1.014705,-1.03623,-0.260413,-0.959062,-0.232813,-0.812435,-0.751226,-1.548665,...,0.198856,0.032466,0.256123,0.138053,0.165828,0.180643,0.254098,0.344766,-1.931066,-0.168606
clusters[T.Pre-endocrine],-0.338756,-0.49298,-0.19412,-0.397644,-0.494959,-0.032659,0.020459,-0.769262,-0.781942,-1.448584,...,0.046769,0.240056,0.360415,0.030449,-0.278936,0.149393,0.065505,-0.046933,-2.028451,-0.191925
clusters[T.Beta],-0.224659,-0.308601,0.108855,0.14263,0.02324,0.109799,-0.420989,-0.369293,-0.092068,-0.236639,...,0.228293,0.080792,0.488632,-0.640758,0.332392,0.108718,0.214259,0.018495,-0.531121,-0.785129
clusters[T.Alpha],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [14]:
sim.sample(example_sce.obs)

AnnData object with n_obs × n_vars = 2087 × 100
    obs: 'clusters_coarse', 'clusters', 'S_score', 'G2M_score', 'cell_type', 'sizeFactor', 'pseudotime'

In [16]:
sim.predict(example_sce.obs)['mean']

Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
AAACCTGAGAGGGATA,23.587784,13.558169,27.169871,26.220190,0.581983,16.498044,15.476788,10.248462,16.066990,18.791500,...,3.919346,2.400918,2.340613,0.985596,4.464784,2.166169,1.770584,1.027317,2.013339,1.651052
AAACCTGGTAAGTGGC,14.233482,9.000058,17.324536,14.926376,2.450958,11.950230,13.049747,7.107926,9.366992,11.957307,...,3.080800,1.985816,3.809371,2.504977,3.334908,2.069663,3.246555,2.086491,1.184572,1.723180
AAACGGGCAAAGAATC,27.359213,15.291675,31.007837,30.937332,0.381555,18.136683,16.271701,11.410953,18.825359,21.459037,...,4.206425,2.538541,2.028704,0.749455,4.864172,2.195351,1.481838,0.834352,2.352650,1.630452
AAACGGGGTACAGTTC,39.294164,20.511341,42.808616,46.327990,0.136156,22.852551,18.387573,14.832528,27.713299,29.669926,...,4.998538,2.908485,1.430942,0.384071,5.995499,2.268241,0.959604,0.502129,3.440729,1.581245
AAACGGGGTGAAATCA,12.810605,8.263047,15.772984,13.272005,3.307754,11.173096,12.593780,6.585789,8.370275,10.881720,...,2.929976,1.908754,4.216552,3.042784,3.138069,2.050089,3.684038,2.418670,1.060549,1.738612
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGGTTTCACTTACT,7.249920,5.206961,9.498808,7.033652,16.720922,7.768378,10.391290,4.360236,4.556646,6.537867,...,2.233759,1.541140,7.300362,8.706107,2.258671,1.947451,7.295747,5.374866,0.583339,1.824446
TTTGGTTTCCTTTCGG,29.924869,16.444990,33.585371,34.190230,0.295632,19.204858,16.771749,12.176481,20.716974,23.251394,...,4.390012,2.625506,1.860731,0.635125,5.122654,2.213178,1.330689,0.735776,2.584844,1.618127
TTTGTCAAGAATGTGT,24.668243,14.059814,28.275814,27.563263,0.512324,16.976596,15.712636,10.586418,16.854345,19.560005,...,4.003903,2.441669,2.241686,0.907361,4.581797,2.174939,1.677917,0.964764,2.110287,1.644804
TTTGTCAAGTGACATA,19.743757,11.736284,23.188074,21.501398,0.965636,14.726823,14.574465,9.009352,13.286460,16.025842,...,3.600761,2.245666,2.778571,1.368876,4.028792,2.131680,2.192021,1.318472,1.670294,1.676102
