In [None]:
import numpy as np
from pathlib import Path
import spatialdata_io
import spatialdata
import json
import time
import pandas as pd
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import concurrent.futures
from joblib import Parallel, delayed

import sys
sys.path.extend(['../../scripts','../../scripts/xenium'])
import coexpression
import readwrite
cfg = readwrite.config()

def diagonal_line(ax):
    limits = np.array([ax.get_xlim(), ax.get_ylim()])
    min_val, max_val = np.max(limits[:, 0]), np.min(limits[:, 1])  # Get the valid range in log scale
    ax.plot([min_val, max_val], [min_val, max_val], c='k', lw=1, alpha=0.8, zorder=1)

## Load data

In [None]:
# cfg paths
xenium_dir = Path(cfg['xenium_processed_data_dir'])
xenium_raw_data_dir = Path(cfg['xenium_raw_data_dir'])
results_dir = Path(cfg['results_dir'])

# Segmentation, mapping paths
dir_segmentations = {
    dir_segmentation.name: (dir_segmentation)
    for dir_segmentation in xenium_dir.iterdir()
}
# dir_segmentations['default'] = xenium_raw_data_dir

# Read resegmentations and RCTD
xenium_paths = {}
umaps = {}

for segmentation in (segmentations := xenium_dir.iterdir()):
    for condition in (conditions := segmentation.iterdir()): 
        for panel in (panels := condition.iterdir()):
            for donor in (donors := panel.iterdir()):
                for sample in (samples := donor.iterdir()):
                    
                    k = (segmentation.stem,condition.stem,panel.stem,donor.stem,sample.stem)
                    sample_path = sample / "normalised_results/outs"
                    name = '/'.join(k)

                    xenium_paths[k] = sample_path
                    # if (sample_path / 'analysis/umap/gene_expression_2_components/projection.csv').exists():
                    #     umaps[k] = pd.read_csv(sample_path / 'analysis/umap/gene_expression_2_components/projection.csv',index_col=0)



# CC, pos_rate = readwrite.read_coexpression_files(cc_paths, results_dir)
xenium_levels = ('segmentation','condition','panel','donor','sample')
ads = readwrite.read_xenium_samples(xenium_paths,anndata_only=True,transcripts=False,sample_name_as_key=False)
ads = pd.Series(ads.values(),
                index=pd.Index(ads.keys(),name = xenium_levels),
                dtype=object).sort_index()

In [None]:
u_condition_panel = ads.index.to_frame()[['condition','panel']].drop_duplicates().index

# load probe count info
gene_panels_info = {}
for condition_panel in u_condition_panel:
    condition, panel = condition_panel[1:3]
    name = '/'.join(condition_panel)
    p=f'/work/PRTNR/CHUV/DIR/rgottar1/spatial/env/xenium_paper/data/xenium/processed/segmentation/{name}/normalised_results/outs/gene_panel.json'
    gene_panels_info[condition, panel] = readwrite.get_gene_panel_info(p)

## Xenium n° probes vs mean expr

In [None]:
ref_segmentation = '10x_0um'

# merge ads and probe info
df = {}
for k,ad in ads.items():
    df_ = pd.DataFrame()

    condition, panel = k[1:3]
    df_['gene_coverage']= (gene_panels_info[condition, panel] 
        .set_index('name')['gene_coverage'] 
        .loc[ad.var_names])
    df_['mean_counts_per_gene'] = ad.X.mean(0).A1
    df_['name'] = ad.var_names

    df[k] = df_

df = pd.concat(df).reset_index()
df.columns = xenium_levels + tuple(df.columns[len(xenium_levels):])
df['condition_panel'] = df[['condition','panel']].agg('-'.join, axis=1)
df['condition_panel_sample'] = df[['condition','panel','sample']].agg(' '.join, axis=1)
df_seg = df[df['segmentation']==ref_segmentation]


# Create a FacetGrid to make subplots for each sample
palette = dict(zip(df_seg['condition_panel'].unique(),sns.color_palette('tab10'))) 
g = sns.FacetGrid(df_seg, col="condition_panel_sample", height=4, aspect=1.2, sharex=True, sharey=True, col_wrap=5)
g.map_dataframe(sns.boxplot, x='gene_coverage', y='mean_counts_per_gene', hue='condition_panel',palette=palette, log_scale=True)
g.add_legend(title="Condition-Panel")
g.set_axis_labels("Gene Coverage", "Mean Counts per Gene")
g.set_titles("{col_name}")
# plt.subplots_adjust(top=0.85)
# g.fig.suptitle('Mean Counts per Gene vs Gene Coverage by Sample')
plt.show()

## scRNA n° probes vs mean expr

In [None]:
scrna_probe_path = '../../../data/markers/Chromium_Human_Transcriptome_Probe_Set_v1.1.0_GRCh38-2024-A.csv'
scrna_gene_panel_info = pd.read_csv(scrna_probe_path,skiprows=5)
scrna_gene_coverage = scrna_gene_panel_info['gene_name'].value_counts()

# p =  '../../../data/markers/Chromium_Human_Transcriptome_Probe_Set_v1.1.0_GRCh38-2024-A.probe_metadata.tsv'
# gene_panel_metadata = pd.read_csv(p,sep='\t')

seurat_to_h5_dir = Path(cfg['results_dir']) / 'seurat_to_h5'

ads_sc = {}
for dir in seurat_to_h5_dir.iterdir():
    if dir.is_dir():
        ads_sc[dir.name] = sc.read_10x_h5(dir / 'RNA_counts.h5')

In [None]:
# merge ads and probe info
df = {}
for k,ad in ads_sc.items():
    df_ = pd.DataFrame()
    
    genes_found = [ g for g in ad.var_names if g in scrna_gene_coverage.index ]
    df_['gene_coverage']= scrna_gene_coverage.loc[genes_found]
    df_['mean_counts_per_gene'] = ad[:,genes_found].X.mean(0).A1
    df_['name'] = genes_found

    df[k] = df_

df = pd.concat(df).reset_index()
df.columns = ('reference',) + tuple(df.columns[1:])

# Create a FacetGrid to make subplots for each sample
palette = dict(zip(df['reference'].unique(),sns.color_palette('tab10'))) 
g = sns.FacetGrid(df, col="reference", height=4, aspect=1.2, sharex=True, sharey=True, col_wrap=3)
g.map_dataframe(sns.boxplot, x='gene_coverage', y='mean_counts_per_gene', hue='reference',palette=palette, log_scale=True)
g.add_legend(title="reference")
g.set_axis_labels("Gene Coverage", "Mean Counts per Gene")
g.set_titles("{col_name}")
plt.subplots_adjust(top=0.9)
g.fig.suptitle('scRNA probes gene coverage')
plt.show()

In [None]:
for condition_panel in u_condition_panel:
    condition, panel = condition_panel[1:3]

    # merge ads and probe info
    df = {}
    for k,ad in ads_sc.items():
        df_ = pd.DataFrame()
        gene_coverage_ = gene_panels_info[condition,panel].set_index('name')['gene_coverage'] 

        genes_found = [ g for g in ad.var_names if g in gene_coverage_.index ]
        df_['gene_coverage'] = (gene_coverage_.loc[genes_found])
        df_['mean_counts_per_gene'] = ad[:,genes_found].X.mean(0).A1
        # df_['name'] = genes_found

        df[k] = df_

    df = pd.concat(df).reset_index()
    df.columns = ('reference',) + tuple(df.columns[1:])

    # Create a FacetGrid to make subplots for each sample
    palette = dict(zip(df['reference'].unique(),sns.color_palette('tab10'))) 
    g = sns.FacetGrid(df, col="reference", height=4, aspect=1.2, sharex=True, sharey=True, col_wrap=3)
    g.map_dataframe(sns.boxplot, x='gene_coverage', y='mean_counts_per_gene', hue='reference',palette=palette, log_scale=True)
    g.add_legend(title="reference")
    g.set_axis_labels("Gene Coverage", "Mean Counts per Gene")
    g.set_titles("{col_name}")
    plt.subplots_adjust(top=0.9)
    g.fig.suptitle(f'{condition} {panel} gene coverage')
    plt.show()

## scRNA vs Xenium

In [None]:
df_sc_means = {}
for k_sc,ad_sc in ads_sc.items():
    df_sc_means[k_sc] = pd.Series(ad_sc.X.mean(0).A1,index = ad_sc.var_names)

# merge ads and probe info
df = {}
for k,ad in ads.items():
    df_ = pd.DataFrame()

    condition, panel = k[1:3]
    df_['gene_coverage']= (gene_panels_info[condition, panel] 
        .set_index('name')['gene_coverage'] 
        .loc[ad.var_names])
    df_['mean_counts_per_gene'] = ad.X.mean(0).A1
    df_['name'] = ad.var_names

    for k_sc,df_sc_ in df_sc_means.items():
        genes_found = [g for g in ad.var_names if g in df_sc_.index]
        df_[f'mean_counts_per_gene_{k_sc}'] = df_sc_.loc[genes_found]

    df[k] = df_

df = pd.concat(df).reset_index()
df.columns = xenium_levels + tuple(df.columns[len(xenium_levels):])
df['condition_panel'] = df[['condition','panel']].agg('-'.join, axis=1)
df['condition_panel_sample'] = df[['condition','panel','sample']].agg(' '.join, axis=1)
df['gene_coverage'] = df['gene_coverage'].astype(float)
df_seg = df[df['segmentation']=='10x_0um']

In [None]:
for k_sc in ads_sc.keys():
    g = sns.lmplot(df_seg, 
        col="condition_panel_sample",  x='mean_counts_per_gene', y=f'mean_counts_per_gene_{k_sc}', 
        hue='gene_coverage', palette='viridis', 
        height=5, aspect=1.2, col_wrap=5, fit_reg=False,
        scatter_kws=dict(s=6, alpha=.5),
        facet_kws=dict(sharex=True, sharey=True), 
    )

    axes = g.fig.axes
    for ax in axes:
        diagonal_line(ax)

    g.set(xscale="log", yscale="log")
    g.set_axis_labels("Mean Counts per Gene Xenium", "Mean Counts per Gene scRNA")
    g.set_titles("{col_name}")
    plt.subplots_adjust(top=0.95)
    g.fig.suptitle(k_sc)
    plt.show()

## scRNA vs Xenium common genes

In [None]:
panels_genes = [ad.var_names for ad in ads.values]
common_genes = list(set.intersection(*map(set, panels_genes)))
df_common_genes = df_seg[df_seg['name'].isin(common_genes)]

for k_sc in ads_sc.keys():
    g = sns.lmplot(df_common_genes, 
        col="condition_panel_sample",  x='mean_counts_per_gene', y=f'mean_counts_per_gene_{k_sc}', 
        hue='gene_coverage', palette='viridis', 
        height=5, aspect=1.2, col_wrap=5, fit_reg=False,
        scatter_kws=dict(s=20, alpha=.5),
        facet_kws=dict(sharex=True, sharey=True), 
    )

    axes = g.fig.axes
    for ax in axes:
        diagonal_line(ax)

    g.set(xscale="log", yscale="log")
    g.set_axis_labels("Mean Counts per Gene Xenium", "Mean Counts per Gene scRNA")
    g.set_titles("{col_name}")
    plt.subplots_adjust(top=0.95)
    g.fig.suptitle(k_sc)
    plt.show()
    break

## scRNA vs Xenium common genes except breast

In [None]:
panels_genes = [ad.var_names for k,ad in ads.items() if 'breast' not in k]
common_genes = list(set.intersection(*map(set, panels_genes)))
df_common_genes = df_seg[df_seg['name'].isin(common_genes)]

for k_sc in ads_sc.keys():
    g = sns.lmplot(df_common_genes, 
        col="condition_panel_sample",  x='mean_counts_per_gene', y=f'mean_counts_per_gene_{k_sc}', 
        hue='gene_coverage', palette='viridis', 
        height=5, aspect=1.2, col_wrap=5, fit_reg=False,
        scatter_kws=dict(s=20, alpha=.5),
        facet_kws=dict(sharex=True, sharey=True), 
    )

    axes = g.fig.axes
    for ax in axes:
        diagonal_line(ax)

    g.set(xscale="log", yscale="log")
    g.set_axis_labels("Mean Counts per Gene Xenium", "Mean Counts per Gene scRNA")
    g.set_titles("{col_name}")
    plt.subplots_adjust(top=0.95)
    g.fig.suptitle(k_sc)
    plt.show()
    break