# NB Estimation for Simulated data

In [2]:
%load_ext autoreload
%autoreload 2

from scdesigner.simulators import NegBinRegressionSimulator
import numpy as np
import pandas as pd
import torch
import anndata
from anndata import AnnData

We first generate negative binomial samples with both mean and dispersion comes from a linear model. Then we fit the negative binomial model with both mean and dispersion.

In [3]:
n, g, d, p = 1000, 20, 2, 2
X1 = np.random.normal(size=(n, d)) # 
X2 = np.random.normal(size=(n, p)) #    
B = np.random.normal(size=(d, g)) # feature x gene
D = np.random.normal(size=(p, g)) # feature x gene
mu = np.exp(X1 @ B) # cell x gene
r = np.exp(X2 @ D) # cell x gene

# generate samples
Y = np.random.negative_binomial(r, r / (r + mu))

X1 = pd.DataFrame(X1, columns=[f"mean_dim{j}" for j in range(d)]) # cell x feature
X2 = pd.DataFrame(X2, columns=[f"dispersion_dim{j}" for j in range(p)]) # cell x feature
obs = pd.concat([X1, X2], axis=1)

adata = AnnData(X=Y, obs=obs)



In [4]:
nb_simulator = NegBinRegressionSimulator()
nb_simulator.fit(adata, {"mean": "~ mean_dim0 + mean_dim1 - 1", 
                                      "dispersion": "~ dispersion_dim0 + dispersion_dim1 - 1"})
nb_params = nb_simulator.params
print("Parameter keys:", list(nb_params.keys()))

                                                           

Parameter keys: ['beta', 'gamma']




Now we compare the ground truth and estimated parameters.

In [5]:
# Compare with ground truth
print("\n=== Ground Truth vs Estimated Parameters ===")
print("Ground Truth Mean (B):")
display(pd.DataFrame(B))    
print("Estimated Mean:")
display(nb_params["beta"])
print("Mean parameter error:", np.mean(np.abs(B - nb_params["beta"].values)))

print("\nGround Truth Dispersion (D):")
display(pd.DataFrame(D))
print("Estimated Dispersion:")
display(nb_params["gamma"])
print("Dispersion parameter error:", np.mean(np.abs(D - nb_params["gamma"].values)))


=== Ground Truth vs Estimated Parameters ===
Ground Truth Mean (B):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,1.594726,-0.866177,0.36576,1.108588,0.126472,-0.236877,0.460676,1.37611,0.22027,-1.250269,-0.659175,1.556995,-0.205124,-0.694896,-1.79165,0.472785,-0.915667,-1.356247,0.799043,0.465109
1,-0.406036,1.070041,0.392654,-1.329016,0.787252,1.266988,1.069468,-1.955259,1.006103,-1.06401,-1.98337,-1.378119,-1.472612,1.715388,0.691899,1.557015,0.675984,0.066121,0.696808,-0.16459


Estimated Mean:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
mean_dim0,1.610046,-0.831127,0.349426,1.122214,0.08996,-0.279992,0.440286,1.405915,0.217838,-1.282284,-0.734286,1.597518,-0.226538,-0.712569,-1.825839,0.487697,-0.91956,-1.401875,0.840859,0.43361
mean_dim1,-0.377482,1.092808,0.392815,-1.348479,0.769217,1.24731,1.099206,-1.946407,0.937164,-1.090665,-1.967453,-1.360766,-1.460016,1.667837,0.700313,1.555358,0.68097,0.060816,0.725686,-0.083968


Mean parameter error: 0.025934528726703286

Ground Truth Dispersion (D):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.136826,1.439632,0.157531,0.692717,-0.675076,-0.075433,0.159354,-0.840335,0.754416,-0.78479,0.166896,0.960192,0.4045,-2.271163,1.34054,0.997282,0.591537,0.547134,1.004401,-0.890365
1,1.505019,0.371515,-0.38645,-0.277598,-0.58838,1.027614,-0.637878,1.826335,-0.040392,0.205676,-0.546004,-0.225499,-0.536045,-1.110398,1.065309,-0.723435,-0.735611,-0.740706,1.025507,1.145504


Estimated Dispersion:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
dispersion_dim0,0.134351,1.398004,0.161663,0.711883,-0.769289,-0.171935,0.257466,-0.755101,0.758259,-0.774401,0.132439,0.70213,0.450999,-2.452514,1.510876,1.044522,0.630705,0.485312,1.133515,-0.844386
dispersion_dim1,1.370066,0.357159,-0.302599,-0.260589,-0.704416,0.880418,-0.548894,1.543103,-0.148388,0.160141,-0.515342,-0.236053,-0.583263,-1.102276,0.997585,-0.809409,-0.919442,-0.78167,0.957447,1.207356


Dispersion parameter error: 0.07784573061829891


The estimation for the dispersion parameter is less accurate than the mean parameter.

# NB estimation for real data

In [6]:
import os
import requests

save_path = "data/example_sce.h5ad"
if not os.path.exists(save_path):
    response = requests.get("https://go.wisc.edu/69435h")
    with open(save_path, "wb") as f:
        f.write(response.content)

example_sce = anndata.read_h5ad(save_path)
example_sce

AnnData object with n_obs × n_vars = 2087 × 100
    obs: 'clusters_coarse', 'clusters', 'S_score', 'G2M_score', 'cell_type', 'sizeFactor', 'pseudotime'
    var: 'highly_variable_genes'
    uns: 'X_name', 'clusters_coarse_colors', 'clusters_colors', 'day_colors', 'neighbors', 'pca'
    obsm: 'PCA', 'UMAP', 'X_pca', 'X_umap'
    layers: 'counts', 'cpm', 'logcounts', 'spliced', 'unspliced'
    obsp: 'connectivities', 'distances'

In [7]:
sim = NegBinRegressionSimulator()
sim.fit(example_sce, {"mean": "~ pseudotime", "dispersion": "~ ."})

                                                          

In [8]:
display(sim.params['beta'])
display(sim.params['gamma'])

Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
Intercept,1.768223,1.477402,2.061628,1.713396,3.422275,1.914225,2.269121,1.318398,1.289312,1.687198,...,0.702285,0.352568,2.193076,2.556926,0.691878,0.647325,2.242666,1.980175,-0.762403,0.619287
pseudotime,2.056158,1.667936,1.83168,2.29334,-5.852593,1.312713,0.694322,1.489479,2.196359,1.840136,...,0.979922,0.772675,-1.982561,-3.796952,1.187684,0.185511,-2.467906,-2.88411,2.159065,-0.174052


Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
Intercept,-0.255255,-0.313529,-0.110844,0.006956,-0.204924,-0.037107,-0.195834,-0.377578,-0.133512,-0.314854,...,0.18074,0.149848,0.310313,-0.214509,0.303153,-0.012965,0.177621,0.340793,-1.179454,-0.395738
clusters_coarse[T.Ngn3 low EP],-0.506407,-0.853216,-1.949587,-1.018138,-0.186941,-2.442313,-2.003157,-0.966549,-1.502361,-2.312207,...,0.311675,-0.271375,-0.039985,-0.462057,0.408251,-0.878225,-0.53454,0.072971,-2.656342,-0.448245
clusters_coarse[T.Ngn3 high EP],-0.455978,-0.634149,-1.014705,-1.03623,-0.260413,-0.959062,-0.232813,-0.812435,-0.751226,-1.548665,...,0.198856,0.032466,0.256123,0.138052,0.165829,0.180643,0.254098,0.344766,-1.931066,-0.168606
clusters_coarse[T.Pre-endocrine],-0.338756,-0.49298,-0.194121,-0.397644,-0.494959,-0.032659,0.020459,-0.769262,-0.781942,-1.448584,...,0.046769,0.240056,0.360416,0.03045,-0.278936,0.149393,0.065505,-0.046933,-2.028451,-0.191925
clusters_coarse[T.Endocrine],-0.224659,-0.308601,0.108855,0.14263,0.02324,0.109799,-0.420989,-0.369293,-0.092068,-0.236639,...,0.228293,0.080792,0.488631,-0.640758,0.332392,0.108718,0.214258,0.018494,-0.531121,-0.785129
clusters[T.Ngn3 low EP],-0.506407,-0.853216,-1.949587,-1.018138,-0.186941,-2.442313,-2.003157,-0.966549,-1.502361,-2.312207,...,0.311675,-0.271375,-0.039985,-0.462057,0.408251,-0.878225,-0.53454,0.072971,-2.656342,-0.448245
clusters[T.Ngn3 high EP],-0.455978,-0.634149,-1.014705,-1.03623,-0.260413,-0.959062,-0.232813,-0.812435,-0.751226,-1.548665,...,0.198856,0.032466,0.256123,0.138052,0.165829,0.180643,0.254098,0.344766,-1.931066,-0.168606
clusters[T.Pre-endocrine],-0.338756,-0.49298,-0.194121,-0.397644,-0.494959,-0.032659,0.020459,-0.769262,-0.781942,-1.448584,...,0.046769,0.240056,0.360416,0.03045,-0.278936,0.149393,0.065505,-0.046933,-2.028451,-0.191925
clusters[T.Beta],-0.224659,-0.308601,0.108855,0.14263,0.02324,0.109799,-0.420989,-0.369293,-0.092068,-0.236639,...,0.228293,0.080792,0.488631,-0.640758,0.332392,0.108718,0.214258,0.018494,-0.531121,-0.785129
clusters[T.Alpha],0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [9]:
sim.sample(example_sce.obs)

AnnData object with n_obs × n_vars = 2087 × 100
    obs: 'clusters_coarse', 'clusters', 'S_score', 'G2M_score', 'cell_type', 'sizeFactor', 'pseudotime'

In [10]:
sim.predict(example_sce.obs)['mean']

Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
AAACCTGAGAGGGATA,23.587781,13.558171,27.169871,26.220190,0.581982,16.498044,15.476786,10.248462,16.066986,18.791500,...,3.919345,2.400918,2.340614,0.985596,4.464784,2.166168,1.770584,1.027317,2.013339,1.651052
AAACCTGGTAAGTGGC,14.233480,9.000059,17.324536,14.926376,2.450958,11.950230,13.049745,7.107926,9.366991,11.957307,...,3.080800,1.985816,3.809373,2.504979,3.334908,2.069663,3.246555,2.086491,1.184573,1.723180
AAACGGGCAAAGAATC,27.359210,15.291678,31.007837,30.937332,0.381555,18.136683,16.271698,11.410953,18.825355,21.459037,...,4.206424,2.538540,2.028706,0.749456,4.864171,2.195351,1.481838,0.834352,2.352650,1.630452
AAACGGGGTACAGTTC,39.294159,20.511345,42.808616,46.327990,0.136156,22.852551,18.387569,14.832528,27.713290,29.669926,...,4.998538,2.908485,1.430943,0.384071,5.995499,2.268240,0.959603,0.502129,3.440730,1.581245
AAACGGGGTGAAATCA,12.810604,8.263048,15.772984,13.272005,3.307754,11.173096,12.593779,6.585789,8.370275,10.881720,...,2.929976,1.908754,4.216555,3.042786,3.138069,2.050089,3.684038,2.418670,1.060549,1.738612
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGGTTTCACTTACT,7.249919,5.206962,9.498808,7.033652,16.720924,7.768378,10.391290,4.360236,4.556646,6.537867,...,2.233759,1.541140,7.300366,8.706113,2.258671,1.947451,7.295749,5.374866,0.583339,1.824446
TTTGGTTTCCTTTCGG,29.924866,16.444994,33.585371,34.190230,0.295632,19.204858,16.771746,12.176481,20.716969,23.251394,...,4.390012,2.625505,1.860732,0.635125,5.122653,2.213177,1.330689,0.735776,2.584844,1.618127
TTTGTCAAGAATGTGT,24.668240,14.059816,28.275814,27.563263,0.512324,16.976596,15.712634,10.586418,16.854341,19.560005,...,4.003903,2.441669,2.241687,0.907361,4.581797,2.174939,1.677916,0.964764,2.110287,1.644804
TTTGTCAAGTGACATA,19.743754,11.736286,23.188074,21.501398,0.965635,14.726823,14.574463,9.009352,13.286458,16.025842,...,3.600761,2.245666,2.778573,1.368876,4.028792,2.131679,2.192021,1.318472,1.670295,1.676102


In [11]:
sim.predict(example_sce.obs)['dispersion']

Unnamed: 0,Pyy,Iapp,Chgb,Rbp4,Spp1,Chga,Cck,Ins1,Nnat,Ins2,...,Nkx6-1,Fxyd3,Hn1,Smarcd2,Pdia6,Ffar2,Hes6,Serpinh1,Npy,1110012L19Rik
AAACCTGAGAGGGATA,0.168147,0.080586,0.437008,0.348179,0.140867,1.066568,0.787982,0.028506,0.064098,0.006090,...,2.007514,1.560245,4.002741,13.839005,0.362364,1.246233,3.050163,0.794166,0.000578,1.240091
AAACCTGGTAAGTGGC,0.130280,0.059453,0.037330,0.048654,0.306235,0.062296,0.373731,0.029722,0.075461,0.004906,...,2.668619,0.922700,2.554637,6.474050,1.609267,1.562689,3.904554,4.142165,0.000858,0.893214
AAACGGGCAAAGAATC,0.220227,0.124499,1.057516,1.788017,0.668743,1.642260,0.213582,0.085011,0.498321,0.219992,...,3.594456,1.022629,5.860222,2.457057,2.209799,1.129209,4.970021,0.989708,0.056190,0.249487
AAACGGGGTACAGTTC,0.185458,0.094574,0.992613,1.852283,0.661110,1.744956,0.217461,0.064741,0.478205,0.194551,...,4.074877,0.964024,5.399809,4.958843,1.980448,1.153579,5.928307,1.038551,0.063207,0.359110
AAACGGGGTGAAATCA,0.143631,0.055531,0.035283,0.049541,0.600479,0.077903,0.600830,0.038893,0.119714,0.006668,...,4.693064,0.976078,1.631863,7.172658,1.776514,2.327558,5.386043,4.402268,0.002857,0.947100
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGGTTTCACTTACT,0.109413,0.046119,0.033571,0.044460,0.250306,0.047139,0.313891,0.022878,0.063041,0.003938,...,1.368153,1.456468,2.061719,1.023172,2.065242,2.053583,1.547395,20.662919,0.000841,0.644666
TTTGGTTTCCTTTCGG,0.247081,0.148104,1.120646,1.819799,0.722643,1.714411,0.227532,0.102456,0.534290,0.247071,...,4.193099,0.992894,6.716155,3.465155,2.168756,1.043543,6.065568,0.580939,0.056637,0.251633
TTTGTCAAGAATGTGT,0.151688,0.068861,0.416144,0.347672,0.133728,1.090505,0.754816,0.023915,0.061385,0.005540,...,1.950080,1.398730,3.485028,13.846082,0.346780,1.283760,2.995348,1.065560,0.000566,1.307076
TTTGTCAAGTGACATA,0.139696,0.060445,0.395112,0.337218,0.126554,1.011666,0.715928,0.021097,0.058248,0.005077,...,1.583474,1.564272,3.041989,7.418246,0.373253,1.440496,2.218866,1.997373,0.000585,1.182214


# Zero-inflated negative binomial

In [18]:
from scdesigner.simulators import ZeroInflatedNegBinRegressionSimulator
from scipy.stats import nbinom, bernoulli
from scipy.special import expit

n, g, d, p, z = 50000, 20, 2, 2, 2
X1 = np.random.normal(size=(n, d)) # 
X2 = np.random.normal(size=(n, p)) #  
X3 = np.random.normal(size=(n, z)) #  
B = np.random.normal(size=(d, g)) # feature x gene
D = np.random.normal(size=(p, g)) # feature x gene
Z = np.random.normal(size=(z, g)) # feature x gene
mu = np.exp(X1 @ B) # cell x gene
r = np.exp(X2 @ D) # cell x gene
pi = expit(X3 @ Z) # cell x gene

# generate samples
Y = nbinom(n=r, p=r / (r + mu)).rvs() * bernoulli(1 - pi).rvs()

X_1 = pd.DataFrame(X1, columns=[f"mean_dim{j}" for j in range(d)]) # cell x feature
X_2 = pd.DataFrame(X2, columns=[f"dispersion_dim{j}" for j in range(p)]) # cell x feature
X_3 = pd.DataFrame(X3, columns=[f"zero_inflation_dim{j}" for j in range(z)]) # cell x feature
obs = pd.concat([X_1, X_2, X_3], axis=1)

adata = AnnData(X=Y, obs=obs)




In [19]:
zinb = ZeroInflatedNegBinRegressionSimulator()
zinb.fit(adata, {"mean": "~ mean_dim0 + mean_dim1 - 1", 
                                      "dispersion": "~ dispersion_dim0 + dispersion_dim1 - 1",
                                      "zero_inflation": "~ zero_inflation_dim0 + zero_inflation_dim1 - 1"})
zinb_params = zinb.params

print("Ground Truth Mean (B):")
display(pd.DataFrame(B))    
print("Estimated Mean:")
display(zinb_params["beta_mean"])
print("Mean parameter error:", np.mean(np.abs(B - zinb_params["beta_mean"].values)))

print("Ground Truth Dispersion (D):")
display(pd.DataFrame(D))
print("Estimated Dispersion:")
display(zinb_params["beta_dispersion"])

print("Ground Truth Zero Inflation (Z):")
display(pd.DataFrame(Z))
print("Estimated Zero Inflation:")
display(zinb_params["beta_zero_inflation"])
print("Zero Inflation parameter error:", np.mean(np.abs(Z - zinb_params["beta_zero_inflation"].values)))

                                                            

Ground Truth Mean (B):




Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.830759,1.868547,0.104927,-0.92791,0.399936,-1.103933,0.243217,0.507953,0.281618,-1.549876,-0.301036,-1.517861,1.170399,0.227096,-1.089197,-0.417297,0.069886,-1.866273,0.312876,-1.99427
1,0.299466,1.857417,-0.443751,-0.037666,-1.234585,0.834531,0.133006,0.181133,0.36022,-0.884056,0.512352,-1.059508,-0.681821,0.41816,1.30127,-0.57941,-0.402473,0.031021,0.516724,-1.143915


Estimated Mean:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
mean_dim0,0.803175,1.773409,0.125912,-0.928065,0.394043,-1.054852,0.299162,0.449718,0.389053,-1.42307,-0.353429,-1.535924,1.263638,0.190291,-1.101896,-0.457582,0.206871,-1.879572,0.371077,-2.068558
mean_dim1,0.352255,1.907713,-0.404977,-0.029255,-1.271859,0.809067,0.053356,0.117288,0.55224,-0.845322,0.526786,-1.078402,-0.683391,0.298215,1.341911,-0.525089,-0.370211,-0.054542,0.506756,-1.122645


Mean parameter error: 0.051740949611253396
Ground Truth Dispersion (D):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.517337,1.612736,-0.177611,1.087662,0.072428,-1.210357,0.574411,-1.334476,-0.610085,-0.519971,-0.893339,-2.594068,2.095338,-0.029342,0.587215,1.006805,1.841024,0.059035,0.339229,1.084367
1,-0.51811,1.277952,-0.971916,1.216517,0.385764,-0.656876,0.115195,0.417191,1.229112,1.746698,0.202669,-1.422358,0.493081,0.531621,-1.401284,0.41517,-0.355651,-0.540177,-0.860018,-0.563338


Estimated Dispersion:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
dispersion_dim0,0.498598,1.60546,-0.062564,1.11011,0.166944,-1.253726,0.539588,-1.248889,-0.569653,-0.535279,-0.879178,-2.841081,2.005572,-0.11063,0.409972,0.922134,1.85483,0.155143,0.4055,1.091263
dispersion_dim1,-0.595123,1.22777,-0.776773,1.11374,0.553886,-0.754215,0.127098,0.384087,1.237249,1.606957,0.195936,-1.609433,0.456225,0.629856,-1.323742,0.181745,-0.504788,-0.668864,-0.860833,-0.563892


Ground Truth Zero Inflation (Z):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,-0.323643,1.598281,-1.926218,-0.978242,-1.169447,0.312451,-0.499746,0.439804,-0.99449,0.599093,2.059879,-1.280892,0.651386,-0.560845,-0.942254,-0.523457,-0.293852,-0.107265,0.152564,-0.000661
1,-0.015822,0.187606,-0.277698,1.244773,0.405467,1.137762,1.892541,-1.800827,-1.372026,0.37343,-0.378648,0.562407,-0.377945,-1.37134,-0.458047,0.89476,1.084383,0.227401,-0.139929,-0.914897


Estimated Zero Inflation:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
zero_inflation_dim0,-0.266378,1.581174,-1.819479,-1.057382,-1.33872,0.312284,-0.424209,0.391051,-1.01416,0.608005,2.011078,-1.412448,0.683002,-0.372843,-0.944482,-0.611184,-0.451012,-0.245723,0.158835,-0.044981
zero_inflation_dim1,-0.089384,0.184316,-0.255163,1.299178,0.470648,1.088871,2.063005,-1.978328,-1.350777,0.384276,-0.487192,0.498364,-0.176748,-1.458619,-0.48464,0.873127,0.917546,0.279764,-0.210301,-1.181303


Zero Inflation parameter error: 0.07829732522515939


In [21]:
from scdesigner.simulators import ZeroInflatedNegBinRegressionSimulator
from scipy.stats import nbinom, bernoulli
from scipy.special import expit

n, g, d, p, z = 50000, 20, 2, 2, 2





zinb.sample(obs).X



array([[ 0,  0,  0, ...,  7,  1,  0],
       [ 0,  0,  0, ...,  0,  0,  0],
       [ 0,  0,  0, ...,  0,  0, 22],
       ...,
       [ 0,  0,  5, ...,  0,  0,  0],
       [ 4,  1,  0, ...,  0,  0,  0],
       [ 0,  0,  1, ...,  0,  0,  0]])