## Applies COMMOT to HVGs, SVGs, and unfiltered expression data set for human brain ST 

- matches cell barcodes in seurat and HVG/SVG matrices
- creates anndata object for COMMOT input
- runs COMMOT on different LR filtered SVG/HVG/unfiltered genes

author: @emilyekstrum
1/20/26

In [12]:
import pandas as pd
import numpy as np
import scanpy as sc
import commot as ct
import os
import pyreadr
import rpy2
from rpy2.robjects.packages import importr
import anndata as ad
import copy
from collections import defaultdict
from scipy import sparse
import math
import matplotlib.pyplot as plt

np.random.seed(42)

### Prepare data for COMMOT input
- align by barcode 
- get non negative expression matrix

In [13]:
# load in human data HVGs and SVGs - using unLR filtered data
# sorted by p value in ascending order
hvg_gene_cell_mat = os.path.join("/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/hvgs", "human_hvg_gene_cell_matrix.csv")
svg_gene_cell_mat = os.path.join("/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/svgs", "nnSVG_human_svg_gene_cell_matrix.csv")

In [14]:
# check first lines of HVG and SVG gene cell matrices
hvg_df = pd.read_csv(hvg_gene_cell_mat, index_col=0)
svg_df = pd.read_csv(svg_gene_cell_mat, index_col=0)

hvg_df.head() # gene by spot matrix
#svg_df.head()

Unnamed: 0,AAACAAGTATCTCCCA.1,AAACAATCTACTAGCA.1,AAACACCAATAACTGC.1,AAACAGCTTTCAGAAG.1,AAACAGGGTCTATATT.1,AAACAGTGTTCCTGGG.1,AAACATTTCCCGGATT.1,AAACCCGAACGAAATC.1,AAACCGGGTAGGTACC.1,AAACCGTTCGTCCAGG.1,...,TTGTGGTGGTACTAAG.1,TTGTGTATGCCACCAA.1,TTGTGTTTCCCGAAAG.1,TTGTTAGCAAATTCGA.1,TTGTTCAGTGTGCTAC.1,TTGTTGTGTGTCAAGA.1,TTGTTTCACATCCAGG.1,TTGTTTCATTAGTCTA.1,TTGTTTGTATTACACG.1,TTGTTTGTGTAAATTC.1
MBP,1.609438,2.197225,4.70953,2.639057,3.401197,3.73767,1.386294,0.0,1.791759,4.477337,...,1.609438,3.178054,2.890372,2.484907,1.791759,2.995732,3.89182,4.110874,3.850148,0.0
PLP1,1.386294,0.693147,3.871201,1.609438,3.218876,3.7612,1.386294,1.098612,2.197225,3.988984,...,0.693147,1.94591,1.791759,1.098612,1.609438,1.098612,4.59512,3.610918,2.890372,1.098612
IGKC,1.098612,0.0,0.693147,2.197225,0.693147,0.0,2.079442,0.0,1.098612,0.0,...,1.098612,0.0,0.0,0.693147,0.693147,1.386294,0.0,1.791759,0.0,0.0
NPY,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.693147,...,0.693147,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
GFAP,0.0,2.302585,2.995732,0.0,1.386294,2.70805,0.0,0.693147,1.386294,3.044522,...,1.386294,0.693147,2.079442,0.0,0.0,0.0,1.94591,3.583519,2.564949,0.693147


In [15]:
# convert matrices to gene lists
SVG = set(svg_df.index.tolist())
HVG = set(hvg_df.index.tolist())

In [16]:
# look at lists
print(f"First 10 HVGs: {list(HVG)[:10]}")
print(f"First 10 SVGs: {list(SVG)[:10]}")

First 10 HVGs: ['ESAM', 'NPTX1', 'RGS4', 'NUAK1', 'SNX29', 'LAGE3', 'TAF8', 'ZDHHC17', 'TNFAIP1', 'AC018647.1']
First 10 SVGs: ['NPTX1', 'RGS4', 'NUAK1', 'LAGE3', 'ZDHHC17', 'PTK2', 'CHD2', 'STIM1', 'EEF1A2', 'RPL38']


In [17]:
# get non negative expression and coordiantes from human seurat

import rpy2.robjects as ro
from rpy2.robjects import pandas2ri

# load in seurat object
ro.r('library(Seurat)')
ro.r('obj <- readRDS("/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/seurat_objs/humanbrain_seurat.rds")')

In [18]:
# look for images slot
print(ro.r('head(colnames(obj@meta.data), 50)'))  # show first 50 meta columns

 [1] "orig.ident"        "nCount_RNA"        "nFeature_RNA"     
 [4] "...1"              "imagerow"          "imagecol"         
 [7] "Manual.annotation" "percent.mt"        "nCount_SCT"       
[10] "nFeature_SCT"     



In [19]:
# get seurat barcodes
seurat_barcodes = list(ro.r('rownames(obj@meta.data)'))
seurat_barcodes = [str(x) for x in seurat_barcodes]

print("Seurat n spots:", len(seurat_barcodes))
print("CSV n spots:", svg_df.shape[1])

Seurat n spots: 2897
CSV n spots: 2897


In [20]:
# map csv columns to seurat barcodes by order
assert svg_df.shape[1] == len(seurat_barcodes), "Counts differ —> can't map by order"

# rename csv columns to seurat barcodes
svg_df.columns = seurat_barcodes
hvg_df.columns = seurat_barcodes

In [21]:
# map csv columns to seurat barcodes by order
assert svg_df.shape[1] == len(seurat_barcodes), "Counts differ —> can't map by order"

# rename csv columns to seurat barcodes
svg_df.columns = seurat_barcodes
hvg_df.columns = seurat_barcodes

In [30]:
# columns are available in the metadata
available_cols = list(ro.r('colnames(obj@meta.data)'))
print("Available metadata columns:")
for col in available_cols:
    print(f"  {col}")

# check for imagecol and imagerow or x and y
desired_cols = ['cell_id', 'imagecol', 'imagerow', 'x', 'y']
existing_cols = [col for col in desired_cols if col in available_cols]
print(f"\nCoordinate cols: {existing_cols}")

Available metadata columns:
  orig.ident
  nCount_RNA
  nFeature_RNA
  ...1
  imagerow
  imagecol
  Manual.annotation
  percent.mt
  nCount_SCT
  nFeature_SCT

Coordinate cols: ['imagecol', 'imagerow']


In [31]:
# get meta.data with rownames kept as a column
cols_to_select = ['"barcode"']  # include barcode
for col in existing_cols:
    cols_to_select.append(f'"{col}"')

col_selection = ', '.join(cols_to_select)
print(f"Selecting columns: {col_selection}")

r_command = f'''
md <- obj@meta.data
md$barcode <- rownames(md)
md[, c({col_selection})]
'''

md = ro.r(r_command)
md = pandas2ri.rpy2py(md)

# check types 
md["barcode"] = md["barcode"].astype(str)

# process cell_id if it exists
if 'cell_id' in md.columns:
    md["cell_id"] = md["cell_id"].astype(str)
    md = md.dropna(subset=["cell_id"])

print(f"\nMetadata shape: {md.shape}")
print(f"Metadata columns: {list(md.columns)}")
print("\nFirst few rows:")
print(md.head())

Selecting columns: "barcode", "imagecol", "imagerow"

Metadata shape: (2897, 3)
Metadata columns: ['barcode', 'imagecol', 'imagerow']

First few rows:
                               barcode    imagecol    imagerow
AAACAAGTATCTCCCA-1  AAACAAGTATCTCCCA-1  440.639079  381.098123
AAACAATCTACTAGCA-1  AAACAATCTACTAGCA-1  259.630972  126.327637
AAACACCAATAACTGC-1  AAACACCAATAACTGC-1  183.078314  427.767792
AAACAGCTTTCAGAAG-1  AAACAGCTTTCAGAAG-1  152.700275  341.269139
AAACAGGGTCTATATT-1  AAACAGGGTCTATATT-1  164.941500  362.916304


In [24]:
# check matches between svg spot ids and meta data cell ids/barcodes
spot_ids_csv = pd.Index(svg_df.columns.astype(str))

# check matches with barcodes (should always work)
n_match_barcode = spot_ids_csv.isin(pd.Index(md["barcode"])).sum()
print(f"CSV -> barcode matches: {n_match_barcode} of {len(spot_ids_csv)}")

# check matches with cell_id if it exists
if 'cell_id' in md.columns:
    n_match_cellid = spot_ids_csv.isin(pd.Index(md["cell_id"])).sum()
    print(f"CSV -> cell_id matches: {n_match_cellid} of {len(spot_ids_csv)}")
else:
    print("cell_id column not available in metadata")

CSV -> barcode matches: 2897 of 2897
cell_id column not available in metadata


In [25]:
# show renamed svg spot ids
print(pd.Index(svg_df.columns.astype(str))[:20].tolist())

# check overlap between seurat barcodes and csv barcodes
seurat_barcodes = pd.Index([str(x) for x in ro.r('colnames(obj)')])
csv_barcodes = pd.Index(svg_df.columns.astype(str))

print("Seurat n:", len(seurat_barcodes))
print("CSV n:", len(csv_barcodes))
print("Overlap:", len(seurat_barcodes.intersection(csv_barcodes)))
print("Seurat example:", seurat_barcodes[:5].tolist())

['AAACAAGTATCTCCCA-1', 'AAACAATCTACTAGCA-1', 'AAACACCAATAACTGC-1', 'AAACAGCTTTCAGAAG-1', 'AAACAGGGTCTATATT-1', 'AAACAGTGTTCCTGGG-1', 'AAACATTTCCCGGATT-1', 'AAACCCGAACGAAATC-1', 'AAACCGGGTAGGTACC-1', 'AAACCGTTCGTCCAGG-1', 'AAACCTAAGCAGCCGG-1', 'AAACCTCATGAAGTTG-1', 'AAACGAAGAACATACC-1', 'AAACGAGACGGTTGAT-1', 'AAACGGGCGTACGGGT-1', 'AAACGGTTGCGAACTG-1', 'AAACTCGGTTCGCAAT-1', 'AAACTCGTGATATAAG-1', 'AAACTGCTGGCTCCAA-1', 'AAACTTAATTGCACGC-1']
Seurat n: 2897
CSV n: 2897
Overlap: 2897
Seurat example: ['AAACAAGTATCTCCCA-1', 'AAACAATCTACTAGCA-1', 'AAACACCAATAACTGC-1', 'AAACAGCTTTCAGAAG-1', 'AAACAGGGTCTATATT-1']


In [26]:
# get coordinates
coords = ro.r('''
md <- obj@meta.data
md[, c("imagecol","imagerow")]
''')
coords = pandas2ri.rpy2py(coords)
coords.index = coords.index.astype(str)
coords.columns = ["x", "y"]  


print(coords.head()) # cell ID by location

                             x           y
AAACAAGTATCTCCCA-1  440.639079  381.098123
AAACAATCTACTAGCA-1  259.630972  126.327637
AAACACCAATAACTGC-1  183.078314  427.767792
AAACAGCTTTCAGAAG-1  152.700275  341.269139
AAACAGGGTCTATATT-1  164.941500  362.916304


In [27]:
# get counts
expr_df = ro.r('''
mat <- NULL
try({ mat <- GetAssayData(obj, assay="SCT", slot="counts") }, silent=TRUE)
if (is.null(mat)) {
  try({ mat <- GetAssayData(obj, assay="SCT", layer="counts") }, silent=TRUE)
}
if (is.null(mat)) stop("Couldn't access SCT counts.")
as.data.frame(as.matrix(mat))
''')
expr_df = pandas2ri.rpy2py(expr_df)
expr_df.index = expr_df.index.astype(str) # genes
expr_df.columns = expr_df.columns.astype(str)  # barcodes

print("expr_df shape (genes x spots):", expr_df.shape)
print(expr_df.head())  # first 5 spots

expr_df shape (genes x spots): (17650, 2897)
            AAACAAGTATCTCCCA-1  AAACAATCTACTAGCA-1  AAACACCAATAACTGC-1  \
AL627309.1                 0.0                 0.0                 0.0   
AL669831.5                 0.0                 0.0                 0.0   
LINC00115                  0.0                 0.0                 0.0   
FAM41C                     0.0                 0.0                 0.0   
AL645608.1                 0.0                 0.0                 0.0   

            AAACAGCTTTCAGAAG-1  AAACAGGGTCTATATT-1  AAACAGTGTTCCTGGG-1  \
AL627309.1                 0.0                 0.0                 0.0   
AL669831.5                 0.0                 0.0                 0.0   
LINC00115                  0.0                 0.0                 0.0   
FAM41C                     0.0                 0.0                 0.0   
AL645608.1                 0.0                 0.0                 0.0   

            AAACATTTCCCGGATT-1  AAACCCGAACGAAATC-1  AAACCGGGTAGGT

In [28]:
# align by barcodes
common = csv_barcodes.intersection(expr_df.columns).intersection(coords.index)
print("Common spots:", len(common))

# subset to common spots
expr_df = expr_df.loc[:, common]
coords_sub = coords.loc[common]

# nonnegative expression for COMMOT
expr_df = np.log1p(expr_df)

# spots x genes
spot_gene_mat = expr_df.T
print("Spots by gene shape:", spot_gene_mat.shape)

Common spots: 2897
Spots by gene shape: (2897, 17650)


In [29]:
# build AnnData object for COMMOT

adata = ad.AnnData(X=spot_gene_mat.values)
adata.obs_names = spot_gene_mat.index.astype(str)
adata.var_names = spot_gene_mat.columns.astype(str)
adata.obsm["spatial"] = coords_sub.loc[adata.obs_names, ["x","y"]].values

print("adata:", adata.shape)
print("spatial:", adata.obsm["spatial"].shape)
print("X min:", adata.X.min())

adata: (2897, 17650)
spatial: (2897, 2)
X min: 0.0


### Filter LR pairs with CellChat

In [32]:
# get human LR pairs from CellChat
df_lr = ct.pp.ligand_receptor_database(database="CellChat", species="human")

# filter for LR pairs present in data
df_lr_expr = ct.pp.filter_lr_database(df_lr, adata, heteromeric=True, heteromeric_delimiter="_",
                                      heteromeric_rule="min", filter_criteria="min_cell_pct",
                                      min_cell_pct=0.05)

# SVG/HVG filtering
def receptor_in_set(rec, geneset):
    return any(p in geneset for p in str(rec).split("_"))

df_lr_svg = df_lr_expr[df_lr_expr.iloc[:,0].isin(SVG) | df_lr_expr.iloc[:,1].apply(lambda r: receptor_in_set(r, SVG))].copy()
df_lr_hvg = df_lr_expr[df_lr_expr.iloc[:,0].isin(HVG) | df_lr_expr.iloc[:,1].apply(lambda r: receptor_in_set(r, HVG))].copy()

# LR df for unfiltered data
df_lr_unfiltered = ct.pp.filter_lr_database(df_lr, adata, heteromeric=True, heteromeric_delimiter="_",
                                      heteromeric_rule="min", filter_criteria="none")

In [33]:
# look at lr dataframes
print("SVG LR dataframe:")
print(df_lr_svg.head())
print("n LR pairs:", df_lr_svg.shape[0])

SVG LR dataframe:
        0             1    2                   3
1    BMP7   ACVR1_BMPR2  BMP  Secreted Signaling
3    BMP7  BMPR1A_BMPR2  BMP  Secreted Signaling
5    BMP7  BMPR1B_BMPR2  BMP  Secreted Signaling
7  WNT10B     FZD3_LRP6  WNT  Secreted Signaling
8   WNT7A     FZD3_LRP6  WNT  Secreted Signaling
n LR pairs: 58


In [34]:
print("HVG LR dataframe:")
print(df_lr_hvg.head())
print("n LR pairs:", df_lr_hvg.shape[0])

HVG LR dataframe:
      0              1    2                   3
0  BMP7   ACVR1_ACVR2A  BMP  Secreted Signaling
1  BMP7    ACVR1_BMPR2  BMP  Secreted Signaling
2  BMP7  BMPR1A_ACVR2A  BMP  Secreted Signaling
3  BMP7   BMPR1A_BMPR2  BMP  Secreted Signaling
4  BMP7  BMPR1B_ACVR2A  BMP  Secreted Signaling
n LR pairs: 80


In [35]:
print("Unfiltered LR dataframe:")
print(df_lr_unfiltered.head())
print("n LR pairs:", df_lr_unfiltered.shape[0])

Unfiltered LR dataframe:
       0              1     2                   3
0  TGFB1  TGFBR1_TGFBR2  TGFb  Secreted Signaling
1  TGFB2  TGFBR1_TGFBR2  TGFb  Secreted Signaling
2  TGFB3  TGFBR1_TGFBR2  TGFb  Secreted Signaling
3  TGFB1  ACVR1B_TGFBR2  TGFb  Secreted Signaling
4  TGFB1  ACVR1C_TGFBR2  TGFb  Secreted Signaling
n LR pairs: 642


# ADD INTERPRETATION OF DFs

In [36]:
# ligand/receptor columns
lig_col = df_lr.columns[0]
rec_col = df_lr.columns[2]

print(f"First 10 ligand names: {df_lr[lig_col].dropna().astype(str).tolist()[:10]}")
print(f"First 10 receptor names: {df_lr[rec_col].dropna().astype(str).tolist()[:10]}")

def split_parts(x):
    if pd.isna(x):
        return []
    # if receptor stored as list/tuple/set
    if isinstance(x, (list, tuple, set)):
        parts = []
        for y in x:
            parts.extend(str(y).split("_"))
        return parts
    return str(x).split("_")

lr_ligands = set(df_lr[lig_col].dropna().astype(str))
lr_receptors = set(p for r in df_lr[rec_col].dropna() for p in split_parts(r))
lr_genes = lr_ligands | lr_receptors

print("Using columns:", lig_col, rec_col)
print("LR genes:", len(lr_genes))
print("LR ∩ SVG:", len(lr_genes & SVG))
print("LR ∩ HVG:", len(lr_genes & HVG))
print("LR ∩ SVG ∩ HVG:", len(lr_genes & SVG & HVG))
print("SVG-only within LR:", len((lr_genes & SVG) - HVG))
print("HVG-only within LR:", len((lr_genes & HVG) - SVG))


First 10 ligand names: ['TGFB1', 'TGFB2', 'TGFB3', 'TGFB1', 'TGFB1', 'TGFB2', 'TGFB2', 'TGFB3', 'TGFB3', 'TGFB1']
First 10 receptor names: ['TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb', 'TGFb']
Using columns: 0 2
LR genes: 446
LR ∩ SVG: 24
LR ∩ HVG: 69
LR ∩ SVG ∩ HVG: 24
SVG-only within LR: 0
HVG-only within LR: 45


In [37]:
# colnames in lr df
LIG = "0"
REC = "1"
PATH = "2"
CAT = "3"  

# make copies & ensure strings
lr_df = df_lr.copy()
lr_df[LIG] = lr_df[LIG].astype(str)
lr_df[REC] = lr_df[REC].astype(str)

varset = set(adata.var_names)

def receptor_parts(rec: str):
    # heteromeric receptors are encoded as "A_B_C"
    return str(rec).split("_")

def receptor_in_var(rec: str) -> bool:
    parts = receptor_parts(rec)
    return any(p in varset for p in parts)

def receptor_in_set(rec: str, geneset: set) -> bool:
    parts = receptor_parts(rec)
    return any(p in geneset for p in parts)

# filter LR pairs to those present in the dataset
lr_expr = lr_df[
    lr_df[LIG].isin(varset) &
    lr_df[REC].apply(receptor_in_var)
].copy()

# SVG / HVG LR filters 
lr_svg = lr_expr[
    lr_expr[LIG].isin(SVG) |
    lr_expr[REC].apply(lambda r: receptor_in_set(r, SVG))
].copy()

lr_hvg = lr_expr[
    lr_expr[LIG].isin(HVG) |
    lr_expr[REC].apply(lambda r: receptor_in_set(r, HVG))
].copy()

print("LR total:", lr_df.shape)
print("LR expressed in data:", lr_expr.shape)
print("LR SVG-filtered:", lr_svg.shape)
print("LR HVG-filtered:", lr_hvg.shape)


LR total: (1199, 4)
LR expressed in data: (651, 4)
LR SVG-filtered: (201, 4)
LR HVG-filtered: (450, 4)


## Use stricter rules on LR filtering pairings
- ligand and receptor must be an SVG or HVG
- ligand must be a SVG or HVG
- receptor must be a SVG or HVG
- prevalence:
    - ligand and at least one receptor part expressed in at least 5% of spots

In [38]:
# filter based on LR both must be SVG or HVG
lr_svg_strict = lr_expr[
    lr_expr[LIG].isin(SVG) &
    lr_expr[REC].apply(lambda r: all(p in SVG for p in receptor_parts(r)))
].copy()

lr_hvg_strict = lr_expr[
    lr_expr[LIG].isin(HVG) &
    lr_expr[REC].apply(lambda r: all(p in HVG for p in receptor_parts(r)))
].copy()

print("LR SVG-strict filtered:", lr_svg_strict.shape)
print("LR HVG-strict filtered:", lr_hvg_strict.shape)

LR SVG-strict filtered: (14, 4)
LR HVG-strict filtered: (114, 4)


In [39]:
# filter based on ligand must be SVG or HVG
lr_svg_ligand = lr_expr[lr_expr[LIG].isin(SVG)].copy()
lr_hvg_ligand = lr_expr[lr_expr[LIG].isin(HVG)].copy()
lr_ligand = lr_expr.copy()

print("LR SVG-ligand filtered:", lr_svg_ligand.shape)
print("LR HVG-ligand filtered:", lr_hvg_ligand.shape)

LR SVG-ligand filtered: (68, 4)
LR HVG-ligand filtered: (210, 4)


In [40]:
# filter based on receptor must be SVG or HVG
lr_svg_receptor = lr_expr[lr_expr[REC].apply(lambda r: receptor_in_set(r, SVG))].copy()
lr_hvg_receptor = lr_expr[lr_expr[REC].apply(lambda r: receptor_in_set(r, HVG))].copy()

print("LR SVG-receptor filtered:", lr_svg_receptor.shape)
print("LR HVG-receptor filtered:", lr_hvg_receptor.shape)

LR SVG-receptor filtered: (162, 4)
LR HVG-receptor filtered: (393, 4)


In [41]:
# filter based on prevalence: ligand and at least one receptor part expressed in at least 5% of spots
Xdf = pd.DataFrame(adata.X, index=adata.obs_names, columns=adata.var_names)
pct_expr = (Xdf > 0).mean(axis=0)

min_pct = 0.05  # 5% of spots

def pair_prevalence_ok(df):
    lig_ok = df[LIG].map(lambda g: pct_expr.get(g, 0) >= min_pct)
    rec_ok = df[REC].map(lambda r: any(pct_expr.get(p, 0) >= min_pct for p in receptor_parts(r)))
    return df[lig_ok & rec_ok].copy()

prevalence_lr_expr = pair_prevalence_ok(lr_expr)
prevalence_lr_svg = pair_prevalence_ok(lr_svg)
prevalence_lr_hvg = pair_prevalence_ok(lr_hvg)

print("After prevalence filtering:")
print("LR expressed in data:", prevalence_lr_expr.shape)
print("LR SVG-filtered:", prevalence_lr_svg.shape)
print("LR HVG-filtered:", prevalence_lr_hvg.shape)

After prevalence filtering:
LR expressed in data: (131, 4)
LR SVG-filtered: (84, 4)
LR HVG-filtered: (131, 4)


## Apply OT with COMMOT

## SVGs
- simple SVGs
- L&R SVG
- Ligand SVG
- Receptor SVG
- 5% prevalence

In [None]:
# SVG LR sets
svg_lr_sets = {
    "simple_svg": lr_svg,
    "LR_svg" : lr_svg_strict,
    "L_svg" : lr_svg_ligand,
    "R_svg" : lr_svg_receptor,
    "prevalence_svg" : prevalence_lr_svg,
}

svg_results = {}

# run COMMOT for each LR set
for name, df in svg_lr_sets.items():
    ad = adata.copy()  # use shallow copy
    ct.tl.spatial_communication(
        adata=ad,
        database_name="cellchat",
        df_ligrec=df,
        dis_thr=500,
        heteromeric=True,
        pathway_sum=True,
    )
    # store results
    svg_results[name] = ad

print(svg_results.keys())


In [None]:
# save results
output_dir = "/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/COMMOT/human_svg_lr_sets"
os.makedirs(output_dir, exist_ok=True)
for name, ad in svg_results.items():
    output_path = os.path.join(output_dir, f"human_commot_{name}.h5ad")
    ad.write_h5ad(output_path)
    print(f"Saved {name} results to {output_path}")

In [None]:
# look at COMMOT outputs for SVG sets

def commot_inventory(ad):
    return {
        "n_obs": ad.n_obs,
        "n_vars": ad.n_vars,
        "n_obsp": len(ad.obsp.keys()),
        "obsp_keys_head": list(ad.obsp.keys())[:15],
        "uns_keys_head": list(ad.uns.keys())[:15],
    }

print("COMMOT results for SVG LR sets:")
for name, ad in svg_results.items():
    inv = commot_inventory(ad)
    print("\n==", name, "==")
    print("n_obsp:", inv["n_obsp"])
    print("obsp head:", inv["obsp_keys_head"])


In [None]:
# pathway-sum matrices vs LR pair matrices

def split_pathway_lr_keys(ad):
    keys = list(ad.obsp.keys())
    # pathway matrices typically have much larger nnz and fewer keys overall;
    # also often contain "pathway" or "sum" in key names (depends on version).
    path_keys = [k for k in keys if ("path" in k.lower()) or ("sum" in k.lower())]
    # fallback: if that fails, treat everything as LR and you can manually adjust later
    if len(path_keys) == 0:
        path_keys = []
    lr_keys = [k for k in keys if k not in path_keys]
    return path_keys, lr_keys

print("Pathway vs LR matrices for SVG LR sets:")
for name, ad in  svg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    print(name, "path:", len(pk), "lr:", len(lk))


In [None]:
# rank pathways by total communication strength
def rank_obsp_by_total(ad, keys=None):
    if keys is None:
        keys = list(ad.obsp.keys())
    rows = []
    for k in keys:
        M = ad.obsp[k]
        total = M.sum() if sparse.issparse(M) else np.sum(M)
        rows.append((k, float(total)))
    rows.sort(key=lambda x: x[1], reverse=True)
    return rows

print("Top pathways by total communication strength for SVG LR sets:")
for name, ad in svg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top = rank_obsp_by_total(ad, keys_to_rank)[:10]
    print("\nTOP in", name)
    for k, s in top:
        print(f"{s: .3e}", k)


In [None]:
# per spot LR sender and receiver scores
def add_sender_receiver(ad, key, prefix):
    M = ad.obsp[key]
    out = np.array(M.sum(axis=1)).ravel()
    inn = np.array(M.sum(axis=0)).ravel()
    ad.obs[f"{prefix}__out"] = out
    ad.obs[f"{prefix}__in"] = inn
    return out, inn

print("Adding top 3 sender/receiver scores for SVG LR sets")
for name, ad in svg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top3 = [k for k, _ in rank_obsp_by_total(ad, keys_to_rank)[:3]]
    for i, k in enumerate(top3, 1):
        add_sender_receiver(ad, k, prefix=f"{name}_top{i}")


In [None]:
# plot top LR pairs for a chosen LR set
lr_set_name = "simple_svg"
ad = svg_results[lr_set_name]

# spatial coordinates
xy = ad.obsm["spatial"]

# COMMOT LR keys ranked by total communication
lr_keys = list(ad.obsp.keys())
lr_ranked = sorted(
    lr_keys,
    key=lambda k: ad.obsp[k].sum() if sparse.issparse(ad.obsp[k]) else np.sum(ad.obsp[k]),
    reverse=True
)

# choose top 4 LR pairs
top_lr = lr_ranked[:4]

# plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

panel = 0

for lr_key in top_lr:
    # sender / receiver values
    M = ad.obsp[lr_key]
    out_vals = np.array(M.sum(axis=1)).ravel()
    in_vals  = np.array(M.sum(axis=0)).ravel()

    lr_pretty = lr_key.replace("commot-cellchat-", "").replace("-", " → ", 1)

    for vals, direction in [(out_vals, "outgoing"), (in_vals, "incoming")]:
        ax = axes[panel]

        sca = ax.scatter(
            xy[:, 0],
            xy[:, 1],
            c=vals,
            s=20,
            cmap="viridis"
        )
        ax.invert_yaxis()
        ax.set_title(f"{lr_pretty}\n({direction})", fontsize=10)
        ax.axis("off")

        # colorbar
        cbar = fig.colorbar(sca, ax=ax, fraction=0.046, pad=0.02)
        cbar.set_label(
            "COMMOT communication strength\n(total OT mass per spot)",
            fontsize=9
        )

        panel += 1

fig.suptitle(
    "Spatial ligand–receptor communication inferred by COMMOT\n"
    f"LR set: {lr_set_name}",
    fontsize=16,
    fontweight='bold',
    y=1.02
)

plt.tight_layout()
plt.show()

In [None]:
# compare LR sets summary table

def top_totals_table(results_dict, top_n=20):
    # get keys from first result
    ref_name = list(results_dict.keys())[0]
    ref = results_dict[ref_name]
    pk, lk = split_pathway_lr_keys(ref)
    keys = pk if len(pk) > 0 else lk

    # rank by totals in reference
    ref_rank = rank_obsp_by_total(ref, keys)[:top_n]
    top_keys = [k for k, _ in ref_rank]

    out = pd.DataFrame(index=top_keys)
    for name, ad in results_dict.items():
        vals = []
        for k in top_keys:
            if k in ad.obsp:
                M = ad.obsp[k]
                vals.append(float(M.sum() if sparse.issparse(M) else np.sum(M)))
            else:
                vals.append(0.0)
        out[name] = vals
    return out

# summary table for SVG LR sets
svg_tbl = top_totals_table(svg_results, top_n=20)
svg_tbl.style.format("{:.2e}")


### Top pathways for SVGs:

# ADD INTERPRETATION

### SVG Null model (shuffled coords)

In [None]:
# null model testing with COMMOT
# shuffle spot locations and rerun COMMOT
n_shuffles = 5
null_svg_results = defaultdict(list)
for name, df in svg_lr_sets.items():
    print("Null model for LR set:", name)
    for i in range(n_shuffles):
        print(" Shuffle", i+1)
        ad = adata.copy()
        # shuffle spatial coordinates
        shuffled_xy = ad.obsm["spatial"].copy()
        np.random.shuffle(shuffled_xy)
        ad.obsm["spatial"] = shuffled_xy
        # run COMMOT
        ct.tl.spatial_communication(
            adata=ad,
            database_name="cellchat",
            df_ligrec=df,
            dis_thr=500,
            heteromeric=True,
            pathway_sum=True,
        )
        null_svg_results[name].append(ad)

In [None]:
# save null models
output_dir_null = "/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/COMMOT/human_svg_lr_sets_null"
os.makedirs(output_dir_null, exist_ok=True)
for name, ad_list in null_svg_results.items():
    for i, ad in enumerate(ad_list):
        output_path = os.path.join(output_dir_null, f"human_commot_{name}_null_{i+1}.h5ad")
        ad.write_h5ad(output_path)
        print(f"Saved null {name} shuffle {i+1} results to {output_path}")

In [None]:
# check null results
print("Null model COMMOT results for SVG LR sets:")
for name, ad_list in null_svg_results.items():
    print("\n==", name, "==")
    for i, ad in enumerate(ad_list):
        inv = commot_inventory(ad)
        print(f" Shuffle {i+1}: n_obsp: {inv['n_obsp']}, obsp head: {inv['obsp_keys_head']}")

In [None]:
# look at significance of top pathways in real vs null model
def pathway_significance(real_ad, null_ads, top_n=10):
    pk, lk = split_pathway_lr_keys(real_ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top_paths = [k for k, _ in rank_obsp_by_total(real_ad, keys_to_rank)[:top_n]]

    results = []
    for path in top_paths:
        real_M = real_ad.obsp[path]
        real_total = real_M.sum() if sparse.issparse(real_M) else np.sum(real_M)

        null_totals = []
        for null_ad in null_ads:
            null_M = null_ad.obsp.get(path)
            if null_M is not None:
                null_total = null_M.sum() if sparse.issparse(null_M) else np.sum(null_M)
                null_totals.append(null_total)
            else:
                null_totals.append(0.0)

        mean_null = np.mean(null_totals)
        std_null = np.std(null_totals)
        z_score = (real_total - mean_null) / std_null if std_null > 0 else float('inf')

        results.append({
            "pathway": path,
            "real_total": real_total, # total OT mass in real data
            "mean_null_total": mean_null, # mean total OT mass in null data
            "std_null_total": std_null, # std dev of total OT mass in null data
            "z_score": z_score 
        })

    return pd.DataFrame(results)

# significance for SVG LR sets
print("Pathway significance for SVG LR sets:")
for name, ad in svg_results.items():
    null_ads = null_svg_results[name]
    sig_df = pathway_significance(ad, null_ads, top_n=10)
    print("\n==", name, "==")
    print(sig_df)
    

### Null Results Interpretation

# ADD

## HVGs
- simple HVGs
- L&R HVG
- ligand HVG
- receptor HVG
- 5% prevalence

In [None]:
# HVG LR sets
hvg_lr_sets = {
    "simple_hvg": lr_hvg,
    "LR_hvg" : lr_hvg_strict,
    "L_hvg" : lr_hvg_ligand,
    "R_hvg" : lr_hvg_receptor,
    "prevalence_hvg" : prevalence_lr_hvg,
}

hvg_results = {}

# run COMMOT for each LR set
for name, df in hvg_lr_sets.items():
    ad = adata.copy()  # use shallow copy
    ct.tl.spatial_communication(
        adata=ad,
        database_name="cellchat",
        df_ligrec=df,
        dis_thr=500,
        heteromeric=True,
        pathway_sum=True,
    )
    hvg_results[name] = ad

print(hvg_results.keys())

dict_keys(['simple_hvg', 'LR_hvg', 'L_hvg', 'R_hvg', 'prevalence_hvg'])


In [None]:
# look at results
print("COMMOT results for HVG LR sets:")
for name, ad in hvg_results.items():
    inv = commot_inventory(ad)
    print("\n==", name, "==")
    print("n_obsp:", inv["n_obsp"])
    print("obsp head:", inv["obsp_keys_head"])

In [None]:
# save results
output_dir = "/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/COMMOT/human_hvg_lr_sets"
os.makedirs(output_dir, exist_ok=True)
for name, ad in svg_results.items():
    output_path = os.path.join(output_dir, f"human_commot_{name}.h5ad")
    ad.write_h5ad(output_path)
    print(f"Saved {name} results to {output_path}")

In [None]:
# pathway-sum matrices vs LR pair matrices
print("Pathway vs LR matrices for HVG LR sets:")
for name, ad in  hvg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    print(name, "path:", len(pk), "lr:", len(lk))

In [None]:
# rank pathways by total communication strength
print("Top pathways by total communication strength for HVG LR sets:")
for name, ad in hvg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top = rank_obsp_by_total(ad, keys_to_rank)[:10]
    print("\nTOP in", name)
    for k, s in top:
        print(f"{s: .3e}", k)

In [None]:
# per spot sender and receiver scores for top pathway in each HVG LR set
print("Adding top 3 sender/receiver scores for HVG LR sets")
for name, ad in hvg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top3 = [k for k, _ in rank_obsp_by_total(ad, keys_to_rank)[:3]]
    for i, k in enumerate(top3, 1):
        add_sender_receiver(ad, k, prefix=f"{name}_top{i}")

In [None]:
# plot top LR pairs for a chosen LR set
lr_set_name = "simple_hvg"
ad = hvg_results[lr_set_name]

# spatial coordinates
xy = ad.obsm["spatial"]

# COMMOT LR keys ranked by total communication
lr_keys = list(ad.obsp.keys())
lr_ranked = sorted(
    lr_keys,
    key=lambda k: ad.obsp[k].sum() if sparse.issparse(ad.obsp[k]) else np.sum(ad.obsp[k]),
    reverse=True
)

# choose top 4 LR pairs
top_lr = lr_ranked[:4]

# plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

panel = 0

for lr_key in top_lr:
    # sender / receiver values
    M = ad.obsp[lr_key]
    out_vals = np.array(M.sum(axis=1)).ravel()
    in_vals  = np.array(M.sum(axis=0)).ravel()

    lr_pretty = lr_key.replace("commot-cellchat-", "").replace("-", " → ", 1)

    for vals, direction in [(out_vals, "outgoing"), (in_vals, "incoming")]:
        ax = axes[panel]

        sca = ax.scatter(
            xy[:, 0],
            xy[:, 1],
            c=vals,
            s=20,
            cmap="viridis"
        )
        ax.invert_yaxis()
        ax.set_title(f"{lr_pretty}\n({direction})", fontsize=10)
        ax.axis("off")

        # colorbar
        cbar = fig.colorbar(sca, ax=ax, fraction=0.046, pad=0.02)
        cbar.set_label(
            "COMMOT communication strength\n(total OT mass per spot)",
            fontsize=9
        )

        panel += 1

fig.suptitle(
    "Spatial ligand–receptor communication inferred by COMMOT\n"
    f"LR set: {lr_set_name}",
    fontsize=16,
    fontweight='bold',
    y=1.02
)

plt.tight_layout()
plt.show()

In [None]:
# summary table for HVG LR sets
hvg_tbl = top_totals_table(hvg_results, top_n=20)
hvg_tbl.style.format("{:.2e}")

### Top pathways for HVGs:

# ADD

### HVG Null model (shuffled coords):

In [None]:
# null model testing with COMMOT
# shuffle spot locations and rerun COMMOT
n_shuffles = 5
null_hvg_results = defaultdict(list)
for name, df in hvg_lr_sets.items():
    print("Null model for LR set:", name)
    for i in range(n_shuffles):
        print(" Shuffle", i+1)
        ad = adata.copy()
        # shuffle spatial coordinates
        shuffled_xy = ad.obsm["spatial"].copy()
        np.random.shuffle(shuffled_xy)
        ad.obsm["spatial"] = shuffled_xy
        # run COMMOT
        ct.tl.spatial_communication(
            adata=ad,
            database_name="cellchat",
            df_ligrec=df,
            dis_thr=500,
            heteromeric=True,
            pathway_sum=True,
        )
        null_hvg_results[name].append(ad)

In [None]:
# check null results
print("Null model COMMOT results for SVG LR sets:")
for name, ad_list in null_hvg_results.items():
    print("\n==", name, "==")
    for i, ad in enumerate(ad_list):
        inv = commot_inventory(ad)
        print(f" Shuffle {i+1}: n_obsp: {inv['n_obsp']}, obsp head: {inv['obsp_keys_head']}")

In [None]:
# save null models
output_dir_null = "/Users/emilyekstrum/repos/zhangLab_Rotation/data/processed/COMMOT/human_hvg_lr_sets_null"
os.makedirs(output_dir_null, exist_ok=True)
for name, ad_list in null_hvg_results.items():
    for i, ad in enumerate(ad_list):
        output_path = os.path.join(output_dir_null, f"human_commot_{name}_null_{i+1}.h5ad")
        ad.write_h5ad(output_path)
        print(f"Saved null {name} shuffle {i+1} results to {output_path}")

In [None]:
# significance for HVG LR sets
print("Pathway significance for HVG LR sets:")
for name, ad in hvg_results.items():
    null_ads = null_hvg_results[name]
    sig_df = pathway_significance(ad, null_ads, top_n=10)
    print("\n==", name, "==")
    print(sig_df)

### Null Results Interpretaion

# ADD

# OVERALL: add

## Unfiltered genes

In [None]:
# commot for unfiltered LR pairs
lr_unfiltered = lr_expr.copy()   

unfiltered_results = {}

name = "unfiltered_lr"
ad = adata.copy()

ct.tl.spatial_communication(
    adata=ad,
    database_name="cellchat",
    df_ligrec=lr_unfiltered,
    dis_thr=500,
    heteromeric=True,
    pathway_sum=True,
)

unfiltered_results[name] = ad
print(unfiltered_results.keys())


In [None]:
# look at results
print("COMMOT results for unfiltered LR sets:")
for name, ad in unfiltered_results.items():
    inv = commot_inventory(ad)
    print("\n==", name, "==")
    print("n_obsp:", inv["n_obsp"])
    print("obsp head:", inv["obsp_keys_head"])

In [None]:
# pathway-sum matrices vs LR pair matrices
print("Pathway vs LR matrices for unfiltered LR sets:")
for name, ad in  hvg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    print(name, "path:", len(pk), "lr:", len(lk))

In [None]:
# per spot sender and receiver scores for top pathway in unfiltered LR set
print("Adding top 3 sender/receiver scores for unfiltered LR set:")
for name, ad in hvg_results.items():
    pk, lk = split_pathway_lr_keys(ad)
    keys_to_rank = pk if len(pk) > 0 else lk
    top3 = [k for k, _ in rank_obsp_by_total(ad, keys_to_rank)[:3]]
    for i, k in enumerate(top3, 1):
        add_sender_receiver(ad, k, prefix=f"{name}_top{i}")

In [None]:
# plot top LR pairs for a chosen LR set
lr_set_name = "unfiltered_lr"
ad = unfiltered_results[lr_set_name]

# spatial coordinates
xy = ad.obsm["spatial"]

# COMMOT LR keys ranked by total communication
lr_keys = list(ad.obsp.keys())
lr_ranked = sorted(
    lr_keys,
    key=lambda k: ad.obsp[k].sum() if sparse.issparse(ad.obsp[k]) else np.sum(ad.obsp[k]),
    reverse=True
)

# choose top 4 LR pairs
top_lr = lr_ranked[:4]

# plot
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

panel = 0

for lr_key in top_lr:
    # sender / receiver values
    M = ad.obsp[lr_key]
    out_vals = np.array(M.sum(axis=1)).ravel()
    in_vals  = np.array(M.sum(axis=0)).ravel()

    lr_pretty = lr_key.replace("commot-cellchat-", "").replace("-", " → ", 1)

    for vals, direction in [(out_vals, "outgoing"), (in_vals, "incoming")]:
        ax = axes[panel]

        sca = ax.scatter(
            xy[:, 0],
            xy[:, 1],
            c=vals,
            s=20,
            cmap="viridis"
        )
        ax.invert_yaxis()
        ax.set_title(f"{lr_pretty}\n({direction})", fontsize=10)
        ax.axis("off")

        # colorbar
        cbar = fig.colorbar(sca, ax=ax, fraction=0.046, pad=0.02)
        cbar.set_label(
            "COMMOT communication strength\n(total OT mass per spot)",
            fontsize=9
        )

        panel += 1

fig.suptitle(
    "Spatial ligand–receptor communication inferred by COMMOT\n"
    f"LR set: {lr_set_name}",
    fontsize=16,
    fontweight='bold',
    y=1.02
)

plt.tight_layout()
plt.show()

In [None]:
# summary table for unfiltered LR sets
unfilt_tbl = top_totals_table(unfiltered_results, top_n=20)
unfilt_tbl.style.format("{:.2e}")

### Top pathways for unfiltered genes: same as SVGs and HVGs

In [None]:
# check differences between LR sets for SVG, HVG, unfiltered
def compare_lr_sets(sets_dict):
    lr_sets = {}
    for name, df in sets_dict.items():
        pairs = set()
        for _, row in df.iterrows():
            lig = row[LIG]
            rec = row[REC]
            pairs.add((lig, rec))
        lr_sets[name] = pairs

    all_names = list(lr_sets.keys())
    n = len(all_names)

    comparison = pd.DataFrame(index=all_names, columns=all_names, dtype=int)

    for i in range(n):
        for j in range(n):
            set_i = lr_sets[all_names[i]]
            set_j = lr_sets[all_names[j]]
            intersection = set_i.intersection(set_j)
            comparison.iloc[i, j] = len(intersection)

    return comparison

# compare SVG LR sets
svg_comparison = compare_lr_sets(svg_lr_sets)
print("SVG LR sets comparison:")
print(svg_comparison)   
# compare HVG LR sets
hvg_comparison = compare_lr_sets(hvg_lr_sets)
print("\nHVG LR sets comparison:")
print(hvg_comparison)   
# compare unfiltered LR set to SVG and HVG simple sets
unfiltered_comparison = compare_lr_sets({
    "unfiltered_lr": lr_unfiltered,
    "simple_svg": lr_svg,
    "simple_hvg": lr_hvg,
})
print("\nUnfiltered vs SVG and HVG LR sets comparison:")
print(unfiltered_comparison)

### **SVGs:**
- ADD

***Spatial organization of LR signaling is WHAT***

### **HVGs:**
- ADD

### **All 3:**
- ADD

***SVG filtering DID WHAT***


In [None]:
# look at 5 interactions removed by SVG filtering but kept by HVG filtering
removed_by_svg = hvg_lr_sets["simple_hvg"].merge(
    svg_lr_sets["simple_svg"],
    on=[LIG, REC],
    how="left",
    indicator=True
)
removed_by_svg = removed_by_svg[removed_by_svg["_merge"] == "left_only"]
print("Interactions removed by SVG filtering but kept by HVG filtering:")
print(removed_by_svg[[LIG, REC]].head(5))

# ADD INTERP