In [1]:
%load_ext autoreload
%autoreload 2

from scdesigner.simulators import NegBinRegressionSimulator
import numpy as np
import pandas as pd
import torch
from anndata import AnnData


n, g, d, p = 1000, 20, 2, 2
X1 = np.random.normal(size=(n, d)) # 
X2 = np.random.normal(size=(n, p)) #    
B = np.random.normal(size=(d, g)) # feature x gene
D = np.random.normal(size=(p, g)) # feature x gene
mu = np.exp(X1 @ B) # cell x gene
r = np.exp(X2 @ D) # cell x gene

# generate samples
Y = np.random.negative_binomial(r, r / (r + mu))

X1 = pd.DataFrame(X1, columns=[f"mean_dim{j}" for j in range(d)]) # cell x feature
X2 = pd.DataFrame(X2, columns=[f"dispersion_dim{j}" for j in range(p)]) # cell x feature
obs = pd.concat([X1, X2], axis=1)

adata = AnnData(X=Y, obs=obs)

print("Testing the fixed NegBinRegressionSimulator...")
nb_simulator = NegBinRegressionSimulator()
nb_simulator.fit(adata, {"mean": "~ mean_dim0 + mean_dim1 - 1", 
                                      "dispersion": "~ dispersion_dim0 + dispersion_dim1 - 1"})
nb_params = nb_simulator.params

print("Success! Parameters fitted successfully.")
print("Parameter keys:", list(nb_params.keys()))





Testing the fixed NegBinRegressionSimulator...
entering multiple formula regression factory


                                                       

Success! Parameters fitted successfully.
Parameter keys: ['beta', 'gamma']


In [3]:
# Compare with ground truth
print("\n=== Ground Truth vs Estimated Parameters ===")
print("Ground Truth Mean (B):")
display(pd.DataFrame(B))    
print("Estimated Mean:")
display(nb_params["beta"])
print("Mean parameter error:", np.mean(np.abs(B - nb_params["beta"].values)))

print("\nGround Truth Dispersion (D):")
display(pd.DataFrame(D))
print("Estimated Dispersion:")
display(nb_params["gamma"])
print("Dispersion parameter error:", np.mean(np.abs(D - nb_params["gamma"].values)))


=== Ground Truth vs Estimated Parameters ===
Ground Truth Mean (B):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,-0.613864,-2.104499,-1.382038,-0.189769,-2.418826,0.438826,0.852657,1.066234,-0.632252,0.193501,0.132949,-1.073473,-1.409686,1.377358,-0.834224,-0.30419,-0.145774,-0.456245,-0.501939,0.29927
1,-0.270848,0.345033,1.052446,0.339236,2.073807,-0.407651,-0.622693,-0.387045,2.113464,0.234854,1.005284,0.810163,0.371313,-0.193813,-0.790187,-2.365495,0.255173,1.128326,0.382451,-0.559645


Estimated Mean:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
mean_dim0,-0.62954,-2.105627,-1.396842,-0.207962,-2.420953,0.481886,0.755186,1.066853,-0.592213,0.222169,0.07826,-1.001152,-1.423582,1.35648,-0.805646,-0.304494,-0.184405,-0.415714,-0.618809,0.263614
mean_dim1,-0.339657,0.305642,1.06981,0.307658,2.082364,-0.416282,-0.503697,-0.389013,2.124462,0.255303,0.972483,0.833231,0.424597,-0.175618,-0.756754,-2.366086,0.235245,1.237029,0.337295,-0.593393


Mean parameter error: 0.03449457738247254

Ground Truth Dispersion (D):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,-1.01727,-0.074272,1.807313,0.359854,1.19734,0.596052,-1.018463,-0.523826,1.069157,1.427817,0.628363,0.207551,0.912121,-0.214772,1.597167,-1.026347,1.096947,-0.096751,-1.475969,-1.359288
1,-0.970722,0.152219,0.554345,-0.19616,-0.465202,-1.033048,0.661043,0.573958,-1.313762,0.440222,1.497319,1.595154,0.295845,2.179896,0.514175,-0.728146,0.072044,2.34717,0.565068,1.149194


Estimated Dispersion:


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
dispersion_dim0,-1.03188,-0.072074,1.870452,0.389137,1.163701,0.533187,-1.059803,-0.62454,1.124853,1.488173,0.583153,0.215901,0.895988,-0.105632,1.623277,-1.001229,1.056538,-0.222109,-1.368697,-1.288661
dispersion_dim1,-1.160128,0.165182,0.563131,-0.164266,-0.520858,-1.111349,0.610582,0.521644,-1.329093,0.448263,1.614028,1.708055,0.329675,2.214091,0.468331,-0.749706,0.055794,2.255846,0.646498,1.167925


Dispersion parameter error: 0.052837529131105586


In [24]:
nb_params["beta_mean"]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
mean_dim0,1.297626,-0.246695,-2.135586,0.286499,0.442747,-0.713648,0.448765,0.264362,0.875127,0.117486,1.110423,-1.095093,0.445033,-2.064075,0.424929,-1.587959,-0.936134,-0.345938,-0.36356,-0.706762
mean_dim1,0.603265,1.610911,-0.819141,-0.877957,0.693311,-0.192277,-0.10983,1.057928,-0.045859,1.137581,-0.249566,0.554592,-1.777147,0.096214,0.322327,1.765976,1.078549,-0.382957,-0.918674,1.176236


In [25]:
nb_params["beta_dispersion"]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
dispersion_dim0,0.852666,-0.304257,0.465278,1.621052,0.620949,1.465132,0.885164,-1.283047,0.23287,-0.718458,-0.038469,1.088672,1.046635,0.738067,-1.360299,-0.439278,0.255443,1.60427,-0.697086,-0.983592
dispersion_dim1,1.070909,1.333931,0.499282,-0.622785,-0.435394,0.971504,1.410034,-0.076684,1.329404,2.677855,-1.792265,1.269645,-0.60944,-0.048674,0.365388,-0.170894,-0.875945,0.378463,1.559213,-1.425431


In [2]:
nb_simulator.predict(adata.obs)['mean']

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,2.982721,1.261885,0.386741,1.703782,1.334727,4.854627,0.168939,0.747625,2.598862,1.511230,0.099983,1.866795,2.078059,0.328218,3.890270,0.122204,1.336038,0.324144,0.587159,0.813708
1,0.134458,0.126885,0.434794,0.591089,11.049615,0.318113,0.088464,0.602579,0.248238,1.181892,0.299389,1.444642,1.959059,0.205616,1.186935,0.330251,1.110011,0.211895,1.089029,1.587644
2,0.979199,1.285991,1.523792,0.922189,0.628735,0.737229,2.518923,1.183272,0.928034,0.858468,2.444609,0.779804,0.719517,1.801359,0.642176,2.263865,0.900275,1.799778,1.161552,0.990914
3,0.014148,0.019848,0.353068,0.288407,71.644396,0.053640,0.028934,0.457648,0.046958,1.098385,0.357841,1.424639,2.360786,0.096877,0.678627,0.386361,1.043037,0.103055,1.541139,2.604852
4,2.876019,1.436006,0.495254,1.611588,1.034292,3.976377,0.288572,0.823586,2.442681,1.379520,0.170113,1.610798,1.713276,0.461431,2.974699,0.198626,1.254589,0.455621,0.644009,0.813598
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,2.301730,5.726159,5.712075,0.973574,0.075226,0.621621,59.801102,2.169140,1.467740,0.565056,31.366196,0.377852,0.253716,13.773852,0.219657,23.385595,0.678300,13.550180,1.565366,0.788648
996,1.462503,2.193297,2.177881,0.990687,0.313084,0.814066,6.231360,1.414148,1.193931,0.775336,4.650418,0.647752,0.541885,3.231967,0.509498,4.079668,0.841148,3.208073,1.219829,0.897859
997,1.005491,1.607005,2.095999,0.879750,0.429108,0.606883,5.132099,1.348866,0.905460,0.766951,4.753510,0.647648,0.560577,2.836841,0.466007,4.158033,0.833238,2.830808,1.290742,0.975081
998,0.241953,0.345942,1.038966,0.617555,2.678475,0.290177,0.718785,0.900356,0.341959,0.898513,1.599472,0.897333,0.985153,0.789949,0.590029,1.530343,0.922143,0.805531,1.319826,1.358695


In [3]:
nb_simulator.predict(adata.obs)['dispersion']

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19
0,0.866851,0.763338,6.317244,1.318886,2.570328,1.160147,0.897264,1.343635,1.262858,0.778239,0.293174,1.146162,1.458707,1.808278,2.748887,0.727826,0.385139,1.633967,1.723109,0.578388
1,0.875857,1.181346,60.777358,0.708161,3.181914,6.231706,0.535537,4.636204,1.485882,1.443550,0.031493,0.539273,0.280623,1.606156,9.397582,0.412842,0.303177,0.665509,1.269013,1.217588
2,0.947452,0.764768,0.772194,1.410979,1.276866,0.539150,1.182618,0.676896,1.016801,0.716356,1.584802,1.416105,2.199609,1.289076,0.872179,1.121849,0.788480,1.697418,1.365768,0.614128
3,0.936326,0.864076,2.062135,1.172130,1.521938,0.980386,0.977654,1.072851,1.103071,0.863546,0.641455,1.106948,1.294794,1.318771,1.488391,0.890902,0.654910,1.311353,1.302696,0.749132
4,1.329451,3.883166,2.800038,0.176863,0.264831,20.556479,0.444043,6.573732,0.893221,5.345214,0.121750,0.176986,0.019592,0.264730,1.715870,0.591954,3.657208,0.069178,0.200424,11.868016
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,0.966436,1.481488,21.537813,0.552117,1.706718,6.627084,0.548507,4.273589,1.287440,1.818217,0.058570,0.456785,0.184937,1.060271,5.312791,0.485289,0.567716,0.436522,0.850158,1.900672
996,0.701302,0.507649,93.208964,2.009623,10.373215,1.403506,0.771105,2.036467,1.779523,0.530898,0.049497,1.421767,2.627861,4.359300,12.036796,0.459110,0.094054,3.436706,3.880185,0.253594
997,1.086056,0.894762,0.074236,1.253074,0.483900,0.310513,1.489182,0.376187,0.778942,0.787197,8.978160,1.488555,2.270925,0.745452,0.242086,1.753220,2.114338,1.308548,0.866197,0.873610
998,0.777023,0.510728,8.443337,2.160145,4.641520,0.589211,1.055092,0.934542,1.388328,0.485223,0.339090,1.801932,4.156684,2.948524,3.242284,0.752011,0.214278,3.562305,2.960563,0.273540


In [4]:
from formulaic import model_matrix

x_mean = model_matrix("~ mean_dim0 + mean_dim1 - 1", adata.obs)
x_dispersion = model_matrix("~ dispersion_dim0 + dispersion_dim1 - 1", adata.obs)

x_mean

Unnamed: 0,mean_dim0,mean_dim1
0,-0.935438,0.271606
1,-0.535490,-0.953262
2,0.370245,0.065848
3,-0.501341,-1.894861
4,-0.715730,0.300590
...,...,...
995,1.443378,0.642395
996,0.643961,0.290063
997,0.646430,0.132734
998,0.169654,-0.563513


In [8]:
nb_params["gamma"].columns = ['gene_1', 'gene_2', 'gene_3', 'gene_4', 'gene_5', 'gene_6', 'gene_7', 'gene_8', 'gene_9', 'gene_10', 'gene_11', 'gene_12', 'gene_13', 'gene_14', 'gene_15', 'gene_16', 'gene_17', 'gene_18', 'gene_19', 'gene_20']
nb_params['gamma']

Unnamed: 0,gene_1,gene_2,gene_3,gene_4,gene_5,gene_6,gene_7,gene_8,gene_9,gene_10,gene_11,gene_12,gene_13,gene_14,gene_15,gene_16,gene_17,gene_18,gene_19,gene_20
dispersion_dim0,0.1971,0.451053,-2.09131,-0.495108,-1.249162,0.11361,0.050898,-0.170255,-0.288125,0.458748,1.255785,-0.328262,-0.825969,-0.831373,-1.149651,0.327156,1.258118,-0.839718,-0.800753,0.887922
dispersion_dim1,0.063011,0.494335,1.343484,-0.663419,-0.162879,1.456306,-0.423462,1.003005,0.058557,0.649984,-1.545529,-0.72946,-1.624561,-0.329301,0.725588,-0.390586,0.143441,-0.992518,-0.479694,0.875386


In [9]:
x_dispersion @ nb_params["gamma"] 

Unnamed: 0,gene_1,gene_2,gene_3,gene_4,gene_5,gene_6,gene_7,gene_8,gene_9,gene_10,gene_11,gene_12,gene_13,gene_14,gene_15,gene_16,gene_17,gene_18,gene_19,gene_20
0,-0.142888,-0.270054,1.843283,0.276787,0.944033,0.148547,-0.108405,0.295379,0.233377,-0.250722,-1.226989,0.136419,0.377551,0.592375,1.011196,-0.317694,-0.954152,0.491011,0.544130,-0.547510
1,-0.132552,0.166655,4.107217,-0.345083,1.157483,1.829650,-0.624485,1.533896,0.396009,0.367105,-3.457984,-0.617534,-1.270743,0.473844,2.240452,-0.884691,-1.193440,-0.407202,0.238239,0.196872
2,-0.053979,-0.268182,-0.258519,0.344283,0.244408,-0.617762,0.167731,-0.390238,0.016661,-0.333578,0.460459,0.347910,0.788279,0.253926,-0.136760,0.114978,-0.237648,0.529108,0.311717,-0.487551
3,-0.065792,-0.146094,0.723742,0.158822,0.419984,-0.019808,-0.022599,0.070320,0.098098,-0.146708,-0.444016,0.101606,0.258351,0.276700,0.397696,-0.115521,-0.423258,0.271059,0.264436,-0.288841
4,0.284766,1.356651,1.029633,-1.732380,-1.328662,3.023176,-0.811834,1.883082,-0.112921,1.676202,-2.105785,-1.731686,-3.932625,-1.329044,0.539920,-0.524327,1.296700,-2.671068,-1.607321,2.473847
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,-0.034140,0.393047,3.069810,-0.593995,0.534572,1.891165,-0.600555,1.452454,0.252656,0.597856,-2.837535,-0.783542,-1.687739,0.058524,1.670117,-0.723011,-0.566133,-0.828917,-0.162332,0.642207
996,-0.354816,-0.677965,4.534844,0.697947,2.339227,0.338973,-0.259931,0.711216,0.576346,-0.633185,-3.005843,0.351901,0.966170,1.472312,2.487968,-0.778466,-2.363884,1.234513,1.355883,-1.372019
997,0.082553,-0.111197,-2.600511,0.225600,-0.725876,-1.169530,0.398227,-0.977668,-0.249819,-0.239277,2.194795,0.397806,0.820187,-0.293764,-1.418463,0.561454,0.748742,0.268918,-0.143643,-0.135122
998,-0.252286,-0.671919,2.133378,0.770176,1.535042,-0.528971,0.053628,-0.067699,0.328100,-0.723146,-1.081489,0.588860,1.424718,1.081305,1.176278,-0.285005,-1.540481,1.270408,1.085379,-1.296307
