In [1]:
import commot as ct
import scanpy as sc
import pandas as pd
import numpy as np

### Run Commot

In [2]:
adata = sc.read_h5ad('/Users/koush/Projects/SpaceOracle/data/survey/mouse_kidney_13.h5ad')
adata

AnnData object with n_obs × n_vars = 8509 × 3058
    obs: 'ct1', 'ct2', 'ct3', 'ct4', 'cond', 'medulla_cortex', 'domain', 'cell_type', 'cell_type_int'
    obsm: 'spatial', 'spatial_unscaled'
    layers: 'imputed_count', 'normalized_count', 'raw_count'
    obsp: 'connectivities', 'distances'

In [3]:
# adata.X = adata.layers['imputed_count']
adata.X = adata.layers['normalized_count']


In [4]:
df_ligrec = ct.pp.ligand_receptor_database(
    database='CellChat', 
    species='mouse', 
    signaling_type=None
)
    
df_ligrec.columns = ['ligand', 'receptor', 'pathway', 'signaling']  

In [5]:
ct.tl.spatial_communication(adata,
    database_name='user_database', df_ligrec=df_ligrec, 
    dis_thr=200, heteromeric=True)

In [6]:
adata

AnnData object with n_obs × n_vars = 5778 × 3549
    obs: 'cell_type', 'author_cell_type', 'cell_type_int', 'leiden', 'leiden_R', 'cell_type_2'
    uns: 'author_cell_type_colors', 'cell_type_2_colors', 'cell_type_colors', 'dendrogram_leiden', 'leiden', 'leiden_R', 'leiden_colors', 'neighbors', 'pca', 'umap', 'commot-user_database-info'
    obsm: 'X_pca', 'X_umap', 'ora_estimate', 'ora_pvals', 'spatial', 'spatial_unscaled', 'commot-user_database-sum-sender', 'commot-user_database-sum-receiver'
    varm: 'PCs'
    layers: 'imputed_count', 'normalized_count'
    obsp: 'connectivities', 'distances', 'commot-user_database-EFNA5-EPHA3', 'commot-user_database-EFNA5-EPHA2', 'commot-user_database-EFNA5-EPHA1', 'commot-user_database-EFNA5-EPHA4', 'commot-user_database-SEMA6A-PLXNA4', 'commot-user_database-CD96-NECTIN1', 'commot-user_database-CD96-PVR', 'commot-user_database-CXCL10-CXCR3', 'commot-user_database-LCK-CD8A_CD8B', 'commot-user_database-NRXN3-NLGN3', 'commot-user_database-NRXN3-NLGN1'

In [7]:
# adata.write_h5ad('commot.h5ad')
# adata = sc.read_h5ad('commot.h5ad')

In [None]:
lr_info = {k.replace('commot-user_database-', ''): v for k, v in adata.obsp.items() if 'commot-user_database-' in k}
len(lr_info)

In [21]:
df_ligrec['name'] = df_ligrec['ligand'] + '-' + df_ligrec['receptor']
len(df_ligrec['name'].unique())

1938

In [10]:
df_ligrec = df_ligrec[df_ligrec['name'].isin(lr_info.keys())]
df_ligrec['signaling'].value_counts()

signaling
ECM-Receptor          89
Cell-Cell Contact     80
Secreted Signaling    69
Name: count, dtype: int64

### Get cluster communication scores

In [11]:
from tqdm import tqdm

for name in tqdm(df_ligrec['name'].unique()):

    ct.tl.cluster_communication(adata, database_name='user_database', pathway_name=name, clustering='cell_type',
        random_seed=12, n_permutations=100)


100%|██████████| 238/238 [21:02<00:00,  5.30s/it]


In [12]:
# adata.write_h5ad('commot_cluster.h5ad')

In [13]:
from collections import defaultdict
data_dict = defaultdict(dict)

for name in df_ligrec['name']:
    data_dict[name]['communication_matrix'] = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_matrix']
    data_dict[name]['communication_pvalue'] = adata.uns[f'commot_cluster-cell_type-user_database-{name}']['communication_pvalue']

import pickle
with open('/ix/djishnu/shared/djishnu_kor11/commot_outputs/tonsil_communication.pkl', 'wb') as f:
    pickle.dump(data_dict, f)

In [16]:
# check outputs

import pickle
with open('/ix/djishnu/shared/djishnu_kor11/commot_outputs/tonsil_communication.pkl', 'rb') as f:
    info = pickle.load(f)

len(info.keys())

238

In [36]:
del info['TGFB1-TGFBR1_TGFBR2']

In [37]:
def get_sig_interactions(value_matrix, p_matrix, pval=0.05):
    p_matrix = np.where(p_matrix < pval, 1, 0)
    return value_matrix * p_matrix

interactions = {}
for lig, rec in tqdm(zip(df_ligrec['ligand'], df_ligrec['receptor'])):
    name = lig + '-' + rec

    if name in info.keys():

        value_matrix = info[name]['communication_matrix']
        p_matrix = info[name]['communication_pvalue']

        sig_matrix = get_sig_interactions(value_matrix, p_matrix)
        
        if sig_matrix.sum().sum() > 0:
            interactions[name] = sig_matrix
    
len(interactions)

1939it [00:00, 30709.11it/s]


238

### Get expanded LR masks

In [38]:
import sys
sys.path.append('../../src')

from spaceoracle.tools.network import expand_paired_interactions

In [39]:
expanded = expand_paired_interactions(df_ligrec)

genes = set(expanded.ligand) | set(expanded.receptor)
genes = list(genes)
len(genes)

958

In [40]:
df_ligrec['name'] = df_ligrec['ligand'] + '@' + df_ligrec['receptor']
len(df_ligrec['name'].unique())

1938

In [41]:
x, y = zip(*[name.split('@') for name in df_ligrec['name']])
y = list(y)

In [42]:
# units2genes = {lig: lig.split('_') for lig in x}
# units = units2genes.keys()
# cell_thresholds = df
# counts_df = adata.to_df(layer='imputed_count')


In [43]:
from collections import defaultdict

lr_units = defaultdict(lambda: defaultdict(list))

for lig, rec, name in zip(expanded['ligand'], expanded['receptor'], expanded['name']):
    lr_units[name]

In [55]:
# create cell x gene matrix
ct_masks = {ct: adata.obs['cell_type'] == ct for ct in adata.obs['cell_type'].unique()}

df = pd.DataFrame(index=adata.obs_names, columns=genes)
df = df.fillna(0)

for name, lig, rec in tqdm(zip(expanded.name, expanded.ligand, expanded.receptor), total=len(expanded)):

    interaction_df = interactions.get(name)

    if interaction_df is not None:

        tmp = interactions[name].sum(axis=1)
        for ct, val in zip(interactions[name].index, tmp):
            df.loc[ct_masks[ct], lig] += tmp[ct]
        
        tmp = interactions[name].sum(axis=0)
        for ct, val in zip(interactions[name].columns, tmp):
            df.loc[ct_masks[ct], rec] += tmp[ct]

df.shape

100%|██████████| 2914/2914 [00:05<00:00, 527.30it/s] 


(5778, 958)

In [56]:
print('Number of LR filtered using celltype specificity:')
np.where(df > 0, 1, 0).sum().sum() / (df.shape[0] * df.shape[1])

Number of LR filtered using celltype specificity:


0.10439858624355142

In [57]:
# df.to_parquet('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_LRs.parquet')
df.to_parquet('/ix/djishnu/shared/djishnu_kor11/commot_outputs/tonsil_LRs.parquet')


### Get true LR pairs

In [58]:
genes = set(df_ligrec.ligand) | set(df_ligrec.receptor)
genes = list(genes)
len(genes)

1002

In [62]:
# create cell x LR unit matrix
ct_masks = {ct: adata.obs['cell_type'] == ct for ct in adata.obs['cell_type'].unique()}

df = pd.DataFrame(index=adata.obs_names, columns=genes)
df = df.fillna(0)

for lig, rec in tqdm(zip(df_ligrec.ligand, df_ligrec.receptor), total=len(df_ligrec)):
    name = lig + '-' + rec
    interaction_df = interactions.get(name)

    if interaction_df is not None:

        tmp = interactions[name].sum(axis=1)
        for ct, val in zip(interactions[name].index, tmp):
            df.loc[ct_masks[ct], lig] += tmp[ct]
        
        tmp = interactions[name].sum(axis=0)
        for ct, val in zip(interactions[name].columns, tmp):
            df.loc[ct_masks[ct], rec] += tmp[ct]

df.shape

100%|██████████| 1939/1939 [00:04<00:00, 452.96it/s]


(5778, 1002)

In [63]:
print('Number of LR filtered using celltype specificity:')
np.where(df > 0, 1, 0).sum().sum() / (df.shape[0] * df.shape[1])

Number of LR filtered using celltype specificity:


0.09818904938478874

In [61]:
# df.to_parquet('/ix/djishnu/shared/djishnu_kor11/miscellaneous/tonsil_commot_LRs_units.parquet')
df.to_parquet('/ix/djishnu/shared/djishnu_kor11/commot_outputs/tonsil_LRs_units.parquet')


### Scratch

In [None]:
# # def count_interactions(matrix):
# #     mask = matrix.astype(bool).toarray()
# #     mask = np.maximum(mask, mask.T)
# #     mask = np.triu(mask, k=1)
# #     return mask.sum()

# def count_interactions(matrix):
#     return matrix.sum()

# num_values = {k: count_interactions(lr_info[k]) for k in lr_info.keys()}

In [None]:
# import matplotlib.pyplot as plt

# plt.hist(list(num_values.values()), bins=1000)
# plt.semilogy()

# # threshold = round(adata.n_obs * 0.05)
# threshold = np.percentile(list(num_values.values()), 10)

# plt.text(threshold, plt.ylim()[1] * 0.9, f'Threshold: {threshold}', color='red', ha='center')
# plt.axvline(threshold, color='red', linestyle='dashed', linewidth=1)
# plt.axvspan(0, threshold, color='red', alpha=0.3)
# plt.xlim(0, 20000)
# plt.xlabel('Number of interactions')

# plt.show()

In [None]:
from collections import defaultdict

celltypes = adata.obs['cell_type'].unique()
interactions = defaultdict(lambda: defaultdict(dict))

for a in celltypes:
    a_mask = adata.obs['cell_type'] == a
    
    for b in celltypes:
        b_mask = adata.obs['cell_type'] == b
        
        for k, v in lr_info.items():
            if k == 'total=total':
                continue

            
            interactions[a][b][k]=np.sum(v[a_mask, :][:, b_mask].astype(bool))

len(interactions)

In [None]:
celltypes

In [None]:
sig_interactions = defaultdict(lambda: defaultdict(list))
discard = defaultdict(lambda: defaultdict(list))

cell_counts = {k: (adata.obs['cell_type'] == k).sum() for k in celltypes}

for sender in celltypes:
    for receiver in celltypes:
        
        tot_cells = cell_counts[sender] + cell_counts[receiver]
        
            
            observed = np.sum(v)
            null_distribution = [np.sum(np.random.permutation(v)) for _ in range(1000)]
            p_value = np.mean([null >= observed for null in null_distribution])
            
            if p_value < 0.05:
                sig_interactions[sender][receiver].append(name)
            else:
                discard[sender][receiver].append(name)

In [None]:
len(sig_interactions['T cells']['T cells']), len(discard['T cells']['T cells'])

In [None]:
interactions[sender][receiver].items()

In [None]:
sender = 'GC B'
receiver = 'Tfh'

sig_vals = [interactions[sender][receiver][s] for s in sig_interactions[sender][receiver]]
discard_vals = [interactions[sender][receiver][s] for s in discard[sender][receiver]]


In [None]:
interactions[sender][receiver]['CDH2-CDH2']

In [None]:
sig_interactions[sender][receiver]

In [None]:
plt.hist(sig_vals, bins=1000, alpha=0.5, label='Significant', color='blue')
# _ = plt.hist(discard_vals, bins=1000, alpha=0.5, label='Discarded', color='red')
plt.ylim(0, 20)
plt.xlim(0, 5000)

In [None]:


sig_interactions['GC B']['Tfh']