In [1]:
import dask
dask.config.set({"dataframe.query-planning": False})

import itertools
import liana
import scanpy as sc
import numpy as np
import pandas as pd
import sys
import gc
import matplotlib.patches as mpatches
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path
import PyComplexHeatmap as pch
from PyComplexHeatmap import HeatmapAnnotation, DotClustermapPlotter, anno_simple, anno_label

sys.path.append("../../../workflow/scripts/")
import _utils
import readwrite
cfg = readwrite.config()
sns.set_style('ticks')

def remove_self_edges(knn_graph, labels, cti, ctj):
    """
    Remove self edges from the knn graph for the given labels.
    """
    # Find which nodes are labeled cti,ctj
    is_cti = (labels == cti)
    is_ctj = (labels == ctj)

    # We'll zero out edges where both row and col are labeled cti or ctj
    indptr = knn_graph.indptr
    indices = knn_graph.indices
    data = knn_graph.data

    for i in range(knn_graph.shape[0]):
        start, end = indptr[i], indptr[i+1]
        neighbor_indices = indices[start:end]
    
        # Find neighbors that are also cti
        if is_cti[i]:
            mask = is_cti[neighbor_indices]
            data[start:end][mask] = 0
    
        # Find neighbors that are also ctj
        elif is_ctj[i]:
            mask = is_ctj[neighbor_indices]
            data[start:end][mask] = 0   

    knn_graph.eliminate_zeros()
    return knn_graph


def ccc_score(
    ad_cti_ctj,
    knn_graph,
    cti_receptor,
    ctj_ligand,
    labels_key,
    cti,
    correction_method,
    k,
):
    """
    Calculates the CCC score for a given cell type.

    Args:
        ad_cti_ctj (scanpy.AnnData): AnnData object containing cell data.
        knn_graph (scipy.sparse.csr_matrix): Connectivity graph.
        cti_receptor (str): Name of the receptor gene in adata.var names.
        ctj_ligand (str): Name of the ligand gene in adata.var names.
        labels_key (str): Key in adata.obs for cell type labels.
        cti (str): The specific receptor cell type label.
        correction_method (str): Name of the correction method used.
        k (tuple): Tuple representing the k value (e.g., (10,)).

    Returns:
        float: Mean product of receptor and neighbor ligand expression.
    """
    X_receptor = ad_cti_ctj[:, cti_receptor].X.toarray().squeeze()
    Y_ligand = ad_cti_ctj[:, ctj_ligand].X.toarray().squeeze()
    cti_idx = np.where(ad_cti_ctj.obs[labels_key] == cti)[0]

    indptr = knn_graph.indptr
    indices = knn_graph.indices

    ctj_ligand_product_mean = []
    for i in cti_idx:
        start, end = indptr[i], indptr[i+1]
        neighbor_indices = indices[start:end]
        ctj_ligand_product_mean.append((X_receptor[i] * Y_ligand[neighbor_indices]).mean())

    return np.nanmean(ctj_ligand_product_mean)



# Params

In [None]:
# cfg paths
xenium_processed_data_dir = Path(cfg['xenium_processed_data_dir'])
xenium_count_correction_dir = Path(cfg['xenium_count_correction_dir'])
xenium_std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
xenium_cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])
results_dir = Path(cfg['results_dir'])
palette_dir = Path(cfg['xenium_metadata_dir'])
std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
scrnaseq_processed_data_dir = Path(cfg['scrnaseq_processed_data_dir'])
seurat_to_h5_dir = results_dir / 'seurat_to_h5'

# Params
signal_integrity_thresholds = [0.5, 0.7]
correction_methods = ['raw', 'split_fully_purified', 'resolvi', 'resolvi_supervised']
correction_methods += [f'ovrlpy_correction_{signal_integrity_threshold=}' for signal_integrity_threshold in signal_integrity_thresholds]
num_samples = 30
mixture_k = 50
normalisation = 'lognorm'
layer = 'data'
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2.1'
segmentation_palette = palette_dir / 'col_palette_segmentation.csv'
count_correction_palette = palette_dir / 'col_palette_correction_method.csv'

list_n_markers = [10, 20, 30, 40, 50]
radius = 15
top_n = 20
markers_mode = 'diffexpr'

xenium_levels = ["segmentation", "condition", "panel", "donor", "sample"]
order = ['breast','chuvio','lung','5k']


hue_segmentation = "segmentation"
hue_segmentation_order = [
    "MM 0µm",
    "MM",
    "MM 15µm",
    "0µm",
    "5µm",
    "15µm",
    "Baysor",
    "ProSeg",
    "ProSeg mode",
    "Segger",
]

hue_correction = 'correction_method'
hue_correction_order = [
    'raw',
    'ResolVI',
    'ResolVI supervised',
    'ovrlpy 0.5',
    'ovrlpy 0.7',
    'SPLIT',
]


rank_metrics = ["logfoldchanges", "-log10pvals_x_logfoldchanges", "-log10pvals_x_sign_logfoldchanges","mean_zscore"]
plot_metrics = ['hypergeometric_pvalue','NES',f"n_hits_{top_n=}","mean_zscore_pvalue"]

labels_key = level

# Load corrected counts

In [None]:
xenium_paths = {}
xenium_annot_paths = {}

for correction_method in correction_methods:
    xenium_paths[correction_method] = {}
    xenium_annot_paths[correction_method] = {}
    
    for segmentation in (segmentations := xenium_std_seurat_analysis_dir.iterdir()):
        if segmentation.stem in ['proseg_mode','baysor','segger','10x_15um','10x_0um']:
            continue
        for condition in (conditions := segmentation.iterdir()): 
            for panel in (panels := condition.iterdir()):
                for donor in (donors := panel.iterdir()):
                    for sample in (samples := donor.iterdir()):

                        k = (segmentation.stem,condition.stem,panel.stem,donor.stem,sample.stem)
                        name = '/'.join(k)

                        # raw samples
                        if 'proseg' in segmentation.stem:
                            k_proseg = ('proseg',condition.stem,panel.stem,donor.stem,sample.stem)
                            name_proseg = '/'.join(k_proseg)
                            sample_dir = xenium_processed_data_dir / f'{name_proseg}/raw_results'
                        else:
                            sample_dir = xenium_processed_data_dir / f'{name}/normalised_results/outs'

                        sample_annotation = xenium_cell_type_annotation_dir / f'{name}/{normalisation}/reference_based/{reference}/{method}/{level}/single_cell/labels.parquet'
                        
                        if correction_method == 'raw':
                            xenium_paths[correction_method][k] = sample_dir
                            xenium_annot_paths[correction_method][k] = sample_annotation
                        
                        # corrected samples
                        else:
                            if correction_method == "split_fully_purified":
                                name_corrected = f'{name}/{normalisation}/reference_based/{reference}/{method}/{level}/single_cell/split_fully_purified/'
                                sample_corrected_counts_path = xenium_count_correction_dir / f"{name_corrected}/corrected_counts.h5"

                            else:
                                if correction_method == "resolvi":
                                    name_corrected = f'{name}/{mixture_k=}/{num_samples=}/'
                                elif correction_method == "resolvi_supervised":
                                    name_corrected = f'{name}/{normalisation}/reference_based/{reference}/{method}/{level}/{mixture_k=}/{num_samples=}'
                                elif "ovrlpy" in correction_method:
                                    name_corrected = f'{name}'

                                sample_corrected_counts_path = results_dir / f"{correction_method}/{name_corrected}/corrected_counts.h5"
                            sample_normalised_counts = xenium_std_seurat_analysis_dir / f'{name}/{normalisation}/normalised_counts/{layer}.parquet'
                            sample_idx = xenium_std_seurat_analysis_dir / f'{name}/{normalisation}/normalised_counts/cells.parquet'

                            xenium_paths[correction_method][k] = sample_corrected_counts_path


ads = readwrite.read_count_correction_samples(xenium_paths,correction_methods[1:])
ads['raw'] = readwrite.read_xenium_samples(xenium_paths['raw'],anndata=True,transcripts=False,max_workers=6)

# fix obs names for proseg expected, load cell types
# filter out cells without labels (this will apply QC thresholds as well since annotation is done after QC)
for correction_method in correction_methods:
    for k, ad in ads[correction_method].items():
        if ad is not None:

            if correction_method == 'raw':
                if k[0] == "proseg_expected": 
                    ad.obs_names = ad.obs_names.astype(str)
                    ad.obs_names = "proseg-" + ad.obs_names

                # filter cells and read labels for raw
                ad.obs[labels_key] = pd.read_parquet(xenium_annot_paths['raw'][k]).set_index("cell_id").iloc[:, 0]
                

                ad = ad[ad.obs[labels_key].notna()]
                if labels_key == "Level2.1":
                    # for custom Level2.1, simplify subtypes
                    ad.obs.loc[ad.obs[labels_key].str.contains("malignant"), labels_key] = "malignant cell"
                    ad.obs.loc[ad.obs[labels_key].str.contains("T cell"), labels_key] = "T cell"     
                    
                # remove tissue from cell type name
                ad.obs[labels_key] = ad.obs[labels_key].str.replace(r" of .+", "", regex=True)
                

            # filter cells and add labels from raw
            if correction_method != 'raw':
                ad.obs[labels_key] = ads['raw'][k].obs[labels_key]
                obs_names_idx = [c for c in ads['raw'][k].obs_names if c in ad.obs_names]
                ad = ad[obs_names_idx]
                ad.obsm['spatial'] = ads['raw'][k][obs_names_idx].obsm['spatial']

            # normalize
            sc.pp.normalize_total(ad)
            sc.pp.log1p(ad)

            # store
            ads[correction_method][k] = ad
            

# Check communication for cti with ctj neighbor

In [None]:
cti = 'T cell'
ctj = 'malignant cell'
radius = 15
obsm = 'spatial'

df_lr=liana.resource.select_resource('consensus')
df_lr_pd1_reverse = df_lr[(df_lr =='PDCD1').any(axis=1) | (df_lr == 'CD274').any(axis=1)]
df_lr_pd1_reverse.columns = ['receptor','ligand']
df_lr = pd.concat((df_lr,df_lr_pd1_reverse))

lrdatas = {}
df_plot_liana = {}
df_plot_expr_cti = {}
df_plot_expr_ctj = {}
df_ccc_score = {}
for correction_method in correction_methods:
    lrdatas[correction_method] = {}

    for k, ad in ads[correction_method].items():
        if k[1] != 'NSCLC' or k[0] not in ['proseg_expected','10x_5um'] or k[2] != 'lung':
            continue

        if ad is not None:

            if not np.all(np.isin([cti,ctj], ad.obs[labels_key].unique())):
                continue

            print(correction_method,k)


            # subset to cti ctj
            ad_cti_ctj = ad[ad.obs[labels_key].isin([cti, ctj])].copy()

            # get kNN graph
            knnlabels, knndis, knnidx, knn_graph = _utils.get_knn_labels(
                ad_cti_ctj,radius=radius,
                label_key=labels_key,obsm=obsm,
                return_sparse_neighbors=True)
                
            ad_cti_ctj.obsp[f'{obsm}_connectivities'] = knn_graph
            ad_cti_ctj.obs[f"has_{cti}_neighbor"] = knnlabels[cti] > 0
            ad_cti_ctj.obs[f"has_{ctj}_neighbor"] = knnlabels[ctj] > 0


            if (ad_cti_ctj.obs[f"has_{ctj}_neighbor"]).sum() < 30 or (~ad_cti_ctj.obs[f"has_{ctj}_neighbor"]).sum() < 30:
                print(f"Not enough cells from each class to test {cti} with {ctj} neighbors")
                continue
            
            ad_cti_ctj.obsp[f'{obsm}_connectivities'] = remove_self_edges(ad_cti_ctj.obsp[f'{obsm}_connectivities'], 
                                                                          ad_cti_ctj.obs[labels_key].values, cti, ctj)
            # liana CCC score
            lrdata = liana.mt.sp.bivariate(
                ad_cti_ctj,
                connectivity_key = f'{obsm}_connectivities',
                resource=df_lr, # NOTE: uses HUMAN gene symbols!
                local_name='cosine', # Name of the function
                global_name=None,
                n_perms=1, # Number of permutations to calculate a p-value
                mask_negatives=True, # Whether to mask LowLow/NegativeNegative interactions
                add_categories=False, # Whether to add local categories to the results
                nz_prop=0.0, # Minimum expr. proportion for ligands/receptors and their subunits
                use_raw=False,
                verbose=False
                )

            # custom CCC score
            for ctj_ligand, cti_receptor in df_lr.values:
                if ctj_ligand in ad_cti_ctj.var_names and cti_receptor in ad_cti_ctj.var_names:
                    df_ccc_score[correction_method,*k,f'{ctj_ligand}_{cti_receptor}'] = ccc_score(
                            ad_cti_ctj,
                            knn_graph,
                            cti_receptor,
                            ctj_ligand,
                            labels_key,
                            cti,
                            correction_method,
                            k,
                        )

            lrdata.var[f"has_{ctj}_neighbor_mean"] = lrdata[(lrdata.obs[labels_key] == cti) & (lrdata.obs[f"has_{ctj}_neighbor"])].X.mean(0).A1
            lrdata.var[f"has_no_{ctj}_neighbor_mean"] = lrdata[(lrdata.obs[labels_key] == cti) & (~lrdata.obs[f"has_{ctj}_neighbor"])].X.mean(0).A1

            ad_cti_ctj.var[f"has_{ctj}_neighbor_mean"] = ad_cti_ctj[(ad_cti_ctj.obs[labels_key] == cti) & (ad_cti_ctj.obs[f"has_{ctj}_neighbor"])].X.mean(0).A1
            ad_cti_ctj.var[f"has_no_{ctj}_neighbor_mean"] = ad_cti_ctj[(ad_cti_ctj.obs[labels_key] == cti) & (~ad_cti_ctj.obs[f"has_{ctj}_neighbor"])].X.mean(0).A1

            ad_cti_ctj.var[f"has_{cti}_neighbor_mean"] = ad_cti_ctj[(ad_cti_ctj.obs[labels_key] == ctj) & (ad_cti_ctj.obs[f"has_{cti}_neighbor"])].X.mean(0).A1
            ad_cti_ctj.var[f"has_no_{cti}_neighbor_mean"] = ad_cti_ctj[(ad_cti_ctj.obs[labels_key] == ctj) & (~ad_cti_ctj.obs[f"has_{cti}_neighbor"])].X.mean(0).A1

            lrdatas[correction_method][k] = lrdata
            df_plot_liana[correction_method,*k] = lrdata.var[[f"has_{ctj}_neighbor_mean",f"has_no_{ctj}_neighbor_mean"]]
            df_plot_expr_cti[correction_method,*k] = ad_cti_ctj.var[[f"has_{ctj}_neighbor_mean",f"has_no_{ctj}_neighbor_mean"]]
            df_plot_expr_ctj[correction_method,*k] = ad_cti_ctj.var[[f"has_{cti}_neighbor_mean",f"has_no_{cti}_neighbor_mean"]]

id_vars = ['correction_method'] + xenium_levels

df_plot_liana = pd.concat(df_plot_liana).reset_index()
df_plot_liana.columns = id_vars + df_plot_liana.columns[len(id_vars):].tolist()
df_plot_liana = df_plot_liana.melt(id_vars=id_vars+['interaction'])
_utils.rename_methods(df_plot_liana)
df_plot_liana['ID'] = df_plot_liana[id_vars+['variable']].astype(str).agg('-'.join, axis=1)

df_ccc_score = pd.Series(df_ccc_score).reset_index()
df_ccc_score.columns = id_vars + ["interaction","ccc_score"]
_utils.rename_methods(df_ccc_score)
df_ccc_score['ID'] = df_ccc_score[id_vars].astype(str).agg('-'.join, axis=1)

dfs_plot_expr = {(cti,ctj):df_plot_expr_cti,(ctj,cti):df_plot_expr_ctj}
for ct, df_ in dfs_plot_expr.items():
    df_ = pd.concat(df_).reset_index()
    df_.columns = id_vars + ['gene'] + df_.columns[len(id_vars)+1:].tolist()
    df_ = df_.melt(id_vars=id_vars+['gene'])
    _utils.rename_methods(df_)
    df_['ID'] = df_[id_vars+['variable']].astype(str).agg('-'.join, axis=1)
    dfs_plot_expr[ct]=df_

## Heatmap liana

In [None]:
rank_metric = 'value'
list_ref_panel = ['lung']
list_ref_segmentation = ['5µm','ProSeg']


Path('../../scratch/heatmaps_ccc/').mkdir(parents=True, exist_ok=True)
for ref_panel, ref_segmentation, in itertools.product(list_ref_panel,list_ref_segmentation):

    if ref_panel == '5k' and ref_segmentation == '5µm':
        ref_segmentation = 'MM'
    df_plot_sub = df_plot_liana.query("segmentation == @ref_segmentation and panel == @ref_panel")
    df_plot_sub = df_plot_sub[df_plot_sub['variable']==f'has_{ctj}_neighbor_mean']

    df_plot_sub[rank_metric] = df_plot_sub[rank_metric]+np.random.normal(0,1e-10,len(df_plot_sub))
    ### heatmap row/col labels
    df_col=df_plot_sub[['correction_method','segmentation','panel','variable','ID']].drop_duplicates()#.to_frame()
    df_col.set_index('ID',inplace=True)

    col_ha = HeatmapAnnotation(
        # variable=anno_simple(df_col['variable'], legend=True,add_text=True),
        correction_method=anno_simple(df_col['correction_method'], legend=True,add_text=True),
        # panel=anno_simple(df_col['panel'], cmap='Dark2',legend=False,add_text=True),
        # segmentation=anno_simple(df_col['segmentation'], cmap='Set1',legend=False,add_text=True),
        verbose=0,label_side='left')


    plt.figure(figsize=(15, 6))
    cm = DotClustermapPlotter(
        data=df_plot_sub, 
        x='ID',y='interaction',value=rank_metric,c=rank_metric,
        s=1,
        row_cluster=True,col_cluster=False,
        #hue='EnrichType', 
        #cmap={'Enrich':'RdYlGn_r','Depletion':'coolwarm_r'},
        #colors={'Enrich':'red','Depletion':'blue'},
        #marker={'Enrich':'^','Depletion':'v'},
        top_annotation=col_ha,
        # right_annotation=row_ha,
        # col_split=df_col['variable'],
        # col_split_order=['has_malignant cell_neighbor_mean','has_no_malignant cell_neighbor_mean'],
        col_split=df_col['correction_method'],
        col_split_order=hue_correction_order,
        # row_split=df_row['gene_label'], 
        # row_split_order=[k for k in gene_labels.keys() if k in df_row['gene_label'].unique()],

        cmap='Purples',#center=0,
        col_split_gap=0,row_split_gap=5,
        show_rownames=True,show_colnames=False,row_dendrogram=False,
        verbose=0,legend_gap=10,linewidths=.1,spines=True,
        # vmin=-0.04,vmax=0.04,
        vmax=np.percentile(df_plot_sub[rank_metric],95),
        ) #if the size of dot in legend is too large, use alpha to control, for example: alpha=0.8

    plt.savefig(f"../../scratch/heatmaps_ccc/heatmap_ccc_{ref_panel=}_{ref_segmentation=}.png",bbox_inches='tight',dpi=300)
    plt.show()

## Heatmap custom CCC

In [None]:
rank_metric = 'ccc_score'
list_ref_panel = ['lung']
list_ref_segmentation = ['5µm','ProSeg']
u_genes = np.unique(df_lr.values.flatten())

for (cti_,ctj_), df_ in dfs_plot_expr.items():
    df_plot_expr_lr = df_[df_['gene'].isin(u_genes)]

    Path('../../scratch/heatmaps_ccc/').mkdir(parents=True, exist_ok=True)
    for ref_panel, ref_segmentation, in itertools.product(list_ref_panel,list_ref_segmentation):

        if ref_panel == '5k' and ref_segmentation == '5µm':
            ref_segmentation = 'MM'
        df_plot_sub = df_plot_expr_lr.query("segmentation == @ref_segmentation and panel == @ref_panel")

        df_plot_sub[rank_metric] = df_plot_sub[rank_metric]+np.random.normal(0,1e-10,len(df_plot_sub))
        ### heatmap row/col labels
        df_col=df_plot_sub[['correction_method','segmentation','panel','variable','ID']].drop_duplicates()#.to_frame()
        df_col.set_index('ID',inplace=True)

        col_ha = HeatmapAnnotation(
            variable=anno_simple(df_col['variable'], legend=True,add_text=True),
            correction_method=anno_simple(df_col['correction_method'], legend=True,add_text=True),
            # panel=anno_simple(df_col['panel'], cmap='Dark2',legend=False,add_text=True),
            # segmentation=anno_simple(df_col['segmentation'], cmap='Set1',legend=False,add_text=True),
            verbose=0,label_side='left')


        plt.figure(figsize=(15, 25))
        cm = DotClustermapPlotter(
            data=df_plot_sub, 
            x='ID',y='gene',value=rank_metric,c=rank_metric,
            s=None,
            row_cluster=True,col_cluster=False,
            #hue='EnrichType', 
            #cmap={'Enrich':'RdYlGn_r','Depletion':'coolwarm_r'},
            #colors={'Enrich':'red','Depletion':'blue'},
            #marker={'Enrich':'^','Depletion':'v'},
            top_annotation=col_ha,
            # right_annotation=row_ha,
            col_split=df_col['variable'],
            col_split_order=[f'has_{ctj_}_neighbor_mean',f'has_no_{ctj_}_neighbor_mean'],
            # row_split=df_row['gene_label'], 
            # row_split_order=[k for k in gene_labels.keys() if k in df_row['gene_label'].unique()],

            cmap='Purples',#center=0,
            col_split_gap=0,row_split_gap=5,
            show_rownames=True,show_colnames=False,row_dendrogram=False,
            verbose=0,legend_gap=10,linewidths=.1,spines=True,
            # vmin=-0.04,vmax=0.04,
            vmin=0.0,vmax=1,
            ) #if the size of dot in legend is too large, use alpha to control, for example: alpha=0.8

        plt.savefig(f"../../scratch/heatmaps_ccc/heatmap_ccc_expr_{cti_}_{ref_panel=}_{ref_segmentation=}.png",bbox_inches='tight',dpi=300)
        plt.show()