In [None]:
import commot as ct
import scanpy as sc
import pandas as pd
import numpy as np

### Run Commot

In [None]:
adata = sc.read_h5ad('/ix/djishnu/shared/djishnu_kor11/training_data_2025/snrna_human_tonsil.h5ad')
adata

AnnData object with n_obs × n_vars = 5778 × 3549
    obs: 'cell_type', 'author_cell_type', 'cell_type_int', 'leiden', 'leiden_R', 'cell_type_2'
    uns: 'author_cell_type_colors', 'cell_type_2_colors', 'cell_type_colors', 'dendrogram_leiden', 'leiden', 'leiden_R', 'leiden_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap', 'ora_estimate', 'ora_pvals', 'spatial', 'spatial_unscaled'
    varm: 'PCs'
    layers: 'imputed_count', 'normalized_count'
    obsp: 'connectivities', 'distances'

In [None]:
# adata.X = adata.layers['imputed_count']
adata.X = adata.layers['normalized_count']


In [None]:
df_ligrec = ct.pp.ligand_receptor_database(
    database='CellChat', 
    species='human', 
    signaling_type=None
)
    
df_ligrec.columns = ['ligand', 'receptor', 'pathway', 'signaling']  

In [None]:
ct.tl.spatial_communication(adata,
    database_name='user_database', df_ligrec=df_ligrec, dis_thr=200, heteromeric=True)

In [None]:
adata

AnnData object with n_obs × n_vars = 5778 × 3549
    obs: 'cell_type', 'author_cell_type', 'cell_type_int', 'leiden', 'leiden_R', 'cell_type_2'
    uns: 'author_cell_type_colors', 'cell_type_2_colors', 'cell_type_colors', 'dendrogram_leiden', 'leiden', 'leiden_R', 'leiden_colors', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap', 'ora_estimate', 'ora_pvals', 'spatial', 'spatial_unscaled'
    varm: 'PCs'
    layers: 'imputed_count', 'normalized_count'
    obsp: 'connectivities', 'distances'

In [None]:
# adata.write_h5ad('commot.h5ad')
# adata = sc.read_h5ad('commot.h5ad')

In [None]:
lr_info = {k.replace('commot-user_database-', ''): v for k, v in adata.obsp.items() if 'commot-user_database-' in k}
len(lr_info)

239

In [None]:
df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor']
len(df_ligrec['name'].unique())

1938

In [None]:
df_ligrec = df_ligrec[df_ligrec['name'].isin(lr_info.keys())]
df_ligrec['signaling'].value_counts()

signaling
ECM-Receptor          89
Cell-Cell Contact     80
Secreted Signaling    69
Name: count, dtype: int64

### Get cluster communication scores

In [None]:
from tqdm import tqdm

for name in tqdm(df_ligrec['name'].unique()):

    ct.tl.cluster_communication(adata, database_name='user_database', pathway_name=name, clustering='cell_type',
        random_seed=12, n_permutations=100)


100%|██████████| 238/238 [06:40<00:00,  1.68s/it]


In [None]:
# adata.write_h5ad('commot_cluster.h5ad')

In [None]:
from collections import defaultdict
data_dict = defaultdict(dict)

for name in df_ligrec['name']:
    data_dict[name]['communication_matrix'] = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_matrix']
    data_dict[name]['communication_pvalue'] = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_pvalue']

import pickle
with open('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_communication.pkl', 'wb') as f:
    pickle.dump(data_dict, f)

In [None]:
data_dict[name]['communication_matrix']

In [None]:
def get_sig_interactions(value_matrix, p_matrix, pval=0.05):
    p_matrix = np.where(p_matrix < pval, 1, 0)
    return value_matrix * p_matrix

interactions = {}
for lig, rec in tqdm(zip(df_ligrec['ligand'], df_ligrec['receptor'])):
    name = lig + '-' + rec

    value_matrix = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_matrix']
    p_matrix = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_pvalue']

    sig_matrix = get_sig_interactions(value_matrix, p_matrix)
    
    if sig_matrix.sum().sum() > 0:
        interactions[name] = sig_matrix
    
len(interactions)

238it [00:00, 4192.56it/s]


238

### Get expanded LR masks

In [16]:
import sys
sys.path.append('../../src')

from spaceoracle.tools.network import expand_paired_interactions

In [17]:
expanded = expand_paired_interactions(df_ligrec)

genes = set(expanded.ligand) | set(expanded.receptor)
genes = list(genes)
len(genes)

198

In [50]:
df_ligrec['name'] = df_ligrec['ligand'] + '@' + df_ligrec['receptor']
len(df_ligrec['name'].unique())

238

In [59]:
x, y = zip(*[name.split('@') for name in df_ligrec['name']])
y = list(y)

In [None]:
units2genes = {lig: lig.split('_') for lig in x}
units = units2genes.keys()
cell_thresholds = df
counts_df = adata.to_df(layer='imputed_count')


In [None]:
gene_values = [counts_df[units2genes[u]].min(axis=1) for u in units]
np.array(gene_values).shape

(112, 5778)

In [88]:
adata.shape, len(units)

((5778, 3549), 112)

In [None]:
from collections import defaultdict

lr_units = defaultdict(lambda: defaultdict(list))

for lig, rec, name in zip(expanded['ligand'], expanded['receptor'], expanded['name']):
    lr_units[name]

In [58]:
# create cell x gene matrix
ct_masks = {ct: adata.obs['cell_type'] == ct for ct in adata.obs['cell_type'].unique()}

df = pd.DataFrame(index=adata.obs_names, columns=genes)
df = df.fillna(0)

for name, lig, rec in tqdm(zip(expanded.name, expanded.ligand, expanded.receptor), total=len(expanded)):

    interaction_df = interactions[name]

    tmp = interactions[name].sum(axis=0)
    for ct, val in zip(interactions[name].index, tmp):
        df.loc[ct_masks[ct], lig] += tmp[ct]
    
    tmp = interactions[name].sum(axis=1)
    for ct, val in zip(interactions[name].columns, tmp):
        df.loc[ct_masks[ct], rec] += tmp[ct]

df

100%|██████████| 320/320 [00:07<00:00, 43.57it/s]


Unnamed: 0_level_0,EPHA1,PTPRM,CD86,TNFSF11,CD80,CCL4,IL18,CNTN1,JAM3,CD8A,...,EDA,CD69,ITGA5,CXCR3,EFNA3,ITGB2,ICOSLG,SEMA6A,LRP5,SPN
NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGCGCCTTG-1,0.000000,0.001721,0.000932,0.000000,0.000425,0.000040,0.000258,0.000013,0.000016,0.000026,...,0.000072,0.000225,0.000074,0.000003,0.000002,0.000896,0.000225,0.000000,0.000003,0.000000
AAACCCAAGTGGACGT-1,0.000000,0.000192,0.000105,0.000000,0.000135,0.000000,0.000000,0.000000,0.000000,0.000014,...,0.000000,0.000095,0.000004,0.000014,0.000000,0.000781,0.000080,0.000000,0.000009,0.000101
AAACCCACAGAAGTGC-1,0.000001,0.000000,0.000000,0.000000,0.000011,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000074,0.000005,0.000000,0.000002,0.000339,0.000000,0.000000,0.000002,0.000083
AAACCCAGTCATTGCA-1,0.000206,0.004682,0.000179,0.000128,0.000287,0.000008,0.000000,0.000269,0.000651,0.000013,...,0.000028,0.000101,0.000083,0.000114,0.000022,0.001142,0.000015,0.000320,0.000007,0.000244
AAACCCATCATCGCAA-1,0.000000,0.000000,0.000000,0.000000,0.000014,0.000010,0.000000,0.000000,0.000560,0.000003,...,0.000000,0.000000,0.000008,0.000003,0.000010,0.000116,0.000000,0.000000,0.000002,0.000063
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGCAGGGACTA-1,0.000001,0.000000,0.000000,0.000000,0.000011,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000074,0.000005,0.000000,0.000002,0.000339,0.000000,0.000000,0.000002,0.000083
TTTGTTGCATTGTAGC-1,0.000000,0.000192,0.000105,0.000000,0.000135,0.000000,0.000000,0.000000,0.000000,0.000014,...,0.000000,0.000095,0.000004,0.000014,0.000000,0.000781,0.000080,0.000000,0.000009,0.000101
TTTGTTGGTACCACGC-1,0.000001,0.000000,0.000000,0.000000,0.000011,0.000000,0.000000,0.000000,0.000000,0.000000,...,0.000000,0.000074,0.000005,0.000000,0.000002,0.000339,0.000000,0.000000,0.000002,0.000083
TTTGTTGGTCTGTCCT-1,0.000000,0.000260,0.000709,0.000000,0.000307,0.000000,0.000060,0.000088,0.000000,0.000040,...,0.000265,0.000030,0.000046,0.000000,0.000001,0.000688,0.000091,0.000035,0.000001,0.000000


In [60]:
df.to_parquet('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_LRs.parquet')

### Get true LR pairs

In [26]:
genes = set(df_ligrec.ligand) | set(df_ligrec.receptor)
genes = list(genes)
len(genes)

199

In [27]:
# create cell x LR unit matrix
ct_masks = {ct: adata.obs['cell_type'] == ct for ct in adata.obs['cell_type'].unique()}

df = pd.DataFrame(index=adata.obs_names, columns=genes)
df = df.fillna(0)

for name, lig, rec in tqdm(zip(df_ligrec.name, df_ligrec.ligand, df_ligrec.receptor), total=len(df_ligrec)):

    interaction_df = interactions[name]

    tmp = interactions[name].sum(axis=0)
    for ct, val in zip(interactions[name].index, tmp):
        df.loc[ct_masks[ct], lig] += tmp[ct]
    
    tmp = interactions[name].sum(axis=1)
    for ct, val in zip(interactions[name].columns, tmp):
        df.loc[ct_masks[ct], rec] += tmp[ct]

df

100%|██████████| 238/238 [00:02<00:00, 98.67it/s] 


Unnamed: 0_level_0,IL4R_IL13RA1,TNFSF8,CCL2,GP6,TNFRSF8,TNFRSF13B,IGSF11,IL6R_IL6ST,SEMA3A,TIGIT,...,WNT3,TNC,IL6,NRG3,NRG2,NPR2,PECAM1,COL4A1,PLXNA4,ITGA10_ITGB1
NAME,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACCCAAGCGCCTTG-1,0.000004,0.000000,0.000004,0.000043,0.000071,0.000198,0.000037,0.000000,0.000016,0.000137,...,0.000000,0.000000,0.000008,0.000000,0.000000,0.000000,0.001013,0.000037,0.000053,0.000006
AAACCCAAGTGGACGT-1,0.000052,0.000000,0.000000,0.000109,0.000023,0.000000,0.000035,0.000000,0.000038,0.000113,...,0.000000,0.000076,0.000008,0.000000,0.000011,0.000051,0.001059,0.000041,0.000013,0.000026
AAACCCACAGAAGTGC-1,0.000021,0.000026,0.000000,0.000035,0.000022,0.000000,0.000000,0.000006,0.000000,0.000000,...,0.000000,0.000004,0.000000,0.000000,0.000000,0.000000,0.000000,0.000020,0.000038,0.000010
AAACCCAGTCATTGCA-1,0.000010,0.000069,0.000011,0.000802,0.000000,0.000511,0.000040,0.000000,0.000262,0.001175,...,0.000004,0.000216,0.000010,0.000208,0.000397,0.000123,0.001635,0.000126,0.000114,0.000253
AAACCCATCATCGCAA-1,0.000000,0.000032,0.000008,0.000298,0.000000,0.000075,0.000000,0.000025,0.000000,0.000000,...,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.000129,0.000022,0.000041,0.000045
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGCAGGGACTA-1,0.000021,0.000026,0.000000,0.000035,0.000022,0.000000,0.000000,0.000006,0.000000,0.000000,...,0.000000,0.000004,0.000000,0.000000,0.000000,0.000000,0.000000,0.000020,0.000038,0.000010
TTTGTTGCATTGTAGC-1,0.000052,0.000000,0.000000,0.000109,0.000023,0.000000,0.000035,0.000000,0.000038,0.000113,...,0.000000,0.000076,0.000008,0.000000,0.000011,0.000051,0.001059,0.000041,0.000013,0.000026
TTTGTTGGTACCACGC-1,0.000021,0.000026,0.000000,0.000035,0.000022,0.000000,0.000000,0.000006,0.000000,0.000000,...,0.000000,0.000004,0.000000,0.000000,0.000000,0.000000,0.000000,0.000020,0.000038,0.000010
TTTGTTGGTCTGTCCT-1,0.000000,0.000000,0.000000,0.000014,0.000027,0.000041,0.000021,0.000000,0.000000,0.000066,...,0.000000,0.000011,0.000000,0.000000,0.000000,0.000000,0.000966,0.000023,0.000020,0.000005


In [28]:
df.to_parquet('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_LRs_units.parquet')

### Scratch

In [None]:
# # def count_interactions(matrix):
# #     mask = matrix.astype(bool).toarray()
# #     mask = np.maximum(mask, mask.T)
# #     mask = np.triu(mask, k=1)
# #     return mask.sum()

# def count_interactions(matrix):
#     return matrix.sum()

# num_values = {k: count_interactions(lr_info[k]) for k in lr_info.keys()}

In [None]:
# import matplotlib.pyplot as plt

# plt.hist(list(num_values.values()), bins=1000)
# plt.semilogy()

# # threshold = round(adata.n_obs * 0.05)
# threshold = np.percentile(list(num_values.values()), 10)

# plt.text(threshold, plt.ylim()[1] * 0.9, f'Threshold: {threshold}', color='red', ha='center')
# plt.axvline(threshold, color='red', linestyle='dashed', linewidth=1)
# plt.axvspan(0, threshold, color='red', alpha=0.3)
# plt.xlim(0, 20000)
# plt.xlabel('Number of interactions')

# plt.show()

In [None]:
from collections import defaultdict

celltypes = adata.obs['cell_type'].unique()
interactions = defaultdict(lambda: defaultdict(dict))

for a in celltypes:
    a_mask = adata.obs['cell_type'] == a
    
    for b in celltypes:
        b_mask = adata.obs['cell_type'] == b
        
        for k, v in lr_info.items():
            if k == 'total=total':
                continue

            
            interactions[a][b][k]=np.sum(v[a_mask, :][:, b_mask].astype(bool))

len(interactions)

In [None]:
celltypes

In [None]:
sig_interactions = defaultdict(lambda: defaultdict(list))
discard = defaultdict(lambda: defaultdict(list))

cell_counts = {k: (adata.obs['cell_type'] == k).sum() for k in celltypes}

for sender in celltypes:
    for receiver in celltypes:
        
        tot_cells = cell_counts[sender] + cell_counts[receiver]
        
            
            observed = np.sum(v)
            null_distribution = [np.sum(np.random.permutation(v)) for _ in range(1000)]
            p_value = np.mean([null >= observed for null in null_distribution])
            
            if p_value < 0.05:
                sig_interactions[sender][receiver].append(name)
            else:
                discard[sender][receiver].append(name)

In [None]:
len(sig_interactions['T cells']['T cells']), len(discard['T cells']['T cells'])

In [None]:
interactions[sender][receiver].items()

In [None]:
sender = 'GC B'
receiver = 'Tfh'

sig_vals = [interactions[sender][receiver][s] for s in sig_interactions[sender][receiver]]
discard_vals = [interactions[sender][receiver][s] for s in discard[sender][receiver]]


In [None]:
interactions[sender][receiver]['CDH2-CDH2']

In [None]:
sig_interactions[sender][receiver]

In [None]:
plt.hist(sig_vals, bins=1000, alpha=0.5, label='Significant', color='blue')
# _ = plt.hist(discard_vals, bins=1000, alpha=0.5, label='Discarded', color='red')
plt.ylim(0, 20)
plt.xlim(0, 5000)

In [None]:


sig_interactions['GC B']['Tfh']