In [142]:
import dask

dask.config.set({"dataframe.query-planning": False})

import numpy as np
from pathlib import Path
import pandas as pd
import scanpy as sc
import scipy
import seaborn as sns
import matplotlib.pyplot as plt

import sys
sys.path.extend(['../../scripts','../../scripts/xenium'])
import readwrite
import preprocessing

cfg = readwrite.config()

In [246]:
from scib_metrics.benchmark import Benchmarker, BioConservation, BatchCorrection
from scib_metrics.benchmark._core import _LABELS, _BATCH

xenium_cell_type_annotation_dir = Path(cfg['xenium_cell_type_annotation_dir'])

# Set up argument parser
# Access the arguments
segmentation = 'proseg_expected'
condition = 'breast'
panel = 'breast'
normalisation_method = "lognorm"  # fix this for now, even for sctransfrom args.normalisation_method
layer = 'data'
n_comps = 50

k = (segmentation,condition,panel,normalisation_method)
name = '/'.join(k)

panel = Path(f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/std_seurat_analysis/{segmentation}/{condition}/{panel}')
reference = 'matched_reference_combo'
method = 'rctd_class_aware'
level = 'Level2'

segmentation = panel.parents[1].stem
condition = panel.parents[0].stem

# variables
OBSM_KEY = 'X_pca'
BATCH_KEY = 'blobs'
CT_KEY = ('matched_reference_combo', 'rctd_class_unaware', 'Level2')

# read xenium samples
ads = {}
for donor in (donors := panel.iterdir()):
    for sample in (samples := donor.iterdir()):
        k = (
            segmentation,
            condition,
            panel.stem,
            donor.stem,
            sample.stem,
            normalisation_method,
        )
        name = "/".join(k)

        sample_counts_path = sample / f"{normalisation_method}/normalised_counts/{layer}.parquet"
        sample_idx_path = sample / f"{normalisation_method}/normalised_counts/cells.parquet"

        ads[k] = sc.AnnData(pd.read_parquet(sample_counts_path))
        if layer != "scale_data":  # no need to sparsify scale_data which is dense
            ads[k].X = scipy.sparse.csr_matrix(ads[k].X)
        ads[k].obs_names = pd.read_parquet(sample_idx_path).iloc[:, 0]

        sample_annotation_dir = xenium_cell_type_annotation_dir / f'{name}/reference_based'
        for reference in (references := sample_annotation_dir.iterdir()):
            for method in (methods := reference.iterdir()):
                for level in (levels := method.iterdir()):

                    annot_file = sample_annotation_dir / f"{reference.stem}/{method.stem}/{level.stem}/single_cell/labels.parquet"
                    ads[k].obs[(reference.stem, method.stem, level.stem)] = pd.read_parquet(annot_file).set_index("cell_id").iloc[:, 0]


# concatenate
xenium_levels = ["segmentation", "condition", "panel", "donor", "sample"]
for k in ads.keys():
    for i, lvl in enumerate(xenium_levels):
        ads[k].obs[lvl] = k[i]
ad_merge = sc.concat(ads)
ad_merge.obs[BATCH_KEY] = ad_merge.obs[xenium_levels].agg("_".join,axis=1)

# drop NaN annotations
ad_merge = ad_merge[ad_merge.obs.notna().all(1)].copy()
CT_KEYS = [c for c in ad_merge.obs.columns if c not in xenium_levels]

# subsample to reasonable size
if len(ad_merge) > 1_000:
    sc.pp.subsample(ad_merge, n_obs=1_000)

# compute pca
sc.tl.pca(ad_merge, n_comps=n_comps)

# set up metrics
batchcor = BatchCorrection(
    silhouette_batch=False,
    ilisi_knn=True,
    kbet_per_label=False,
    graph_connectivity=False,
    pcr_comparison=False,
)

biocons = BioConservation(
    isolated_labels=False,
    nmi_ari_cluster_labels_leiden=True,
    nmi_ari_cluster_labels_kmeans=False,
    silhouette_label=False,
    clisi_knn=True,
)

# benchmark all cell type keys
df_results = pd.DataFrame()
for i,CT_KEY in enumerate(CT_KEYS):
    if i == 0:
        bm = Benchmarker(
            ad_merge,
            batch_key=BATCH_KEY,
            label_key=CT_KEY,
            embedding_obsm_keys=[OBSM_KEY],
            pre_integrated_embedding_obsm_key=OBSM_KEY,
            bio_conservation_metrics=biocons,
            batch_correction_metrics=batchcor,
            n_jobs=-1,
        )
        bm.benchmark()
    else:
        # to avoid recomputing kNN graph
        bm._emb_adatas[OBSM_KEY].obs[_LABELS] = ad_merge.obs[CT_KEY].values
        bm.benchmark()

    df_results[CT_KEY] = bm.get_results(min_max_scale=False).iloc[0]

# df_results.to_parquet(
#     f"{out_dir}/scib_metrics_{CT_KEY}.parquet"
# )

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
Computing neighbors: 100%|██████████| 1/1 [00:00<00:00,  2.43it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.65it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.25it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  3.89it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  3.99it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.24it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.28it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.23it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.22it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  3.90it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  2.97it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.26it/s]
Embeddings:

In [249]:
ad_merge = sc.concat(ads)


  utils.warn_names_duplicates("obs")


In [250]:
ad_merge

AnnData object with n_obs × n_vars = 1660238 × 280
    obs: ('matched_reference_combo', 'rctd_class_unaware', 'Level4'), ('matched_reference_combo', 'rctd_class_unaware', 'Level2'), ('matched_reference_combo', 'rctd_class_unaware', 'Level3'), ('matched_reference_combo', 'rctd_class_unaware', 'Level1'), ('matched_reference_combo', 'rctd_class_aware', 'Level4'), ('matched_reference_combo', 'rctd_class_aware', 'Level2'), ('matched_reference_combo', 'rctd_class_aware', 'Level3'), ('matched_reference_combo', 'rctd_class_aware', 'Level1'), ('external_reference', 'rctd_class_unaware', 'Level2'), ('external_reference', 'rctd_class_unaware', 'Level3'), ('external_reference', 'rctd_class_unaware', 'Level1'), ('external_reference', 'rctd_class_aware', 'Level2'), ('external_reference', 'rctd_class_aware', 'Level3'), ('external_reference', 'rctd_class_aware', 'Level1'), 'segmentation', 'condition', 'panel', 'donor', 'sample'

In [None]:
if len(ad_merge) > 1_000:
    adata = sc.pp.subsample(ad_merge, n_obs=1_000,copy=True)
else:
    adata = ad_merge

CT_KEYS = [c for c in adata.obs.columns if c not in xenium_levels]

batchcor = BatchCorrection(
    silhouette_batch=False,
    ilisi_knn=True,
    kbet_per_label=False,
    graph_connectivity=False,
    pcr_comparison=False,
)

biocons = BioConservation(
    isolated_labels=False,
    nmi_ari_cluster_labels_leiden=True,
    nmi_ari_cluster_labels_kmeans=False,
    silhouette_label=False,
    clisi_knn=True,
)

df_results = pd.DataFrame()
for i,CT_KEY in enumerate(CT_KEYS):
    if i == 0:
        bm = Benchmarker(
            adata,
            batch_key=BATCH_KEY,
            label_key=CT_KEYS[0],
            embedding_obsm_keys=[OBSM_KEY],
            bio_conservation_metrics=biocons,
            batch_correction_metrics=batchcor,
            n_jobs=-1,
        )
        bm.benchmark()
    else:
        bm._emb_adatas[OBSM_KEY].obs[_LABELS] = adata.obs[CT_KEYS[1]].values
        bm.benchmark()

    df_results[CT_KEY] = bm.get_results(min_max_scale=False).iloc[0]

# df_results.to_parquet(
#     f"{out_dir}/scib_metrics_{CT_KEY}.parquet"
# )

  utils.warn_names_duplicates("obs")
  utils.warn_names_duplicates("obs")
Computing neighbors:   0%|          | 0/1 [00:00<?, ?it/s]

Computing neighbors: 100%|██████████| 1/1 [00:00<00:00,  2.34it/s]
Computing neighbors: 100%|██████████| 1/1 [00:00<00:00,  2.34it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.86it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.85it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.07it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.98it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.99it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.09it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.95it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.03it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  5.06it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.94it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.93it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.91it/s]
Embeddings: 100%|[32m██████████[0m| 1/1 [00:00<00:00,  4.91i

Embeddings: 100%|[32m██████████[0m| 1/1 [01:03<00:00, 63.14s/it]
