In [1]:
import dask

dask.config.set({"dataframe.query-planning": False})

import ot
import numpy as np
from pathlib import Path
import pandas as pd
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.pyplot as plt
from itertools import product
from tqdm.auto import tqdm
from scipy.spatial.distance import cdist, pdist, squareform
from matplotlib.patches import Patch
import fastcluster

import sys
sys.path.extend(['../../scripts','../../scripts/xenium'])
import readwrite
import preprocessing

cfg = readwrite.config()

def compute_energy_distance(
    adata,
    label_key: str,
    batch_key: str,
    use_rep: str = 'X_pca',
    n_subsample = None,
    random_state: int = 0
) -> pd.DataFrame:
    """
    (Robust Version) Computes the Energy Distance matrix by building a NumPy
    array first and then converting it to a pandas DataFrame.

    Args:
        adata: The AnnData object, expected to have PCA computed in .obsm.
        label_key: The key in adata.obs for cell type labels.
        batch_key: The key in adata.obs for batch/patient labels.
        use_rep: The representation to use from adata.obsm (e.g., 'X_pca').
        n_subsample: If not None, randomly subsamples this many cells from each
                     group before calculating the distance.
        random_state: Seed for the random subsampling for reproducibility.

    Returns:
        A square pandas DataFrame with a multi-index (cell_type, batch)
        containing the pairwise Energy Distances.
    """
    rng = np.random.default_rng(random_state)
    unique_labels = adata.obs[label_key].unique()
    unique_batches = adata.obs[batch_key].unique()
    combinations = list(product(unique_labels, unique_batches))

    combo_data = {}
    for label, batch in combinations:
        mask = (adata.obs[label_key] == label) & (adata.obs[batch_key] == batch)
        n_cells = np.sum(mask)
        if n_cells == 0:
            continue
        if n_subsample is not None and n_cells > n_subsample:
            full_data = adata[mask].obsm[use_rep]
            indices = rng.choice(n_cells, n_subsample, replace=False)
            combo_data[(label, batch)] = full_data[indices, :]
        else:
            combo_data[(label, batch)] = adata[mask].obsm[use_rep]

    combo_names = list(combo_data.keys())
    n_combos = len(combo_names)

    # --- THE ROBUST PATTERN: Use a NumPy array first ---
    dist_array = np.zeros((n_combos, n_combos), dtype=float)

    pbar_total = (n_combos**2 - n_combos) // 2
    pbar = tqdm(total=pbar_total, desc="Calculating Energy Distance")
    for i in range(n_combos):
        for j in range(i, n_combos):
            if i == j:
                continue # Distance is already 0.0
            
            name_i, name_j = combo_names[i], combo_names[j]
            X = combo_data[name_i]
            Y = combo_data[name_j]
            
            between_dist = cdist(X, Y).mean()
            within_dist_X = pdist(X).mean() if len(X) > 1 else 0
            within_dist_Y = pdist(Y).mean() if len(Y) > 1 else 0
            
            squared_edist = 2 * between_dist - within_dist_X - within_dist_Y
            dist = np.sqrt(max(0, squared_edist))
            pbar.update(1)
            
            # Use simple, unambiguous integer indexing on the NumPy array
            dist_array[i, j] = dist
            dist_array[j, i] = dist
            
    pbar.close()
    
    multi_index = pd.MultiIndex.from_tuples(combo_names, names=[label_key, batch_key])
    dist_matrix = pd.DataFrame(dist_array, index=multi_index, columns=multi_index)
    
    return dist_matrix

def compute_emd(
    adata,
    label_key: str,
    batch_key: str,
    use_rep: str = 'X_pca',
    n_subsample = None,
    random_state: int = 0
) -> pd.DataFrame:
    """  Computes the EMD matrix. """
    rng = np.random.default_rng(random_state)

    unique_labels = adata.obs[label_key].unique()
    unique_batches = adata.obs[batch_key].unique()
    combinations = list(product(unique_labels, unique_batches))

    combo_data = {}
    for label, batch in combinations:
        mask = (adata.obs[label_key] == label) & (adata.obs[batch_key] == batch)
        n_cells = np.sum(mask)
        if n_cells == 0:
            continue
        if n_subsample is not None and n_cells > n_subsample:
            full_data = adata[mask].obsm[use_rep]
            indices = rng.choice(n_cells, n_subsample, replace=False)
            combo_data[(label, batch)] = full_data[indices, :]
        else:
            combo_data[(label, batch)] = adata[mask].obsm[use_rep]

    combo_names = list(combo_data.keys())
    n_combos = len(combo_names)

    # --- THE ROBUST PATTERN: Use a NumPy array first ---
    dist_array = np.zeros((n_combos, n_combos), dtype=float)

    pbar = tqdm(total=(n_combos**2 - n_combos) // 2, desc="Calculating EMD")
    for i in range(n_combos):
        for j in range(i, n_combos):
            if i == j:
                continue # Distance is already 0
            
            name_i, name_j = combo_names[i], combo_names[j]
            data_i, data_j = combo_data[name_i], combo_data[name_j]
            
            weights_i = np.ones(len(data_i)) / len(data_i)
            weights_j = np.ones(len(data_j)) / len(data_j)
            
            cost_matrix = ot.dist(data_i, data_j, metric='euclidean')
            dist = ot.emd2(weights_i, weights_j, cost_matrix)
            pbar.update(1)
            
            # Use integer indices .iloc style on the NumPy array
            dist_array[i, j] = dist
            dist_array[j, i] = dist
            
    pbar.close()
    
    # --- Convert to a DataFrame at the very end ---
    multi_index = pd.MultiIndex.from_tuples(combo_names, names=[label_key, batch_key])
    dist_matrix = pd.DataFrame(dist_array, index=multi_index, columns=multi_index)
    
    return dist_matrix



def compute_euclidean_distance(
    adata,
    label_key: str,
    batch_key: str,
    use_rep: str = 'X_pca',
    n_subsample = None,
    random_state: int = 0
) -> pd.DataFrame:
    """
    Computes a Euclidean distance matrix between all cells, returning it in the
    same multi-index format as the distributional distance functions.

    The resulting DataFrame's index and columns will be a MultiIndex based on
    the provided label_key and batch_key from adata.obs.

    Args:
        adata: The AnnData object.
        label_key: The key in adata.obs for the first level of the MultiIndex (e.g., 'cell_type').
        batch_key: The key in adata.obs for the second level of the MultiIndex (e.g., 'sample').
        use_rep: The representation to use from adata.obsm (e.g., 'X_pca').

    Returns:
        A square pandas DataFrame where the index and columns are a MultiIndex
        of (label_key, batch_key) for each cell, and the values are the
        pairwise Euclidean distances.
    """
    
    # --- 1. Input Validation ---
    if use_rep not in adata.obsm:
        raise KeyError(f"Representation '{use_rep}' not found in adata.obsm. Please compute it first.")
    if label_key not in adata.obs.columns:
        raise KeyError(f"label_key '{label_key}' not found in adata.obs.")
    if batch_key not in adata.obs.columns:
        raise KeyError(f"batch_key '{batch_key}' not found in adata.obs.")
        

    # --- 2. Core Distance Calculation (same as before) ---
    X = adata.obsm[use_rep]
    annot = adata.obs[[label_key, batch_key]]

    n_cells = X.shape[0]

    rng = np.random.default_rng(random_state)

    if n_subsample is not None and n_cells > n_subsample:
        indices = rng.choice(n_cells, n_subsample, replace=False)
        X = X[indices, :]
        annot = annot.iloc[indices]

    condensed_distances = pdist(X, metric='euclidean')
    distance_matrix_np = squareform(condensed_distances)

    # --- 3. Create the MultiIndex for Labeling ---
    # This is the key change to match the desired format.
    # We create a MultiIndex directly from the columns of the .obs DataFrame.
    # The index will have the same length as the number of cells.
    multi_index = pd.MultiIndex.from_frame(annot)

    # --- 4. Wrap the result in a Labeled DataFrame ---
    distance_matrix_df = pd.DataFrame(
        distance_matrix_np,
        index=multi_index,
        columns=multi_index
    )
    
    
    return distance_matrix_df

def plot_annotated_heatmap(
    dist_matrix: pd.DataFrame,
    label_palette: dict = None,
    batch_palette: dict = None,
    linkage_method: str = 'average',
    save_path: str = None,
    show_label_legend: bool = True,
    title = None,
    show=False,
):
    """
    (From Scratch Legends) Plots a clustered heatmap with manually created and
    placed legends for maximum control and clarity.
    """

    # --- 1. Data and Palette Preparation (same as before) ---
    label_key, batch_key = dist_matrix.index.names
    labels = dist_matrix.index.get_level_values(label_key)
    batches = dist_matrix.index.get_level_values(batch_key)
    
    annot_df = pd.DataFrame({label_key: labels, batch_key: batches}, index=dist_matrix.index)
    
    if label_palette is None:
        unique_labels = annot_df[label_key].unique()
        label_palette = dict(zip(unique_labels, sns.color_palette("tab10", len(unique_labels))))
    
    if batch_palette is None:
        unique_batches = annot_df[batch_key].unique()
        batch_palette = dict(zip(unique_batches, sns.color_palette("Pastel1", len(unique_batches))))
        
    row_colors_df = annot_df.copy()
    row_colors_df[label_key] = annot_df[label_key].map(label_palette)
    row_colors_df[batch_key] = annot_df[batch_key].map(batch_palette)

    # linkage
    condensed_dist_matrix = squareform(dist_matrix.values)
    linkage = fastcluster.linkage(condensed_dist_matrix, method=linkage_method)
    
    # --- 2. Plot the Clustermap (without worrying about its legends) ---
    g = sns.clustermap(
        dist_matrix,
        cmap='viridis_r',
        row_colors=row_colors_df,
        col_colors=row_colors_df,
        row_linkage=linkage,
        col_linkage=linkage,
        # linewidths=.5,
        figsize=(12, 12),
        xticklabels=False,
        yticklabels=False,
        colors_ratio=0.05,
        dendrogram_ratio=(.1, .1)
    )
    g.ax_heatmap.set_xlabel("")
    g.ax_heatmap.set_ylabel("")
    # --- 3. Create Legends "From Scratch" ---

    # Create the graphical "handles" for the label legend
    if show_label_legend:
        label_legend_patches = [
            Patch(facecolor=color, edgecolor='black', label=label) 
            for label, color in label_palette.items() if label in labels
        ]
        
        # Place the first legend on the figure
        legend1 = g.fig.legend(
            handles=label_legend_patches,
            title=label_key,
            loc="upper right",
            bbox_to_anchor=(1.2, 0.5), # Position: 2% from left, 95% from bottom
            fontsize='medium',
            frameon=True,
            shadow=False,
        )

    # Create the graphical "handles" for the batch legend
    batch_legend_patches = [
        Patch(facecolor=color, edgecolor='black', label=label) 
        for label, color in batch_palette.items() if label in batches
    ]

    # Place the second legend on the figure
    legend2 = g.fig.legend(
        handles=batch_legend_patches,
        title=batch_key,
        loc="upper right",
        bbox_to_anchor=(1.2, 0.8), # Position: 98% from left, 95% from bottom
        fontsize='medium',
        frameon=True,
        shadow=False,
    )
    
    # --- 4. Final Adjustments ---
    
    # Set the main title for the entire figure
    if title:
        g.fig.suptitle(title, y=1.05, fontsize=16, weight='bold')
    
    if save_path is not None:
        p = Path(save_path)
        p.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(p, dpi=300, bbox_inches='tight')
    if show:
        plt.show()
    else:
        plt.close()

  from .autonotebook import tqdm as notebook_tqdm
  left = partial(_left_join_spatialelement_table)
  left_exclusive = partial(_left_exclusive_join_spatialelement_table)
  inner = partial(_inner_join_spatialelement_table)
  right = partial(_right_join_spatialelement_table)
  right_exclusive = partial(_right_exclusive_join_spatialelement_table)


## Params

In [2]:
# params
cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])
xenium_processed_data_dir = Path(cfg['xenium_processed_data_dir'])
xenium_std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
normalisation = 'lognorm'
layer = 'data'
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2.1'
n_comps = 50
max_n_cells = 100_000
singlets = False

# qc params
min_counts = 10
min_features = 5
max_counts = float("inf")
max_features = float("inf")
min_cells = 5

# common genes and samples to use
genes = []
samples = []

# fixed params
OBSM_KEY = "X_pca"
CT_KEY = 'cell type'#(reference, method, level)
BATCH_KEY = "sample"
annotation_normalisation = "lognorm"  # fix this for now, even for sctransfrom

palette_dir = Path(cfg['xenium_metadata_dir'])
sample_palette = pd.read_csv(palette_dir / 'col_palette_sample.csv').set_index('sample').squeeze()
cell_type_palette_path = palette_dir / 'col_palette_cell_types_combo.csv'


if level == "Level2.1":
    palette_lvl2 = (
        pd.read_csv(cell_type_palette_path)[["Level2", "cols_Level2"]].drop_duplicates().set_index("Level2").squeeze()
    )
    cell_type_palette = pd.read_csv(cell_type_palette_path)[[level, f"cols_{level}"]].drop_duplicates().set_index(level).squeeze()
    for k, v in palette_lvl2.items():
        if k not in cell_type_palette.index:
            cell_type_palette[k] = palette_lvl2[k]

else:
    cell_type_palette = pd.read_csv(cell_type_palette_path)[[level, f"cols_{level}"]].drop_duplicates().set_index(level).squeeze()


## Compute metrics

In [3]:
correction_method = 'raw'

for condition, segmentation, panel_name in [
    ('NSCLC', '10x_5um', 'chuvio'),
    # ('NSCLC', '10x_5um', 'lung'),
    # ('NSCLC', '10x_5um', '5k'),
    # ('breast', '10x_5um', 'breast'), 
    ]:

    panel = xenium_std_seurat_analysis_dir / f"{segmentation}/{condition}/{panel_name}"

    # read xenium samples
    print(f"Processing {condition} {segmentation} {panel_name}")
    ads = {}
    for donor in (donors := panel.iterdir()):
        for sample in (samples_ := donor.iterdir()):
            if len(samples) and sample.stem not in samples:
                continue

                print(donor.stem, sample.stem)

            if segmentation == "proseg_expected":
                k = ("proseg", condition, panel.stem, donor.stem, sample.stem)
                k_annot = (segmentation, condition, panel.stem, donor.stem, sample.stem)

                name_sample = "/".join(k)
                name_annot = "/".join(k_annot)
                sample_dir = xenium_processed_data_dir / f"{name_sample}/raw_results"
            else:
                k = (segmentation.replace("proseg_mode", "proseg"), condition, panel.stem, donor.stem, sample.stem)
                k_annot = (segmentation, condition, panel.stem, donor.stem, sample.stem)
                name_sample = "/".join(k)
                name_annot = "/".join(k_annot)
                sample_dir = xenium_processed_data_dir / f"{name_sample}/normalised_results/outs"

            sample_normalised_counts_path = sample / f"{normalisation}/normalised_counts/{layer}.parquet"
            sample_idx_path = sample / f"{normalisation}/normalised_counts/cells.parquet"

            # read normalised data
            X_normalised = pd.read_parquet(sample_normalised_counts_path)
            X_normalised.index = pd.read_parquet(sample_idx_path).iloc[:, 0]
            X_normalised.columns = X_normalised.columns.str.replace(".", "-")  # undo seurat renaming

            if len(genes):
                # load raw data to reapply lower bounds QC filters
                ads[k] = readwrite.read_xenium_sample(sample_dir, anndata=True)
                if segmentation == "proseg_expected":
                    ads[k].obs_names = "proseg-" + ads[k].obs_names.astype(str)

                # filter cells
                ads[k] = ads[k][X_normalised.index, X_normalised.columns]
                ads[k].layers["X_normalised"] = X_normalised
                if layer != "scale_data":  # no need to sparsify scale_data which is dense
                    ads[k].layers["X_normalised"] = scipy.sparse.csr_matrix(ads[k].layers["X_normalised"])
            else:
                ads[k] = sc.AnnData(X_normalised)
                if layer != "scale_data":  # no need to sparsify scale_data which is dense
                    ads[k].X = scipy.sparse.csr_matrix(ads[k].X)

            # read cell type annotation
            sample_annotation_dir = cell_type_annotation_dir / f"{name_annot}/{annotation_normalisation}/reference_based"
            annot_file = sample_annotation_dir / f"{reference}/{method}/{level}/single_cell/labels.parquet"
            ads[k].obs[CT_KEY] = pd.read_parquet(annot_file).set_index("cell_id").iloc[:, 0]

            if singlets:
                # read spot class
                spot_class_file = (
                    sample_annotation_dir / f"{reference}/{method}/{level}/single_cell/output/results_df.parquet"
                )

                ads[k].obs["spot_class"] = pd.read_parquet(spot_class_file, columns=["cell_id", "spot_class"]).set_index(
                    "cell_id"
                )
                ads[k] = ads[k][ads[k].obs["spot_class"] == "singlet"]


    print("Concatenating")
    # concatenate
    xenium_levels = ["segmentation", "condition", "panel", "donor", "sample"]
    for k in ads.keys():
        for i, lvl in enumerate(xenium_levels):
            ads[k].obs[lvl] = k[i]
    ad_merge = sc.concat(ads)
    print("Done")

    # subset genes
    if len(genes):
        print("Subsetting")

        genes_found = [
            g
            for g in ad_merge.var_names
            if (g in genes) or (g.replace(".", "-") in genes)  # possible seurat renaming
        ]

        print(f"Found {len(genes_found)} out of {len(genes)} genes.")
        ad_merge = ad_merge[:, genes_found].copy()
        # reapply QC to subset of genes
        preprocessing.preprocess(
            ad_merge,
            min_counts=min_counts,
            min_genes=min_features,
            max_counts=max_counts,
            max_genes=max_features,
            min_cells=min_cells,
            save_raw=False,
        )
        # replace X
        ad_merge.X = ad_merge.layers["X_normalised"]

    # remove NaN  annotations
    ad_merge = ad_merge[ad_merge.obs[CT_KEY].notna()]

    #simplify malignant annot
    # if condition in ["NSCLC","mesothelioma_pilot"]:
    #     name_malignant = "malignant cell of lung"
    # elif condition == "breast":
    #     name_malignant = "malignant cell of breast"
    # else:
    #     name_malignant = "malignant cell"

    # ct_to_replace = ad_merge.obs[CT_KEY][ad_merge.obs[CT_KEY].str.contains("malignant cell")].unique()
    # replace_map = dict([[ct, name_malignant] for ct in ct_to_replace])
    # ad_merge.obs[CT_KEY] = ad_merge.obs[CT_KEY].replace(replace_map)

    # subsample to reasonable size
    # if len(ad_merge) > max_n_cells:
    #     sc.pp.subsample(ad_merge, n_obs=max_n_cells)

    # compute pca
    sc.tl.pca(ad_merge, n_comps=n_comps)

    D = compute_energy_distance(
        ad_merge,
        label_key=CT_KEY,
        batch_key='sample',
        use_rep="X_pca",
        n_subsample=1000
    )
    plot_annotated_heatmap(D,label_palette=cell_type_palette,batch_palette=sample_palette,
        save_path=cfg['figures_dir']+f'revision/{correction_method}/{segmentation}/{condition}/{panel.stem}/edistance_heatmap_{level}.png')


    D = compute_euclidean_distance(
        ad_merge,
        label_key=CT_KEY,
        batch_key='sample',
        use_rep="X_pca",
        n_subsample=10000
    )
    plot_annotated_heatmap(D,label_palette=cell_type_palette,batch_palette=sample_palette,
        save_path=cfg['figures_dir']+f'revision/{correction_method}/{segmentation}/{condition}/{panel.stem}/euclidean_distance_heatmap_{level}.png')


    for ct in ad_merge.obs[CT_KEY].unique():
        if 'malignant' in ct:
            continue
        D_ct = compute_euclidean_distance(
            ad_merge[ad_merge.obs[CT_KEY]==ct],
            label_key=CT_KEY,
            batch_key='sample',
            use_rep="X_pca",
            n_subsample=5000,
        )
        plot_annotated_heatmap(D_ct,label_palette=cell_type_palette,batch_palette=sample_palette,
            save_path=cfg['figures_dir']+f'revision/{correction_method}/{segmentation}/{condition}/{panel.stem}/euclidean_heatmap_{ct}_{level}.png',
            show_label_legend=False,
            title=ct
        )


    D_ct = compute_euclidean_distance(
        ad_merge[ad_merge.obs[CT_KEY].str.contains("malignant")],
        label_key=CT_KEY,
        batch_key='sample',
        use_rep="X_pca",
        n_subsample=5000,
    )
    plot_annotated_heatmap(D_ct,label_palette=cell_type_palette,batch_palette=sample_palette,
        save_path=cfg['figures_dir']+f'revision/{correction_method}/{segmentation}/{condition}/{panel.stem}/euclidean_heatmap_malignant_{level}.png',
        show_label_legend=False,
        title = "malignant cell"

    )

Processing NSCLC 10x_5um chuvio
Concatenating


  utils.warn_names_duplicates("obs")


Done


  adata.obsm[key_obsm] = X_pca
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
Calculating Energy Distance: 100%|██████████| 4005/4005 [01:40<00:00, 39.76it/s] 
