In [None]:
import pandas as pd                     # Data manipulation and analysis (tables/dataframes)
import matplotlib.pyplot as plt         # Plotting/figures
import seaborn as sns                   # Statistical visualization with clean defaults
import scanpy as sc                     # Single-cell/spatial omics analysis toolkit
import numpy as np                      # Numerical computing & arrays
from sklearn.decomposition import PCA   # Dimensionality reduction (principal components)
from sklearn.neighbors import NearestNeighbors  # k-NN graph construction / neighbor search
import igraph as ig                     # Graph data structures and algorithms
import leidenalg as la                  # Leiden community detection (clustering on graphs)
from umap.umap_ import fuzzy_simplicial_set  # Build UMAP fuzzy graph (advanced usage)
import umap                             # UMAP embedding for visualization / DR
import os                               # OS utilities (paths, files)
import math                             # Math functions
import warnings                         # Control/suppress warnings
warnings.filterwarnings('ignore')       # Hide non-critical warnings for cleaner logs
import gzip                             # Read/write .gz compressed files
import json                             # Read/write JSON (configs, metadata)
import pickle                           # Serialize/deserialize Python objects
import scanpy as sc                     # (Duplicate) Already imported above
from scipy.sparse import csr_matrix     # Efficient sparse matrices (memory-light)

# import tensorflow as tf               # (Optional) Deep learning backend (commented out)
import sys                              # Access Python path, argv, version, etc.
sys.path.append('/home/shamini/')       # Add custom code directory to Python path

from sklearn.decomposition import PCA   # (Duplicate) PCA already imported above
from sklearn.neighbors import NearestNeighbors  # (Duplicate) Already imported above
import igraph as ig                     # (Duplicate) Already imported above
import leidenalg as la                  # (Duplicate) Already imported above
# from umap.umap_ import fuzzy_simplicial_set  # (Duplicate) Already imported above
import umap                             # (Duplicate) Already imported above
from sklearn.model_selection import train_test_split  # Split data into train/test sets

import gc                               # Manual garbage collection (free memory)

In [None]:
working_dir = '/home/mystique27m/ext_gpu_hd/hackathon/'
# Main working directory for the hackathon project

main_out = working_dir+'out/'
# Main output directory to store all results

src_gene_tfs = main_out+'script01a_cleaned_tf_annot_gene_probes/'
# Source directory containing cleaned transcript files for gene probes

src_control_tfs = main_out+'script01a_cleaned_tf_annot_control_probes/'
# Source directory containing cleaned transcript files for control probes

### meta data
#gene_list_dir = working_dir+'metadata/version1/'
# (Optional) Directory where gene list metadata could be stored (currently commented out)

### destination output directories for figures and objects
out_fig_destdir = main_out+'script01b_figures/'
# Destination folder to save generated figures from script01b

out_obj_destdir = main_out+'script01b_output_objects/'
# Destination folder to save analysis output objects (main gene probe results)

out_obj_destdir_control = main_out+'script01b_output_objects_control/'
# Destination folder to save analysis output objects for control probes

os.makedirs(out_fig_destdir, exist_ok=True)
# Create the figures directory if it does not already exist

os.makedirs(out_obj_destdir, exist_ok=True)
# Create the objects directory for gene probes if it does not already exist

os.makedirs(out_obj_destdir_control, exist_ok=True)
# Create the objects directory for control probes if it does not already exist

In [None]:
colors_palette = [
    '#ebac23', #\n",
    '#b80058', #lipstick 0,140,249 \n",
    '#008cf9', #azure 0,110,0 \n",
    '#006e00', #green 0,187,173 \n",
    '#00bbad', #caribbean 209,99,230 \n",
    '#d163e6', #lavender 178,69,2 \n",
    '#b24502', #brown 255,146,135 \n",
    '#ff9287', #coral 89,84,214 \n",
    '#5954d6', #indigo 0,198,248 \n",
    '#00c6f8', #turquoise 135,133,0 \n",
    '#878500', #olive 0,167,108 \n",
    '#00a76c', #jade 189,189,189 \n",
    '#274d52', #plantation 199,162,166 \n",
    '#c7a2a6', #eunry 129,139,112 \n",
    '#818b70', #battleship 96,78,60 \n",
    '#604e3c', #kabul 140,159,183 \n",
    '#8c9fb7', #balihai 121,104,128 \n",
    '#796880', #rum,\n",
    '#56641a', #fernfrond 192,175,251 \n",
    '#c0affb', #perfume 230,161,118 \n",
    '#e6a176', #apricot 0,103,138 \n",
    '#00678a', #orient 152,68,100 \n",
    '#984464', #vinrouge 94,204,171 \n",
    '#5eccab', #downy\n",
    '#bdbdbd'] #gray\n"

In [None]:
adatas_filenames = os.listdir(src_gene_tfs)
# List all filenames in the 'src_gene_tfs' directory (cleaned gene probe transcript files)

In [None]:
original_xenium_file_path = main_out + 'script01_original_xenium_data/'
# Path to the original Xenium dataset (raw data + geometry segmentation)

for folder in adatas_filenames:
    # Loop through each file in the cleaned gene probe directory

    #xenium_tf_filepath = src_xenium + adatas_filenames[0]
    # Placeholder (unused) — example path for a xenium transcript file

    ### check to ensure folder is a file
    if not os.path.isfile(src_gene_tfs+f'/{folder}'):
        continue  # Skip if it’s not a file (e.g., if it’s a directory)

    xe_tf = pd.read_csv(src_gene_tfs+f'/{folder}')
    # Load the transcript file into a DataFrame

    ### check to ensure that file only contains gene probes
    print(f'File contains {xe_tf["group"].unique()}')
    print(f'File contains {xe_tf["binary"].unique()}')
    # Print which groups (gene/neg probes) and binary labels (assigned/unassigned) are present

    ### keep only assigned transcripts
    xe_tf = xe_tf[xe_tf['binary']=='assigned']
    # Filter out unassigned transcripts → only keep assigned

    ### drop feature_names that contain 'Unassigned'
    print(f'File contains {xe_tf["feature_name"].nunique()}')
    xe_tf = xe_tf[~xe_tf['feature_name'].str.contains('Unassigned')]
    print(f'File contains {xe_tf["feature_name"].nunique()}')
    # Drop transcripts with 'Unassigned' in their feature_name column

    ### first create adata objects from original xenium datasets
    centroids = xe_tf.groupby('cell_id')[['x_location', 'y_location', 'z_location']].mean().reset_index()
    centroids.rename(columns={'x_location':'x', 'y_location':'y', 'z_location':'z'}, inplace=True)
    # Compute mean centroid coordinates (x,y,z) per cell

    cell_gene_counts = xe_tf.groupby('cell_id').size().reset_index(name='n_counts')
    # Count number of transcripts per cell

    gene_mtx = xe_tf[['cell_id', 'feature_name']]
    gene_mtx = pd.pivot_table(gene_mtx, index='cell_id', columns='feature_name', aggfunc='size', fill_value=0)
    # Create cell × gene count matrix (pivot table)

    gene_mtx_bool = gene_mtx > 0
    n_genes = gene_mtx_bool.sum(axis=1).reset_index(name='n_genes')
    # Count number of unique genes per cell

    gene_mtx.rename_axis('', axis=1, inplace=True)
    # Remove axis name from columns (cleaner matrix)

    obs = pd.merge(centroids, cell_gene_counts, on='cell_id')
    obs = pd.merge(obs, n_genes, on='cell_id')
    obs.set_index('cell_id', inplace=True)
    obs.rename_axis('', axis=0, inplace=True)
    # Create metadata (obs): cell centroids + transcript counts + gene counts

    ### create adata objects
    gene_mtx = gene_mtx.reindex(obs.index)
    # Align gene matrix index with obs

    csr_mtx = csr_matrix(gene_mtx)
    # Convert dense gene expression matrix into sparse format (efficient)

    adata = sc.AnnData(X=csr_mtx, obs=obs, var=pd.DataFrame(index=gene_mtx.columns))
    # Create AnnData object with expression matrix (X), cell metadata (obs), and gene features (var)

    adata.var['feature_names'] = adata.var.index
    adata.var_names = adata.var.index
    adata.obs['sample_id'] = folder.split('_')[0]
    adata.uns['sample_id'] = folder.split('_')[0]
    # Annotate AnnData object with sample information

    adata.obs.index = adata.obs.index.astype(str)
    adata.obs.index.name = 'cell_id'
    # Ensure cell IDs are strings

    adata.var.index = adata.var.index.astype(str)
    adata.var.index.name = 'feature_name'
    # Ensure feature names are strings

    # --- (Optional geometry) ---
    #nuc_geom = pd.read_parquet([... nucleus segmentation file ...])
    #cyto_geom = pd.read_parquet([... cytoplasm segmentation file ...])
    #adata.uns['xe_nuc_polygon'] = nuc_geom
    #adata.uns['xe_cyto_polygon'] = cyto_geom
    # Attach nucleus/cytoplasm segmentation polygons if needed

    adata.write_h5ad(out_obj_destdir+'adata.h5ad')
    # Save AnnData object to disk in H5AD format

    del adata, gene_mtx_bool, gene_mtx, csr_mtx, cell_gene_counts, n_genes, centroids, xe_tf
    # Free up memory by deleting intermediate objects

    gc.collect()
    # Force garbage collection

In [None]:
### plot gene counts against total number of transcripts
fig, axes = plt.subplots(1, 3, figsize=(21, 6.5))  # Three side-by-side plots (scatter + two histograms)

median_gene_counts = []   # Store per-sample medians of n_genes
median_n_counts = []      # Store per-sample medians of n_counts
q99_gene_counts = []      # Store per-sample 99th percentile of n_genes
q99_n_counts = []         # Store per-sample 99th percentile of n_counts
qpt01_gene_counts = []    # Store per-sample 1st percentile of n_genes (optionally for QC)
qpt01_n_counts = []       # Store per-sample 1st percentile of n_counts (optionally for QC)

adatas_filenames = os.listdir(out_obj_destdir)  # List all saved .h5ad files to iterate over
#adatas_filenames = adatas_filenames[3:6]       # (Optional) Subset files for quick testing
#adatas_filenames = [file for file in adatas_filenames if '45' not in file]  # (Optional) Filter

for filename in adatas_filenames:
    adata = sc.read_h5ad(out_obj_destdir+f'{filename}')  # Load AnnData object

    # --- Plot 1: scatter of total transcripts vs number of genes per cell ---
    sns.scatterplot(x='n_counts', y='n_genes', data=adata.obs, ax=axes[0], s=1)  # Cell-level scatter
    sns.rugplot(x='n_counts', y='n_genes', data=adata.obs, ax=axes[0], alpha=0.1, clip_on=False, lw=0.1)  # Marginal density hints
    axes[0].set_xlabel('Total number of transcripts')  # X label
    axes[0].set_ylabel('Number of genes')              # Y label

    # --- Plot 2: histogram of total transcripts per cell ---
    sns.histplot(adata.obs['n_counts'], ax=axes[1])    # Distribution of n_counts
    axes[1].set_xlabel('Total number of transcripts')
    axes[1].set_ylabel('Frequency')

    # --- Plot 3: histogram of detected genes per cell ---
    sns.histplot(adata.obs['n_genes'], ax=axes[2])     # Distribution of n_genes
    axes[2].set_xlabel('Number of genes')
    axes[2].set_ylabel('Frequency')

    # Collect per-file summary stats for QC thresholds
    median_gene_counts.append(adata.obs['n_genes'].median())          # median genes/cell
    median_n_counts.append(adata.obs['n_counts'].median())            # median transcripts/cell
    q99_gene_counts.append(adata.obs['n_genes'].quantile(0.99))       # 99th pct genes/cell
    q99_n_counts.append(adata.obs['n_counts'].quantile(0.99))         # 99th pct transcripts/cell
    qpt01_gene_counts.append(adata.obs['n_genes'].quantile(0.01))     # 1st pct genes/cell (unused below)
    qpt01_n_counts.append(adata.obs['n_counts'].quantile(0.01))       # 1st pct transcripts/cell (unused below)

    del adata  # Free memory per iteration

# Aggregate QC thresholds across files
median_of_medians_gene_counts = np.median(median_gene_counts)  # Global robust center for n_genes
median_of_medians_n_counts = np.median(median_n_counts)        # Global robust center for n_counts

mean_q99_gene_counts = np.mean(q99_gene_counts)  # Average high-end (99th pct) for n_genes
mean_q99_n_counts = np.mean(q99_n_counts)        # Average high-end (99th pct) for n_counts

# Add reference lines to the scatter and histograms:
axes[0].axhline(median_of_medians_gene_counts, color='red', linestyle='--')   # Red = global median level (genes)
axes[0].axvline(median_of_medians_n_counts, color='red', linestyle='--')      # Red = global median level (transcripts)
axes[0].axhline(mean_q99_gene_counts, color='green', linestyle='--')          # Green = average 99th percentile (genes)
axes[0].axvline(mean_q99_n_counts, color='green', linestyle='--')             # Green = average 99th percentile (transcripts)

axes[1].axvline(median_of_medians_n_counts, color='red', linestyle='--')      # Median transcripts/cell
axes[1].axvline(mean_q99_n_counts, color='green', linestyle='--')             # 99th pct transcripts/cell

axes[2].axvline(median_of_medians_gene_counts, color='red', linestyle='--')   # Median genes/cell
axes[2].axvline(mean_q99_gene_counts, color='green', linestyle='--')          # 99th pct genes/cell

plt.tight_layout()  # Prevent overlap
sns.despine()       # Clean look: remove top/right spines

plt.savefig(out_fig_destdir+'figure02_gene_probes_qc_01.png', dpi=300, bbox_inches='tight')  # Save figure

In [None]:
### plot gene counts against total number of transcripts
fig, axes = plt.subplots(1, 1, figsize=(22, 4.5))  # Single wide plot for all samples

#median_gene_counts = []
#median_n_counts = []
#q99_gene_counts = []
#q99_n_counts = []

adatas_filenames = os.listdir(out_obj_destdir)  # List all .h5ad files to visualize
#adatas_filenames = adatas_filenames[3:6]       # (Optional) subset for quick testing
#adatas_filenames = [file for file in adatas_filenames if '45' not in file]  # (Optional) filter

for i, filename in enumerate(adatas_filenames):
    adata = sc.read_h5ad(out_obj_destdir+f'{filename}')  # Load AnnData object

    q999 = adata.obs['n_counts'].quantile(0.99)          # 99th percentile of transcripts/cell
    qpt01 = adata.obs['n_counts'].quantile(0.01)         # 1st percentile of transcripts/cell

    # Violin plot per sample; log_scale to handle heavy-tailed n_counts distributions
    sns.violinplot(
        x=adata.obs['sample_id'],                        # sample category (same value for all rows in this file)
        y=adata.obs['n_counts'],                         # transcripts/cell
        ax=axes,
        inner='stick',                                   # show individual observations as thin sticks
        inner_kws={'linewidth': 0.1},                    # thinner sticks
        color=colors_palette[i],                         # per-file color
        cut=0,                                           # don't extend beyond data range
        log_scale=True                                   # log-scale Y axis for better separation
    )

    ### plot a point at the 99th and 1st percentiles for this sample
    sns.pointplot(x=[adata.obs['sample_id'][0]], y=q999, ax=axes, markers='o', color=colors_palette[i])  # 99th pct marker
    sns.pointplot(x=[adata.obs['sample_id'][0]], y=qpt01, ax=axes, markers='o', color=colors_palette[i]) # 1st pct marker

    axes.set_xlabel('Samples')                           # X-axis label
    axes.set_ylabel('ncounts')                           # Y-axis label (total transcripts per cell)

    # Alternative (commented): scatter n_counts vs n_genes per sample
    #sns.scatterplot(x='n_counts', y='n_genes', data=adata.obs, ax=axes[i], s=1)
    #axes[i].set_xlabel('Total number of transcripts')
    #axes[i].set_ylabel('Number of genes')

    del adata                                            # Free memory per iteration

plt.tight_layout()                                       # Avoid layout overlap
sns.despine()                                            # Clean look (remove top/right spines)

plt.savefig(out_fig_destdir+'figure02_gene_probes_qc_02.png', dpi=300, bbox_inches='tight')  # Save figure

In [None]:
### write all range of median, mean and q95 values to a file

save_file_path = main_out+'script01b_output_summary_files/'
# Directory to save summary statistics text file

os.makedirs(save_file_path, exist_ok=True)
# Create the summary directory if it does not exist

range_median_n_counts = [np.max(median_n_counts), np.min(median_n_counts)]
# Range (max, min) of per-sample medians of total transcripts

range_median_gene_counts = [np.max(median_gene_counts), np.min(median_gene_counts)]
# Range (max, min) of per-sample medians of number of genes

range_q99_n_counts = [np.max(q99_n_counts), np.min(q99_n_counts)]
# Range of 99th percentile values of n_counts across samples

range_q99_gene_counts = [np.max(q99_gene_counts), np.min(q99_gene_counts)]
# Range of 99th percentile values of n_genes across samples

range_qpt01_n_counts = [np.max(qpt01_n_counts), np.min(qpt01_n_counts)]
# Range of 1st percentile n_counts → transcripts below this may be removed as low-quality

range_qpt01_gene_counts = [np.max(qpt01_gene_counts), np.min(qpt01_gene_counts)]
# Range of 1st percentile n_genes → genes below this may be removed as low-quality

median_of_medians_n_counts_q1 = np.median(qpt01_n_counts)
# Median of the 1st percentile transcript counts across samples

median_of_medians_gene_counts_q1 = np.median(qpt01_gene_counts)
# Median of the 1st percentile gene counts across samples

with open(save_file_path+'gene_counts_summary.txt', 'w') as f:
    # Open text file to write summary statistics

    f.write(f'Median of medians of gene counts: {median_of_medians_gene_counts}\n')
    # Write overall median of medians (genes)

    f.write(f'Median of medians of n counts: {median_of_medians_n_counts}\n')
    # Write overall median of medians (transcripts)

    f.write(f'Mean of q99 of gene counts: {mean_q99_gene_counts}\n')
    # Write mean of 99th percentile values for genes

    f.write(f'Mean of q99 of n counts: {mean_q99_n_counts}\n')
    # Write mean of 99th percentile values for transcripts

    f.write(f'Range of median of gene counts: {range_median_gene_counts}\n')
    # Write range of medians (genes)

    f.write(f'Range of median of n counts: {range_median_n_counts}\n')
    # Write range of medians (transcripts)

    f.write(f'Range of q99 of gene counts: {range_q99_gene_counts}\n')
    # Write range of 99th percentile values (genes)

    f.write(f'Range of q99 of n counts: {range_q99_n_counts}\n')
    # Write range of 99th percentile values (transcripts)

    f.write(f'Range of qpt01 of gene counts: {range_qpt01_gene_counts}\n')
    # Write range of 1st percentile values (genes)

    f.write(f'Range of qpt01 of n counts: {range_qpt01_n_counts}\n')
    # Write range of 1st percentile values (transcripts)

In [None]:
range_qpt01_gene_counts   # [max, min] of 1st percentile values of n_genes across samples
                          # → shows variation in the lowest detected genes per cell across datasets

range_qpt01_n_counts      # [max, min] of 1st percentile values of n_counts across samples
                          # → shows variation in the lowest transcript counts per cell across datasets

median_of_medians_gene_counts_q1   # Median of the 1st percentile n_genes values across all samples
                                   # → robust threshold for low gene counts (cells below may be filtered)

median_of_medians_n_counts_q1      # Median of the 1st percentile n_counts values across all samples
                                   # → robust threshold for low transcript counts (cells below may be filtered)

In [None]:
adatas_filenames = os.listdir(out_obj_destdir)                   # List saved .h5ad files to process
#adatas_filenames = adatas_filenames[3:6]                        # (Optional) subset for quick testing
#adatas_filenames = [file for file in adatas_filenames if '45' not in file]  # (Optional) filter by name

for folder in adatas_filenames:                                  # Iterate over each AnnData file

        adata = sc.read_h5ad(out_obj_destdir+f'{folder}')        # Load AnnData
        ### Perform normalization, transformation, and scaling on individual anndata objects
        q99_n_counts = adata.obs['n_counts'].quantile(0.99)      # 99th pct of transcripts/cell (upper outlier cutoff)
        q99_gene_counts = adata.obs['n_genes'].quantile(0.99)    # 99th pct of genes/cell (not used below, kept for QC)
        median_n_counts = adata.obs['n_counts'].median()         # Median transcripts/cell (target for normalization)

        ### remove cells with high number of transcripts
        print(adata.shape)                                       # Log current shape
        adata = adata[adata.obs['n_counts'] < q99_n_counts, :]   # Drop top 1% high-count cells (potential doublets/artifacts)
        print(adata.shape)                                       # Log after filter
        adata = adata[adata.obs['n_counts'] > median_of_medians_n_counts_q1, :]  # Drop very low-count cells (global low-quality)
        print(adata.shape)                                       # Log after filter
        adata = adata[adata.obs['n_genes'] > 3, :]               # Keep cells with >3 detected genes (basic QC floor)
        print(adata.shape)                                       # Log after filter

        ### copy counts to layers
        adata.raw = adata.copy()                                  # Preserve raw counts for downstream reference
        adata.layers["counts"] = adata.X.copy()                   # Store current X as 'counts' layer

        ### perform basic normalization, log1p transformation and scaling
        #adata = sc.pp.subsample(adata, fraction=0.001, copy=True) # (Optional) downsample cells for speed
        sc.pp.normalize_total(adata, target_sum=median_n_counts)  # Library-size normalize to median n_counts
        sc.pp.log1p(adata)                                        # Log-transform (stabilize variance)
        sc.pp.scale(adata)                                        # Z-score features (mean=0, var=1)

        sc.pp.pca(adata, n_comps=21)                              # PCA (retain 21 PCs)
        #sc.pp.neighbors(adata, n_neighbors=25)                   # (Optional) build KNN graph
        #sc.tl.umap(adata, min_dist=0.1, spread=1.0)              # (Optional) UMAP embedding
        #sc.tl.leiden(adata, resolution=.9)                       # (Optional) Leiden clustering

        adata.write_h5ad(out_obj_destdir+'adata.h5ad')            # Save processed AnnData (⚠ overwrites each loop)
        #del adata                                                # (Optional) explicit free; gc.collect will handle
        gc.collect()                                              # Trigger garbage collection to free memory

In [None]:
ct_files = os.listdir(src_control_tfs)                                   # List all control-probe transcript files

original_xenium_file_path = main_out + 'script01_original_xenium_data/'  # Path to original Xenium data (optional geometry)

for folder in ct_files:
    #xenium_tf_filepath = src_xenium + adatas_filenames[0]               # (Placeholder) example path (unused)
    ### read in xenium datasets
    #file_to_pull = folder.split('_')[[0,1]].join('_')                   # (Placeholder) intended sample name logic

    ### check to ensure folder is a file
    if not os.path.isfile(src_control_tfs+f'/{folder}'):                 # Skip directories or non-files
        continue

    xe_tf = pd.read_csv(src_control_tfs+f'/{folder}')                    # Load control-probe transcript CSV

    ### check to ensure that file only contains gene probes
    print(f'File contains {xe_tf["group"].unique()}')                    # Inspect 'group' labels present
    print(f'File contains {xe_tf["binary"].unique()}')                   # Inspect 'binary' labels present

    ### keep only assigned transcripts
    xe_tf = xe_tf[xe_tf['binary']=='assigned']                           # Keep only assigned transcripts

    ### drop feature_names that contain 'Unassigned
    print(f'File contains {xe_tf["feature_name"].nunique()}')            # Unique features before filtering
    xe_tf = xe_tf[~xe_tf['feature_name'].str.contains('Unassigned')]     # Remove unassigned features
    print(f'File contains {xe_tf["feature_name"].nunique()}')            # Unique features after filtering

    ### first create adata objects from original xenium datasets
    centroids = xe_tf.groupby('cell_id')[['x_location','y_location','z_location']].mean().reset_index()  # Mean centroid per cell
    centroids.rename(columns={'x_location':'x','y_location':'y','z_location':'z'}, inplace=True)         # Rename to x,y,z

    #cell_gene_counts = pd.DataFrame(xe_tf.groupby('cell_id').size()).rename_axis('', axis=0).rename(columns={0:'n_counts'}).reset_index().rename(columns={'' : 'cell_id'})
    cell_gene_counts = xe_tf.groupby('cell_id').size().reset_index(name='n_counts')   # Transcripts per cell
    gene_mtx = xe_tf[['cell_id','feature_name']]                                      # Cell-feature pairs
    gene_mtx = pd.pivot_table(gene_mtx, index='cell_id', columns='feature_name', aggfunc='size', fill_value=0)  # Cell×feature counts
    gene_mtx_bool = gene_mtx > 0                                                      # Presence/absence matrix
    n_genes = gene_mtx_bool.sum(axis=1).reset_index(name='n_genes')                   # #features detected per cell
    #n_genes = pd.DataFrame(gene_mtx_bool.sum(axis=1)).reset_index().rename(columns={0:'n_genes','index':'cell_id'})
    gene_mtx.rename_axis('', axis=1, inplace=True)                                    # Drop column axis name

    obs = pd.merge(centroids, cell_gene_counts, on='cell_id')                         # Merge centroids + counts
    obs = pd.merge(obs, n_genes, on='cell_id')                                        # Add n_genes
    obs.set_index('cell_id', inplace=True)                                            # Index by cell_id
    obs.rename_axis('', axis=0, inplace=True)                                         # Drop index axis name

    ### create adata objects
    gene_mtx = gene_mtx.reindex(obs.index)                                            # Align matrix rows to obs

    csr_mtx = csr_matrix(gene_mtx)                                                    # Convert to sparse for memory

    adata = sc.AnnData(X=csr_mtx, obs=obs, var=pd.DataFrame(index=gene_mtx.columns))  # Build AnnData object
    adata.var['feature_names'] = adata.var.index                                      # Store feature names
    adata.var_names = adata.var.index                                                 # Set var_names
    adata.obs['sample_id'] = folder.split('_')[0]                                     # Sample ID from filename
    adata.uns['sample_id'] = folder.split('_')[0]                                     # Also store in .uns

    adata.obs.index = adata.obs.index.astype(str)                                     # Ensure string indices
    adata.obs.index.name = 'cell_id'                                                  # Name index 'cell_id'

    adata.var.index = adata.var.index.astype(str)                                     # Ensure string feature names
    adata.var.index.name = 'feature_name'                                             # Name var index

    #nuc_geom = pd.read_parquet([... nucleus segmentation parquet ...])               # (Optional) nuclei polygons
    #cyto_geom = pd.read_parquet([... cell segmentation parquet ...])                 # (Optional) cell polygons
    #adata.uns['xe_nuc_polygon'] = nuc_geom                                           # Attach optional geometry
    #adata.uns['xe_cyto_polygon'] = cyto_geom

    adata.write_h5ad(out_obj_destdir_control+f'{"_".join(folder.split("_")[0:2])}_adata.h5ad')  # Save per-control AnnData
    del adata, gene_mtx_bool, gene_mtx, csr_mtx, cell_gene_counts, n_genes, centroids, xe_tf,   # Free intermediates
    #nuc_geom, cyto_geom
    gc.collect()                                                                         # Garbage collect to free RAM

In [None]:
### plot gene counts against total number of transcripts
fig, axes = plt.subplots(1, 3, figsize=(18, 4.5))
# Create 3 side-by-side plots (scatter + two histograms) for control probe QC

median_gene_counts = []   # Store per-sample medians of n_genes
median_n_counts = []      # Store per-sample medians of n_counts
q99_gene_counts = []      # Store per-sample 99th percentile of n_genes
q99_n_counts = []         # Store per-sample 99th percentile of n_counts

adata_ct_files = os.listdir(out_obj_destdir_control)
# List all control-probe .h5ad files

for filename in adata_ct_files:
    adata = sc.read_h5ad(out_obj_destdir_control+f'{filename}')
    # Load control AnnData object

    # --- Plot 1: scatter of transcripts vs genes ---
    sns.scatterplot(x='n_counts', y='n_genes', data=adata.obs, ax=axes[0], s=9)
    # Each point = a cell (transcripts vs genes detected)
    axes[0].set_xlabel('Total number of transcripts')
    axes[0].set_ylabel('Number of genes')

    # --- Plot 2: histogram of transcript counts ---
    sns.histplot(adata.obs['n_counts'], ax=axes[1])
    axes[1].set_xlabel('Total number of transcripts')
    axes[1].set_ylabel('Frequency')

    # --- Plot 3: histogram of gene counts ---
    sns.histplot(adata.obs['n_genes'], ax=axes[2])
    axes[2].set_xlabel('Number of genes')
    axes[2].set_ylabel('Frequency')

    # Collect per-file QC stats
    median_gene_counts.append(adata.obs['n_genes'].median())        # Median genes/cell
    median_n_counts.append(adata.obs['n_counts'].median())          # Median transcripts/cell
    q99_gene_counts.append(adata.obs['n_genes'].quantile(0.99))     # 99th pct genes/cell
    q99_n_counts.append(adata.obs['n_counts'].quantile(0.99))       # 99th pct transcripts/cell

    del adata  # Free memory

# --- Aggregate QC thresholds across all control files ---
median_of_medians_gene_counts = np.median(median_gene_counts)  # Global robust median for genes
median_of_medians_n_counts = np.median(median_n_counts)        # Global robust median for transcripts

mean_q99_gene_counts = np.mean(q99_gene_counts)  # Average high-end (99th pct) for genes
mean_q99_n_counts = np.mean(q99_n_counts)        # Average high-end (99th pct) for transcripts

# Add QC reference lines
axes[0].axhline(median_of_medians_gene_counts, color='red', linestyle='--')   # Red = global medians
axes[0].axvline(median_of_medians_n_counts, color='red', linestyle='--')
axes[0].axhline(mean_q99_gene_counts, color='green', linestyle='--')          # Green = average 99th percentiles
axes[0].axvline(mean_q99_n_counts, color='green', linestyle='--')

axes[1].axvline(median_of_medians_n_counts, color='red', linestyle='--')      # Medians and 99th pct on histograms
axes[1].axvline(mean_q99_n_counts, color='green', linestyle='--')

axes[2].axvline(median_of_medians_gene_counts, color='red', linestyle='--')
axes[2].axvline(mean_q99_gene_counts, color='green', linestyle='--')

plt.tight_layout()   # Adjust layout
sns.despine()        # Clean plot (remove top/right spines)

plt.savefig(out_fig_destdir+'figure02_control_probes_qc_01.png', dpi=300, bbox_inches='tight')
# Save QC figure for control probes