In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

In [3]:
import scanpy as sc 
import numpy as np 
import pandas as pd 
import seaborn as sns
import matplotlib.pyplot as plt
import os, sys 
sys.path.append('../../src')

In [4]:
import sys
sys.path.append('../../src/')
# import celloracle as co

In [5]:
adata = sc.read_h5ad('/ix/djishnu/shared/djishnu_kor11/training_data_2025/snrna_human_tonsil.h5ad')
adata

AnnData object with n_obs × n_vars = 5778 × 3549
    obs: 'cell_type', 'author_cell_type', 'cell_type_int', 'leiden', 'leiden_R', 'cell_type_2'
    uns: 'author_cell_type_colors', 'cell_type_2_colors', 'cell_type_colors', 'dendrogram_leiden', 'leiden', 'leiden_R', 'leiden_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap', 'ora_estimate', 'ora_pvals', 'spatial', 'spatial_unscaled'
    varm: 'PCs'
    layers: 'imputed_count', 'normalized_count'
    obsp: 'connectivities', 'distances'

In [6]:
from spaceoracle.tools.network import RegulatoryFactory
co_grn = RegulatoryFactory(
    colinks_path='/ix/djishnu/shared/djishnu_kor11/training_data_2025/snrna_human_tonsil_colinks.pkl',
    annot='cell_type_int'
)

In [7]:
from spaceoracle.models.parallel_estimators import SpatialCellularProgramsEstimator

estimator = SpatialCellularProgramsEstimator(
    adata, target_gene='PAX5', cluster_annot='cell_type_int',
    radius=200, contact_distance=30, grn=co_grn)

In [12]:
from spaceoracle.models.parallel_estimators import *

# cell_threshes = pd.read_csv('../../data/tonsil/cell_threshes.csv', index_col=0)
# adata.obs['cell_thresholds'] = cell_threshes.loc[adata.obs_names, '0']
# cell_threshes.mean().item()

cell_threshes = pd.read_parquet('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_LRs.parquet')
adata.uns['cell_thresholds'] = cell_threshes
adata.uns['cell_thresholds'].shape

(5778, 198)

In [9]:
# df = received_ligands(
#     adata.obsm['spatial'], 
#     get_ligands_df(adata.to_df(layer='imputed_count'), adata.obs['cell_thresholds'], estimator.ligands),
#     lr_info=estimator.lr 
# )

# df_nofilt = received_ligands(
#     adata.obsm['spatial'], 
#     get_ligands_df(adata.to_df(layer='imputed_count'), np.zeros(adata.n_obs), estimator.ligands),
#     lr_info=estimator.lr 
# )

# (df - df_nofilt).sum(axis=0)

In [None]:
estimator.fit(num_epochs=100, learning_rate=5e-3, score_threshold=0.1, coef_filter=0.001)
estimator.betadata.to_parquet('filtered_betadata.parquet')

In [12]:
adata.obs['cell_thresholds'] = 0
estimator.fit(num_epochs=100, learning_rate=5e-3, score_threshold=0.1, coef_filter=0.001)
estimator.betadata.to_parquet('unfiltered_betadata.parquet')

Fitting PAX5 with 278 modulators
	22 Transcription Factors
	244 Ligand-Receptor Pairs
	12 TranscriptionFactor-Ligand Pairs
0: 0.8724 | 0.8645
1: 0.9500 | 0.9351
2: 0.9538 | 0.9220
3: 0.9659 | 0.9470
4: 0.9421 | 0.9425
5: 0.9981 | 0.9983
6: 0.9947 | 0.9955
7: 0.9445 | 0.8869
8: 0.9567 | 0.9246


In [None]:
_receptors = np.unique(estimator.lr.receptor.values)
_layer = 'normalized_count' if 'normalized_count' in estimator.adata.layers else 'imputed_count'
receptor_levels = estimator.adata.to_df(layer=_layer)[np.unique(_receptors)].join(
    estimator.adata.obs[estimator.cluster_annot]).groupby(estimator.cluster_annot).mean().max(0).to_frame()
receptor_levels.columns = ['mean_max']

In [None]:
receptor_levels[receptor_levels.mean_max < 0.2]

In [None]:
# %matplotlib inline
# estimator.plot_modulators()

In [None]:
estimator.fit(num_epochs=100, learning_rate=5e-3, score_threshold=0.1, coef_filter=0.001)

In [None]:
betadata = estimator.betadata

In [None]:
anchor = estimator.models[0].anchors.cpu().numpy()

In [None]:
estimator.adata.obs['cell_type_int'].value_counts()

In [None]:
pd.DataFrame(np.where(np.abs(anchor) < (1/100.0), 0, anchor)[1:], 
    index=estimator.modulators).abs().sort_values(by=0, ascending=True)

In [None]:
beta_cols = [i for i in betadata.columns if 'beta_' in i]

In [None]:
betadata[beta_cols].mean().sort_values(ascending=False)

In [None]:
from spaceoracle.prophets import Prophet
import anndata as ad
import pandas as pd
import matplotlib

In [None]:
betadata.to_parquet(f'/tmp/models/{estimator.target_gene}_betadata.parquet')

In [None]:
pythia = Prophet(
    adata=estimator.adata,
    models_dir='/tmp/models',
    annot='cell_type_int',
    annot_labels='cell_type'
)

In [None]:
gex_df = estimator.adata.to_df(layer=estimator.layer)
pythia.compute_betas()
gene_mtx = pythia.adata.layers['imputed_count']
weighted_ligands = pythia._compute_weighted_ligands(gene_mtx)
beta_dict = pythia._get_wbetas_dict(
    pythia.beta_dict, weighted_ligands, gene_mtx)
wbetas = beta_dict.data[estimator.target_gene].wbetas
xy = beta_dict.xydf.copy()
xy.columns = ['x', 'y']
df = wbetas \
        .join(estimator.adata.obs) \
        .join(xy) \
        .join(gex_df)

In [None]:
beta_cols_df = [i for i in df.columns if 'beta_' in i and i.replace('beta_', '') in np.unique(estimator.ligands)]

In [None]:
beta_cols = [i for i in betadata.columns if 'beta_' in i and '$' in i]

In [None]:
betadata[beta_cols].mean().sort_values(ascending=False)

In [None]:
df[['beta_GAS6', 'beta_NPPC']].mean()

In [None]:
# Create figure with 4 subplots in 2x2 layout
fig, axes = plt.subplots(2, 2, figsize=(14, 10), dpi=200)
axes = axes.flatten()

modulator = 'CCL21'

# Join data once for efficiency
plot_data = df[[f'beta_{modulator}', 'x', 'y', 'cell_type']].join(adata.to_df(layer='imputed_count'))

# Common plot settings
plot_settings = {
    'linewidth': 0.2,
    'edgecolor': 'black',
    's': 30
}

# Plot for cell_type
sns.scatterplot(
    data=plot_data,
    x='x',
    y='y', 
    hue='cell_type',
    palette='tab20',
    legend='brief',
    ax=axes[0],
    **plot_settings
)
axes[0].set_title('Cell Type', fontsize=14)
axes[0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# Plot for modulator gene with colorbar
scatter1 = axes[1].scatter(
    plot_data['x'],
    plot_data['y'],
    c=plot_data[modulator],
    cmap='viridis',
    **plot_settings
)
axes[1].set_title(modulator, fontsize=14)
plt.colorbar(scatter1, ax=axes[1], shrink=0.5)

# Plot for target gene with colorbar  
scatter2 = axes[2].scatter(
    plot_data['x'],
    plot_data['y'],
    c=plot_data[estimator.target_gene], 
    cmap='magma',
    **plot_settings
)
axes[2].set_title(estimator.target_gene, fontsize=14)
plt.colorbar(scatter2, ax=axes[2], shrink=0.5)

# Plot for beta_modulator with colorbar
scatter3 = axes[3].scatter(
    plot_data['x'],
    plot_data['y'],
    c=plot_data[f'beta_{modulator}'],
    cmap='rainbow', 
    **plot_settings
)
axes[3].set_title(f'Beta {modulator}', fontsize=14)
plt.colorbar(scatter3, ax=axes[3], shrink=0.5)

# Remove ticks and set equal aspect for all axes
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_frame_on(False)

# Remove borders from legends
for ax in axes:
    legend = ax.get_legend()
    if legend is not None:
        legend.set_frame_on(False)

plt.tight_layout()
plt.show()

In [None]:
import commot as ct

In [None]:
df_ligrec = ct.pp.ligand_receptor_database(
        database='CellChat', 
        species='human', 
        signaling_type="Secreted Signaling"
    )
df_ligrec.columns = ['ligand', 'receptor', 'pathway', 'signaling']
          

In [None]:
df_ligrec.query('ligand == "CCL21"')