In [2]:
import dask
dask.config.set({"dataframe.query-planning": False})

import scanpy as sc
import scipy
import numpy as np
import pandas as pd
import sys
import matplotlib.patches as mpatches
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

sys.path.append("../../../workflow/scripts/")
import _utils
import readwrite
cfg = readwrite.config()

## Compute metrics diffexpr logreg

In [None]:
segmentation = '10x_mm_5um'
condition = 'NSCLC'
panel = '5k'
donor = '1GQ9'
sample = '1GQ9'
normalisation = 'lognorm'
layer = 'data'
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2.1'

k = (segmentation,condition,panel,donor,sample)
if 'proseg' in segmentation:
    k_dir = ('proseg',condition,panel,donor,sample)
else:
    k_dir = k
name = '/'.join(k)
name_dir = '/'.join(k_dir)
sample_corrected_counts_path = Path(f"../../../results/resolvi_supervised/{name}/lognorm/reference_based/matched_reference_combo/rctd_class_aware/Level2.1/mixture_k=50/num_samples=30/corrected_counts.h5")
sample_dir = Path(f'../../../data/xenium/processed/segmentation/{name_dir}') / 'normalised_results/outs' #'raw_results'
sample_counts = Path(f'../../../data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/data.parquet')
sample_idx = Path(f'../../../data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/cells.parquet')
sample_annotation = Path(f'../../../data/xenium/_backups/_processed_problematic_matched_combo/cell_type_annotation/{name}/lognorm/reference_based/matched_reference_combo/rctd_class_aware/{level}/single_cell/labels.parquet')
# ctj_markers_file = Path(f'../../../results/contamination_metrics_diffexpr/{name}/lognorm/{layer}_{reference}_{method}_{level}_marker_genes.parquet')
ctj_markers_file = None
# precomputed_adata_obs = adata.obs.copy()
precomputed_adata_obs = None

out_file_df_permutations = sample_dir / 'permutation_summary.parquet'
out_file_df_importances = sample_dir / 'importances.parquet'
out_file_df_diffexpr = sample_dir / 'diffexpr.parquet'
out_file_df_markers_rank_significance_logreg = sample_dir / 'markers_rank_significance_logreg.json'
out_file_df_markers_rank_significance_diffexpr = sample_dir / 'markers_rank_significance_diffexpr.json'
# out_dir_liana_lrdata = sample_dir / 'liana_lrdata_folder'

n_markers = 50
pct_expression_threshold = 0.05
radius = 20
n_neighbors = None
n_permutations = 30
n_splits= 5
top_n = 20
top_n_lr = 10
cti = "macrophage"
ctj = "malignant cell"
scoring = 'f1'
markers = 'diffexpr'
# markers = "xenium_common_markers_file"
rank_metrics = ["logfoldchanges", "-log10pvals_x_logfoldchanges", "-log10pvals_x_sign_logfoldchanges"]
index_diffexpr_metrics=["Name","Term","ES","NES","NOM p-val","FDR q-val","FWER p-val","Tag %","Gene %","Lead_genes","hypergeometric_pvalue","mean_zscore","mean_zscore_pvalue"]
label_key = "label_key"

####
#### READ DATA
####
# read raw data to get spatial coordinates
adata = readwrite.read_xenium_sample(
    sample_dir,
    cells_as_circles=False,
    cells_boundaries=False,
    cells_boundaries_layers=False,
    nucleus_boundaries=False,
    cells_labels=False,
    nucleus_labels=False,
    transcripts=False,
    morphology_mip=False,
    morphology_focus=False,
    aligned_images=False,
    anndata=True,
)
if 'proseg_expected' in sample_counts.as_posix():
    adata.obs_names = 'proseg-'+adata.obs_names.astype(str)

# read corrected counts
if sample_corrected_counts_path is not None:
    adata_corrected_counts = sc.read_10x_h5(
        sample_corrected_counts_path,
    )

    adata_corrected_counts.obsm["spatial"] = adata[adata_corrected_counts.obs_names].obsm["spatial"]
    adata = adata_corrected_counts


# read normalised data, filter cells
X_normalised = pd.read_parquet(sample_counts)
X_normalised.index = pd.read_parquet(sample_idx).iloc[:, 0]
X_normalised.columns = X_normalised.columns.str.replace('.','-')
obs_found = [c for c in X_normalised.index if c in adata.obs_names]
var_found = [g for g in X_normalised.columns if g in adata.var_names]
adata = adata[obs_found,var_found]
adata.layers['X_normalised'] = X_normalised.loc[obs_found,var_found]

# # reapply QC to corrected counts data
# preprocessing.preprocess(
#     adata,
#     min_counts=min_counts,
#     min_genes=min_features,
#     max_counts=max_counts,
#     max_genes=max_features,
#     min_cells=min_cells,
#     save_raw=False,
# )

# read labels
adata.obs[label_key] = pd.read_parquet(sample_annotation).set_index("cell_id").iloc[:, 0]
adata = adata[adata.obs[label_key].notna()]


if "Level2.1" in sample_annotation.as_posix():
    # for custom Level2.1, simplify malignant subtypes to malignant
    adata.obs.loc[adata.obs[label_key].str.contains("malignant"), label_key] = "malignant cell"
    adata.obs.loc[adata.obs[label_key].str.contains("T cell"), label_key] = "T cell"

# read markers if needed
if markers != "diffexpr":
    if markers == "xenium_common_markers_file":
        level_simplified = 'Level1'
        palette = pd.read_csv('/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/metadata/col_palette_cell_types_combo.csv')
        cell_types_mapping = palette.set_index(level)[level_simplified].replace(r' of .+', '', regex=True)
        cell_types_mapping[cell_types_mapping.str.contains('malignant')] = 'malignant cell'
        adata.obs[label_key] = adata.obs[label_key].replace(cell_types_mapping)
        df_markers = pd.read_csv('/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/Xenium_panels_common_markers.csv')[["cell_type","gene_name"]]
    else:
        df_markers = pd.read_csv(markers)[["cell_type","gene_name"]]

    ct_not_found = adata.obs[label_key][~adata.obs[label_key].isin(df_markers['cell_type'])].unique()
    print(f"Could not find {ct_not_found} in markers file")
    adata = adata[adata.obs[label_key].isin(df_markers['cell_type'])]
else:
    # get precomputed markers from raw data
    if ctj_markers_file is not None:
        print(f"Loading precomputed {ctj} markers")
        df_ctj_marker_genes_precomputed = pd.read_parquet(ctj_markers_file)

# define target (cell type j presence in kNN)
if precomputed_adata_obs is not None:
    print("Loading precomputed adata obs")
    # adata.obs = pd.read_parquet(precomputed_adata_obs)
    adata.obs = precomputed_adata_obs
else:
    # get kNN graph
    obsm = "spatial"
    knnlabels, knndis, knnidx, knn_graph = _utils.get_knn_labels(
        adata, n_neighbors=n_neighbors, radius=radius, label_key=label_key, obsm=obsm, return_sparse_neighbors=True
    )

    adata.obsp[f"{obsm}_connectivities"] = knn_graph

# summary stats
adata.obs["n_genes"] = (adata.X > 0).sum(axis=1).A1
adata.obs["n_counts"] = (adata.X).sum(axis=1).A1.astype(float)
df_percent_expressed = _utils.get_expression_percent_per_celltype(adata=adata, label_key=label_key)

# log-normalize before DE
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)


# iterate over targets permutations (cell type i with cell type j presence in kNN)
df_diffexpr = {}
df_markers_rank_significance_diffexpr = {}
df_markers_rank_significance_diffexpr_expressed = {}
u_cell_types = adata.obs[label_key].unique()
df_ctj_marker_genes = {}

for ctj in u_cell_types:
    if (adata.obs[label_key] == ctj).sum() < 30:
        print(f"Not enough cells from class {ctj}")
        continue

    # get markers
    if markers == "diffexpr":
        if ctj_markers_file is not None:
            print(f"Loading precomputed {ctj} markers")
            ctj_marker_genes_precomputed = df_ctj_marker_genes_precomputed[ctj]

        sc.tl.rank_genes_groups(adata, groupby=label_key, groups=[ctj], reference="rest", method="wilcoxon")
        ctj_marker_genes = sc.get.rank_genes_groups_df(adata, group=ctj)["names"][: 50].tolist()
    else:
        ctj_marker_genes = df_markers[df_markers["cell_type"] == ctj]["gene"].tolist()
        ctj_marker_genes = [g for g in ctj_marker_genes if g in adata.var_names]

        assert len(ctj_marker_genes), f"no markers found for {ctj}"

    df_ctj_marker_genes[ctj] = ctj_marker_genes

    for cti in u_cell_types:
        if cti == ctj:
            continue
        print(cti, ctj)
        
        if precomputed_adata_obs is None:
            adata.obs[f"has_{ctj}_neighbor"] = knnlabels[ctj] > 0

        # Filter for cti
        adata_cti = adata[adata.obs[label_key] == cti]

        if (adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30 or (~adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30:
            print(f"Not enough cells from each class to test {cti} with {ctj} neighbors")
            continue

        expressed_genes_cti = df_percent_expressed[df_percent_expressed[cti] > pct_expression_threshold].index

        ###
        ### DIFF EXPR TEST: check DE genes between cti with ctj neighbor or not
        ###

        adata_cti.obs[f"has_{ctj}_neighbor_str"] = adata_cti.obs[f"has_{ctj}_neighbor"].astype(str)
        sc.tl.rank_genes_groups(
            adata_cti, groupby=f"has_{ctj}_neighbor_str", groups=["True"], reference="False", method="wilcoxon"
        )
        df_diffexpr[cti, ctj] = sc.get.rank_genes_groups_df(adata_cti, group="True")

        # add ranking score by -log10pvals_x_logfoldchanges. Use offset to avoid -log10(pval) = inf
        pvals = df_diffexpr[cti, ctj]["pvals"]
        df_diffexpr[cti, ctj]["pvals_offset"] = pvals + pvals[pvals > 0].min() * 0.1
        df_diffexpr[cti, ctj]["-log10pvals_x_logfoldchanges"] = (
            -np.log10(df_diffexpr[cti, ctj]["pvals_offset"]) * df_diffexpr[cti, ctj]["logfoldchanges"]
        )

        # add ranking score by log10pvals_x_signlogFC
        df_diffexpr[cti, ctj]["-log10pvals_x_sign_logfoldchanges"] = -np.log10(df_diffexpr[cti, ctj]["pvals"]) * np.sign(
            df_diffexpr[cti, ctj]["logfoldchanges"]
        )

        # get significance from gsea and hypergeometric test
        df_markers_rank_significance_diffexpr[cti, ctj] = pd.DataFrame(index=index_diffexpr_metrics)
        df_markers_rank_significance_diffexpr_expressed[cti, ctj] = pd.DataFrame(index=index_diffexpr_metrics)
        dict_ctj_marker_genes = {'':ctj_marker_genes}

        if ctj_markers_file is not None:
            # also compute scores for precomputed marker gene list
            dict_ctj_marker_genes['_precomputed'] = (ctj_marker_genes_precomputed)

        for n in [10,20,30,40,50]:
            for k_, markers_ in dict_ctj_marker_genes.items():
                markers_n_ = markers_[:n]
                for rank_metric in rank_metrics:
                    df_markers_rank_significance_diffexpr[cti, ctj][rank_metric + k_ + f'_{n=}'] = (
                        _utils.get_marker_rank_significance(
                            rnk=df_diffexpr[cti, ctj].set_index("names")[rank_metric].sort_values(ascending=False),
                            gene_set=markers_n_,
                            top_n=top_n,
                        ).iloc[0]
                    )


# count number of True/False for each has_{ctj}_neighbor column
cols = [f"has_{ctj}_neighbor" for ctj in u_cell_types]
df_has_neighbor_counts = (
    adata.obs.melt(id_vars=[label_key], value_vars=cols).groupby([label_key, 'variable'],observed=True)['value']
    .value_counts()
    .reset_index(name='count')
)

adata.obs['n_genes'] = (adata.X>0).sum(axis=1).A1

###
### SAVE OUTPUTS
###
# general stats
summary_stats = {
    "n_cells": len(adata),
    "n_cells_by_type": adata.obs[label_key].value_counts().to_dict(),
    "mean_n_genes_by_type": adata.obs.groupby(label_key,observed=True)["n_genes"].mean().to_dict(),
    "median_n_genes_by_type": adata.obs.groupby(label_key,observed=True)["n_genes"].median().to_dict(),
    "mean_n_genes": adata.obs["n_genes"].mean(),
    "median_n_genes": adata.obs["n_genes"].median(),
    "df_has_neighbor_counts": df_has_neighbor_counts.to_dict(),  # Storing the DataFrame
}

# with open("summary_stats.json", "w") as f:
#     json.dump(summary_stats,f)

# df_permutations_logreg = pd.concat(df_permutations_logreg)
# df_importances_logreg = pd.concat(df_importances_logreg)
# df_diffexpr = pd.concat(df_diffexpr)
# df_markers_rank_significance_logreg = pd.concat(df_markers_rank_significance_logreg)
# df_markers_rank_significance_diffexpr = pd.concat(df_markers_rank_significance_diffexpr)

#logreg
# df_permutations.to_parquet(out_file_df_permutations)
# df_importances.to_parquet(out_file_df_importances)
# df_markers_rank_significance_logreg.to_parquet(out_file_df_markers_rank_significance_logreg)

# #diffexpr
# df_diffexpr.to_parquet(out_file_df_diffexpr)
# df_markers_rank_significance_diffexpr.to_parquet(out_file_df_markers_rank_significance_diffexpr)

#liana
# readwrite.write_anndata_folder(lrdata, outm_dir_liana_lrdata)

[34mINFO    [0m reading                                                                                                   
         ..[35m/../../data/xenium/processed/segmentation/10x_mm_5um/NSCLC/5k/1GQ9/1GQ9/normalised_results/outs/[0m[95mcell_fea[0m
         [95mture_matrix.h5[0m                                                                                            


  self.validate_table_in_spatialdata(v)
  adata.layers['X_normalised'] = X_normalised.loc[obs_found,var_found]
  adata.obsp[f"{obsm}_connectivities"] = knn_graph


Calculation complete.




monocyte macrophage


  adata_cti.obs[f"has_{ctj}_neighbor_str"] = adata_cti.obs[f"has_{ctj}_neighbor"].astype(str)
The order of those genes will be arbitrary, which may produce unexpected results.
The order of those genes will be arbitrary, which may produce unexpected results.
The order of those genes will be arbitrary, which may produce unexpected results.
The order of those genes will be arbitrary, which may produce unexpected results.


KeyboardInterrupt: 

In [62]:
import sklearn
import numpy as np
import pandas as pd
import scipy
import gseapy
import anndata
from sklearn.model_selection import train_test_split, permutation_test_score, StratifiedKFold
from sklearn.linear_model import LogisticRegression
from sklearn.inspection import permutation_importance
from sklearn.preprocessing import StandardScaler
from typing import Dict, List



def logreg(
    X,
    y,
    feature_names=None,
    scoring="precision",
    test_size=0.2,
    n_splits=5,
    n_permutations=30,
    n_repeats=5,
    max_iter=100,
    random_state=0,
    importance_mode="coef",
    class_weight="balanced",
    cv_mode="spatial",
    spatial_coords=None,

):
    """
    Perform logistic regression with permutation test and compute feature importances.

    Parameters:
    - X (array-like): Input data for model training/testing.
    - y (vector-like): Input vector of labels for model training/testing.
    - feature_names (vector-like): Names of X features (optional).
    - scoring (str): Scoring metric for the permutation test (e.g., 'f1', 'accuracy').
    - test_size (float): Proportion of data to use for testing.
    - max_iter (int): Maximum number of iterations for the logistic regression model.
    - n_splits (int): Number of splits for cross-validation.
    - n_permutations (int): Number of permutations for the permutation test.
    - n_repeats (int): Number of repeats for the permutation importance calculation.
    - random_state (int): Random seed for reproducibility.
    - importance_mode (str): Mode for feature importance calculation ('permutation' or 'coef').
    - class_weight (str): Class weight for the logistic regression model ('balanced' or None or dict).
    - cv_mode (str): Cross-validation mode ('stratified' or 'spatial').
    - spatial_coords (array-like): Spatial coordinates for spatial cross-validation.

    Returns:
    - df_permutations (pd.DataFrame): Summary of permutation test results.
    - df_importances (pd.DataFrame): Feature importances from permutation importance.
    """


    if cv_mode == 'stratified':
        cv = StratifiedKFold(n_splits=n_splits, shuffle = False)
    elif cv_mode == 'spatial':
        if spatial_coords is None:
            raise ValueError("spatial_coords must be provided when cv_mode is 'spatial'")
        cv = list(_utils.SpatialClusterGroupKFold(algorithm='bisectingkmeans',n_splits=n_splits).split(spatial_coords, y))

    # Initialize logistic regression model
    model = LogisticRegression(max_iter=max_iter, class_weight=class_weight)

    # Empirical p-value calculation using permutation test
    score, perm_scores, p_value = permutation_test_score(
        model, X, y, scoring=scoring, n_permutations=n_permutations, cv=cv, n_jobs=-1, verbose=1
    )

    # Summarize permutation test results
    df_permutations = pd.DataFrame(
        [[score, perm_scores.mean(), perm_scores.std(), p_value]],
        columns=[f"{scoring}_score", f"perm_mean{scoring}_score", f"perm_std{scoring}_score", "p_value"],
    )
    df_permutations["effect_size"] = (
        df_permutations[f"{scoring}_score"] - df_permutations[f"perm_mean{scoring}_score"]
    ) / df_permutations[f"perm_std{scoring}_score"]

    # Fit the model and compute feature importances from permutations

    if importance_mode == "permutation":
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=test_size, stratify=y, random_state=random_state
        )
        model.fit(X_train, y_train)
        importances = permutation_importance(
            model, pd.DataFrame.sparse.from_spmatrix(X_test), y_test, scoring=scoring, n_repeats=n_repeats, n_jobs=-1
        )
        importances.pop("importances")
        df_importances = pd.DataFrame(importances, index=feature_names).sort_values("importances_mean", ascending=False)

    elif importance_mode == "coef":
        model.fit(X, y)
        # Feature importances from model coefs
        # cv_results = cross_validate(model,X,y,return_estimator=True, scoring=scoring, n_jobs=-1)
        # importances = np.std(X, axis=0) * np.vstack([m.coef_[0] for m in cv_results["estimator"]])
        importances = StandardScaler(with_mean=False).fit(X).scale_ * model.coef_[0]
        df_importances = pd.DataFrame(importances, index=feature_names, columns=["importances"]).sort_values(
            "importances", ascending=False
        )

        # coef pvalues from formula
        # df_importances["pvalues"] = logit_pvalue(model, X.toarray())[1:]
    else:
        raise ValueError("Importance mode must be 'permutation' or 'coef'")

    return df_permutations, df_importances

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 6 concurrent workers.
[Parallel(n_jobs=-1)]: Done  19 out of  30 | elapsed:    0.4s remaining:    0.3s
[Parallel(n_jobs=-1)]: Done  30 out of  30 | elapsed:    0.5s finished


## Compute metrics marker purity

In [None]:
segmentation = 'proseg_expected'
condition = 'NSCLC'
panel = 'lung'
donor = '0S8R'
sample = '0S8R'
k = (segmentation,condition,panel,donor,sample)
if 'proseg' in segmentation:
    k_dir = ('proseg',condition,panel,donor,sample)
else:
    k_dir = k
name = '/'.join(k)
name_dir = '/'.join(k_dir)
sample_dir = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/segmentation/{name_dir}') / 'raw_results'
sample_counts = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/data.parquet')
sample_idx = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/std_seurat_analysis/{name}/lognorm/normalised_counts/cells.parquet')
cell_type_labels = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/cell_type_annotation/{name}/lognorm/reference_based/matched_reference_combo/rctd_class_aware/Level2.1/single_cell/labels.parquet')

out_file_df_permutations = sample_dir / 'permutation_summary.parquet'
out_file_df_importances = sample_dir / 'importances.parquet'
out_file_df_diffexpr = sample_dir / 'diffexpr.parquet'
out_file_df_markers_rank_significance_logreg = sample_dir / 'markers_rank_significance_logreg.json'
out_file_df_markers_rank_significance_diffexpr = sample_dir / 'markers_rank_significance_diffexpr.json'
# out_dir_liana_lrdata = sample_dir / 'liana_lrdata_folder'

n_neighbors = 10
n_permutations = 30
n_splits= 5
top_n = 20
top_n_lr = 10
cti = "macrophage"
ctj = "malignant cell"
scoring = 'f1'
markers = '/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/Xenium_panels_common_markers.csv'
# markers = 'diffexpr'


level_simplified = 'Level1'
palette = pd.read_csv('/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/metadata/col_palette_cell_types_combo.csv')
cell_types_mapping = palette.set_index('Level2.1')[level_simplified]
cell_types_mapping[cell_types_mapping.str.contains('malignant')] = 'malignant cell'

####
#### READ DATA
####
# read raw data to get spatial coordinates
adata = readwrite.read_xenium_sample(
    sample_dir,
    cells_as_circles=False,
    cells_boundaries=False,
    cells_boundaries_layers=False,
    nucleus_boundaries=False,
    cells_labels=False,
    nucleus_labels=False,
    transcripts=False,
    morphology_mip=False,
    morphology_focus=False,
    aligned_images=False,
    anndata=True,
)
if 'proseg_expected' in sample_counts.as_posix():
    adata.obs_names = 'proseg-'+adata.obs_names.astype(str)


# read normalised data, filter cells
X_normalised = pd.read_parquet(sample_counts)
X_normalised.index = pd.read_parquet(sample_idx).iloc[:, 0]
X_normalised.columns = X_normalised.columns.str.replace('.','-')
adata = adata[X_normalised.index,X_normalised.columns]
adata.layers['X_normalised'] = X_normalised

# log-normalize before DE
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)

# read labels
label_key = "label_key"
adata.obs[label_key] = pd.read_parquet(cell_type_labels).set_index("cell_id").iloc[:, 0]
adata = adata[adata.obs[label_key].notna()]
adata.obs[label_key] = adata.obs[label_key].replace(cell_types_mapping)

# read markers if needed
if markers != "diffexpr":
    if markers == "xenium_common_markers_file":
        level_simplified = 'Level1'
        palette = pd.read_csv('/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/metadata/col_palette_cell_types_combo.csv')
        cell_types_mapping = palette.set_index(level)[level_simplified].replace(r' of .+', '', regex=True)
        cell_types_mapping[cell_types_mapping.str.contains('malignant')] = 'malignant cell'
        adata.obs[label_key] = adata.obs[label_key].replace(cell_types_mapping)
        df_markers = pd.read_csv('/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/Xenium_panels_common_markers.csv')[["cell_type","gene_name"]]
    else:
        df_markers = pd.read_csv(markers)[["cell_type","gene_name"]]

    ct_not_found = adata.obs[label_key][~adata.obs[label_key].isin(df_markers['cell_type'])].unique()
    print(f"Could not find {ct_not_found} in markers file")
    adata = adata[adata.obs[label_key].isin(df_markers['cell_type'])]


# get kNN graph
# obsm = 'spatial'
# knnlabels, knndis, knnidx, knn_graph = _utils.get_knn_labels(
#     adata,n_neighbors=n_neighbors,
#     label_key=label_key,obsm=obsm,
#     return_sparse_neighbors=True)
# adata.obsp[f'{obsm}_connectivities'] = knn_graph


# iterate over targets permutations (cell type i with cell type j presence in kNN)
u_cell_types = adata.obs[label_key].unique()
df_ctj_marker_genes = {}


for ctj in u_cell_types:
    if  (adata.obs[label_key]==ctj).sum() < 30:
        print(f"Not enough cells from class {ctj}")
        continue

    # get markers
    if markers == "diffexpr":
        sc.tl.rank_genes_groups(adata, groupby=label_key, groups=[ctj], reference='rest', method="wilcoxon")
        ctj_marker_genes = sc.get.rank_genes_groups_df(adata, group=ctj)['names'][: top_n].tolist()
    else:
        ctj_marker_genes = df_markers[df_markers["cell_type"] == ctj]["gene_name"].tolist()
        ctj_marker_genes = [g for g in ctj_marker_genes if g in adata.var_names]

        if len(ctj_marker_genes) == 0:
            print(f"no markers found for {ctj}")
            continue

    df_ctj_marker_genes[ctj] = ctj_marker_genes

    for cti in u_cell_types:
        if cti == ctj:
            continue
        print(cti, ctj)

        adata.obs[f"has_{ctj}_neighbor"] = knnlabels[ctj]>0

        # Filter for cti
        adata_cti = adata[adata.obs[label_key] == cti]

        if (   (adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30
            or (~adata_cti.obs[f"has_{ctj}_neighbor"]).sum() < 30
            ):
            print(f"Not enough cells from each class to test {cti} with {ctj} neighbors")
            continue


        ###
        ### DIFF EXPR TEST: check DE genes between cti with ctj neighbor or not
        ###
        # adata_cti.obs[f'has_{ctj}_neighbor_str'] = adata_cti.obs[f'has_{ctj}_neighbor'].astype(str)
        # sc.tl.rank_genes_groups(adata_cti, groupby=f"has_{ctj}_neighbor_str", groups=['True'], reference='False', method="wilcoxon")
        # df_diffexpr[cti,ctj] = sc.get.rank_genes_groups_df(adata_cti, group='True').sort_values('pvals_adj')


        # # get significance from gsea and hypergeometric test
        # df_markers_rank_significance_diffexpr[cti, ctj] = _utils.get_marker_rank_significance(
        #     rnk=df_diffexpr[cti, ctj].set_index("names")["logfoldchanges"],
        #     gene_set=ctj_marker_genes,
        #     top_n=top_n,
        # )

AnnData expects .obs.index to contain strings, but got values like:
    [0, 1, 2, 3, 4]

    Inferred to be: integer



metrics_summary.csv not found at: /work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/segmentation/proseg/NSCLC/lung/0S8R/0S8R/raw_results/metrics_summary.csv
Could not find ['cycling lymphocyte'] in markers file
stromal cell malignant cell




## Plot results diffexpr

In [None]:
# cfg paths
xenium_dir = Path(cfg['xenium_processed_data_dir'])
xenium_std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
xenium_cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])
results_dir = Path(cfg['results_dir'])
palette_dir = Path(cfg['xenium_metadata_dir'])

# Params
# probably only need to run for lognorm data
normalisations = ['lognorm',]
layers = ['data',]
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2.1'
segmentation_palette = palette_dir / 'col_palette_segmentation.csv'

n_neighbors = 10
n_permutations = 30
n_splits= 5
top_n = 20
scoring = 'f1'
markers = 'diffexpr' #'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/cellmarker_cell_types_markers.json'

# needed to get unique cell types names for each level
# cell_types_palette = pd.read_csv(palette_dir / 'col_palette_cell_types_combo.csv')

df_diffexpr = {}
df_markers_rank_significance_diffexpr = {}
for segmentation in (segmentations := xenium_std_seurat_analysis_dir.iterdir()):
    for condition in (conditions := segmentation.iterdir()): 
        for panel in (panels := condition.iterdir()):
            for donor in (donors := panel.iterdir()):
                for sample in (samples := donor.iterdir()):
                    for normalisation in normalisations:
                        for layer in layers:
                            # for reference in references:
                            #     for method in methods:
                            #         for level in levels:

                            k = (segmentation.stem,condition.stem,panel.stem,donor.stem,sample.stem)
                            name = '/'.join(k)

                            out_file_df_diffexpr = results_dir / f'contamination_metrics_diffexpr/{name}/{normalisation}/{layer}_{reference}_{method}_{level}_diffexpr.parquet'
                            out_file_df_markers_rank_significance_diffexpr = results_dir / f'contamination_metrics_diffexpr/{name}/{normalisation}/{layer}_{reference}_{method}_{level}_markers_rank_significance_diffexpr.parquet'

                            if out_file_df_diffexpr.exists():
                                # df_diffexpr[k] = pd.read_parquet(out_file_df_diffexpr)
                                df_markers_rank_significance_diffexpr[k] = pd.read_parquet(out_file_df_markers_rank_significance_diffexpr)

In [None]:
cti = 'T cell'
ctj = 'malignant cell'
xenium_levels = ["segmentation", "condition", "panel", "donor", "sample","cti","ctj"]

hue = "segmentation"
hue_order = [
    "10x_mm_0um",
    "10x_mm_5um",
    "10x_mm_15um",
    "10x_0um",
    "10x_5um",
    "10x_15um",
    "baysor",
    "proseg_expected",
    "proseg_mode",
    "segger",
]


palette = pd.read_csv(segmentation_palette, index_col=0).iloc[:, 0]


df = pd.concat(df_markers_rank_significance_diffexpr).reset_index()
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
df = df.query("cti == @cti and ctj == @ctj")
df['-log10pvalue'] = -np.log10(df['hypergeometric_pvalue'])


# plotting params, palette
title = f"Reference: {reference}, Method: {method}, Level: {level} \n{cti} contaminated by {ctj}"
unique_labels = [c for c in hue_order if c in np.unique(df[hue].dropna())]
unique_labels = unique_labels + [c for c in np.unique(df[hue].dropna()) if c not in unique_labels]
palette = {u: palette[u] for u in unique_labels}
legend_handles = [mpatches.Patch(color=color, label=label) for label, color in palette.items()]

sns.set(style="ticks")

### hypergeometric pvalue boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='-log10pvalue', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

plt.ylabel(r'$-\log_{10} \text{ p-value}$')
sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### NES boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='NES', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### number of hits boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y=f'n_hits_{top_n=}', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()

In [None]:
sns.set_style('ticks')
ref_segmentation = '10x_5um'

df = pd.concat(df_markers_rank_significance_diffexpr).reset_index()
df['-log10pvalue'] = -np.log10(df['hypergeometric_pvalue'])
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
u_condition_panel = df[['condition','panel']].drop_duplicates().values

metrics = ['NES', '-log10pvalue', f'n_hits_{top_n=}']

for metric in metrics:
    for condition,panel in u_condition_panel:

        df_plot = df.query(f"segmentation == '{ref_segmentation}' and condition == '{condition}' and panel == '{panel}'")
        df_plot = df_plot.groupby(['cti', 'ctj'])[metric].mean().unstack()
        df_plot = df_plot.loc[df_plot.sum(1).sort_values(ascending=False).index]
        df_plot = df_plot[df_plot.sum(0).sort_values(ascending=False).index]

        f = plt.figure(figsize=(8,8))
        ax = plt.subplot()
        ax.set_title(f"{condition=} {panel=} {metric=}",fontsize=20)
        g = sns.heatmap(df_plot,cmap='coolwarm',center=-np.log10(0.05) if metric == '-log10pvalue' else 0.)
        plt.show()

## Plot results logreg

In [None]:
# cfg paths
xenium_dir = Path(cfg['xenium_processed_data_dir'])
xenium_std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
xenium_cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])
results_dir = Path(cfg['results_dir'])
palette_dir = Path(cfg['xenium_metadata_dir'])

# Params
normalisations = ['lognorm',]
layers = ['data',]
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2.1'
segmentation_palette = palette_dir / 'col_palette_segmentation.csv'

n_neighbors = 10
n_permutations = 30
n_splits= 5
top_n = 20
scoring = 'f1'
markers = 'diffexpr' #'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/markers/cellmarker_cell_types_markers.json'

# needed to get unique cell types names for each level
# cell_types_palette = pd.read_csv(palette_dir / 'col_palette_cell_types_combo.csv')

df_diffexpr = {}
df_markers_rank_significance_diffexpr = {}
for segmentation in (segmentations := xenium_std_seurat_analysis_dir.iterdir()):
    for condition in (conditions := segmentation.iterdir()): 
        for panel in (panels := condition.iterdir()):
            for donor in (donors := panel.iterdir()):
                for sample in (samples := donor.iterdir()):
                    for normalisation in normalisations:
                        for layer in layers:
                            # for reference in references:
                            #     for method in methods:
                            #         for level in levels:

                            k = (segmentation.stem,condition.stem,panel.stem,donor.stem,sample.stem)
                            name = '/'.join(k)

                            out_file_df_permutations_logreg = results_dir / f'contamination_metrics_logreg/{name}/{normalisation}/{layer}_{reference}_{method}_{level}_permutations_logreg.parquet'
                            out_file_df_importances_logreg = results_dir / f'contamination_metrics_logreg/{name}/{normalisation}/{layer}_{reference}_{method}_{level}_importances_logreg.parquet'
                            out_file_df_markers_rank_significance_logreg = results_dir / f'contamination_metrics_logreg/{name}/{normalisation}/{layer}_{reference}_{method}_{level}_markers_rank_significance_logreg.json'

                            if out_file_df_permutations_logreg.exists():
                                # df_diffexpr[k] = pd.read_parquet(out_file_df_diffexpr)
                                df_permutations_logreg[k] = pd.read_parquet(out_file_df_permutations_logreg)
                                df_importances_logreg[k] = pd.read_parquet(out_file_df_importances_logreg)
                                df_markers_rank_significance_logreg[k] = pd.read_parquet(out_file_df_markers_rank_significance_logreg)

In [6]:
cti = 'T cell'
ctj = 'malignant cell'
xenium_levels = ["segmentation", "condition", "panel", "donor", "sample","cti","ctj"]

df = pd.concat(df_markers_rank_significance_logreg).reset_index()
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
df = df.query("cti == @cti and ctj == @ctj")
df['-log10pvalue'] = -np.log10(df['hypergeometric_pvalue'])


std_seurat_analysis_dir = Path(cfg['xenium_std_seurat_analysis_dir'])
cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])
results_dir = Path(cfg['results_dir'])
palette_dir = Path(cfg['xenium_metadata_dir'])
segmentation_palette = palette_dir / 'col_palette_segmentation.csv'

hue = "segmentation"
hue_order = [
    "10x_mm_0um",
    "10x_mm_5um",
    "10x_mm_15um",
    "10x_0um",
    "10x_5um",
    "10x_15um",
    "baysor",
    "proseg_expected",
    "proseg_mode",
    "segger",
]


palette = pd.read_csv(segmentation_palette, index_col=0).iloc[:, 0]


# plotting params, palette
title = f"Reference: {reference}, Method: {method}, Level: {level} \n{cti} contaminated by {ctj}"
unique_labels = [c for c in hue_order if c in np.unique(df[hue].dropna())]
unique_labels = unique_labels + [c for c in np.unique(df[hue].dropna()) if c not in unique_labels]
palette = {u: palette[u] for u in unique_labels}
legend_handles = [mpatches.Patch(color=color, label=label) for label, color in palette.items()]


### hypergeometric pvalue boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='-log10pvalue', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

plt.ylabel(r'$-\log_{10} \text{ p-value}$')
sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### NES boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='NES', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### number of hits boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y=f'n_hits_{top_n=}', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()

NameError: name 'df_markers_rank_significance_logreg' is not defined

In [None]:
sns.set_style('ticks')
ref_segmentation = '10x_5um'

df = pd.concat(df_markers_rank_significance_diffexpr).reset_index()
df['-log10pvalue'] = -np.log10(df['hypergeometric_pvalue'])
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
u_condition_panel = df[['condition','panel']].drop_duplicates().values

metrics = ['NES', '-log10pvalue', f'n_hits_{top_n=}']

for metric in metrics:
    for condition,panel in u_condition_panel:

        df_plot = df.query(f"segmentation == '{ref_segmentation}' and condition == '{condition}' and panel == '{panel}'")
        df_plot = df_plot.groupby(['cti', 'ctj'])[metric].mean().unstack()
        df_plot = df_plot.loc[df_plot.sum(1).sort_values(ascending=False).index]
        df_plot = df_plot[df_plot.sum(0).sort_values(ascending=False).index]

        f = plt.figure(figsize=(8,8))
        ax = plt.subplot()
        ax.set_title(f"{condition=} {panel=} {metric=}",fontsize=20)
        g = sns.heatmap(df_plot,cmap='coolwarm',center=-np.log10(0.05) if metric == '-log10pvalue' else 0.)
        plt.show()

In [None]:
cti = 'T cell'
ctj = 'malignant cell'
xenium_levels = ["segmentation", "condition", "panel", "donor", "sample","cti","ctj"]

df = pd.concat(df_markers_rank_significance_logreg).reset_index()
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
df = df.query("cti == @cti and ctj == @ctj")
df['-log10pvalue'] = -np.log10(df['hypergeometric_pvalue'])


### hypergeometric pvalue boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='-log10pvalue', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

plt.ylabel(r'$-\log_{10} \text{ p-value}$')
sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### NES boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y='NES', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()


### number of hits boxplot
f = plt.figure(figsize=(6, 6))
ax = plt.subplot()
g = sns.boxplot(df,x='panel',y=f'n_hits_{top_n=}', 
                hue=hue, hue_order=unique_labels, 
                legend=False, palette=palette,ax=ax
                )

sns.despine(offset=10, trim=True)
ax.yaxis.grid(True)

plt.suptitle(title)
f.legend(
    handles=legend_handles,
    loc="center left",
    bbox_to_anchor=(1, 0.5),
    title=hue,
    frameon=False,
)
plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(out_file, dpi=dpi, bbox_inches="tight")
plt.show()

In [None]:
df = pd.concat(df_permutations_logreg).reset_index()
df.columns = xenium_levels + df.columns[len(xenium_levels) :].tolist()
df