# Path-Driven Xenium Analysis Notebook

This notebook replicates the core analysis from `/Users/chrislangseth/work/spatialist/intership1/00-batch-processing.ipynb` and is designed so you only need to change the dataset path.

## How to use
1. Open this notebook in Jupyter Lab or VS Code.
2. Run cells from top to bottom.
3. In **Step 1**, set `DATA_DIR` to your new dataset folder.
4. Check generated outputs under `OUT_DIR`.

## Expected input folder layout
- `DATA_DIR/`
  - `output-.../`
    - `cell_feature_matrix.h5`
    - `cells.csv.gz`

If your run folder names differ, adjust `RUN_PREFIX` in Step 1.

## Step 0: Imports and plotting setup

This step loads the analysis libraries (`scanpy`, `pandas`, `seaborn`, etc.) and configures plotting.

In [None]:
import warnings
warnings.filterwarnings('ignore')

from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
import seaborn as sns
from scipy import sparse

plt.rcParams['figure.dpi'] = 120
sns.set_style('whitegrid')

## Step 1: Configure your paths and parameters

- `DATA_DIR`: your new dataset path.
- `OUT_DIR`: where all outputs will be saved.
- `RUN_PREFIX`: run folder naming pattern.
- `SAMPLE_ID_SPLIT` and `SAMPLE_ID_INDEX`: controls how `sample_id` is extracted from run names.

In [None]:
# REQUIRED: change this to your new dataset location.
DATA_DIR = Path('/absolute/path/to/new/dataset').expanduser().resolve()

# Optional: change this output directory.
OUT_DIR = Path('/absolute/path/to/output').expanduser().resolve()

# Run folder prefix (reference workflow used output-*).
RUN_PREFIX = 'output-'

# Extract sample_id from run name, e.g. run.split('__')[2]
SAMPLE_ID_SPLIT = '__'
SAMPLE_ID_INDEX = 2

# Clustering / preprocessing parameters (same spirit as reference notebook)
MIN_COUNTS = 50
MIN_GENES = 15
TARGET_SUM = 100
N_NEIGHBORS = 15
N_PCS = 30
UMAP_MIN_DIST = 0.1
LEIDEN_RESOLUTIONS = [0.1, 0.5, 1, 1.5, 2]

DATA_OUT_DIR = OUT_DIR / 'data'
QC_DIR = OUT_DIR / 'xenium_qc'
DATA_OUT_DIR.mkdir(parents=True, exist_ok=True)
QC_DIR.mkdir(parents=True, exist_ok=True)

print(f'DATA_DIR: {DATA_DIR}')
print(f'OUT_DIR : {OUT_DIR}')

## Step 2: Discover and validate run folders

This step finds all run folders (for example `output-*`) and checks each has:
- `cell_feature_matrix.h5`
- `cells.csv.gz`

Invalid runs are skipped.

In [None]:
if not DATA_DIR.is_dir():
    raise NotADirectoryError(f'Invalid DATA_DIR: {DATA_DIR}')

runs = [p for p in sorted(DATA_DIR.iterdir()) if p.is_dir() and p.name.startswith(RUN_PREFIX)]
if not runs:
    raise FileNotFoundError(f'No run directories starting with {RUN_PREFIX!r} in {DATA_DIR}')

valid_runs = []
for run in runs:
    h5_path = run / 'cell_feature_matrix.h5'
    cell_info_path = run / 'cells.csv.gz'
    if h5_path.exists() and cell_info_path.exists():
        valid_runs.append(run)
    else:
        print(f'Skipping {run.name}: missing required files')

if not valid_runs:
    raise RuntimeError('No valid runs found with required files.')

print(f'Found {len(valid_runs)} valid run(s):')
for r in valid_runs:
    print(' -', r.name)

## Step 3: Load each run and concatenate into one AnnData object

For each run:
1. Read the expression matrix from `cell_feature_matrix.h5`.
2. Read cell metadata from `cells.csv.gz`.
3. Attach metadata to `ad_int.obs`.
4. Add `run` column and append to list.

Finally, all runs are concatenated into `ad`.

In [None]:
ad_list = []

for run in valid_runs:
    print(f'Loading: {run.name}')
    ad_int = sc.read_10x_h5(str(run / 'cell_feature_matrix.h5'))
    cell_info = pd.read_csv(run / 'cells.csv.gz', index_col=0)

    if len(cell_info) != ad_int.n_obs:
        raise ValueError(
            f'Row mismatch for {run.name}: cells.csv.gz={len(cell_info)} vs matrix={ad_int.n_obs}'
        )

    ad_int.obs = cell_info
    ad_int.obs['run'] = run.name
    ad_list.append(ad_int)

ad = sc.concat(ad_list)
ad.layers['counts'] = ad.X.copy()

print(ad)

## Step 4: Compute QC metrics and create `sample_id`

- `sc.pp.calculate_qc_metrics` adds standard QC columns such as `total_counts` and `n_genes_by_counts`.
- `sample_id` is parsed from `run` name. If parsing fails, run name itself is used.

In [None]:
sc.pp.calculate_qc_metrics(ad, percent_top=None, log1p=False, inplace=True)

def infer_sample_id(run_name: str, split_token: str = SAMPLE_ID_SPLIT, split_index: int = SAMPLE_ID_INDEX) -> str:
    parts = str(run_name).split(split_token)
    if 0 <= split_index < len(parts):
        return parts[split_index]
    return str(run_name)

ad.obs['sample_id'] = ad.obs['run'].astype(str).apply(infer_sample_id)
ad.obs['cell_id'] = ad.obs.index.astype(str)

display(ad.obs[['run', 'sample_id']].head())

## Step 5: Save raw AnnData

This saves the merged, unfiltered object so you can reuse it later without reloading all runs.

In [None]:
raw_path = DATA_OUT_DIR / 'raw.h5ad'
ad.write(raw_path)
print(f'Saved: {raw_path}')

## Step 6: Generate QC tables and plots

This creates the same QC-style outputs as the reference workflow:
- `summary_by_run.csv`
- Cells-per-run bar plot
- Violin plots (`n_genes_by_counts`, `total_counts`)
- Counts-vs-genes hexbin
- Gene detection outputs (`gene_detection_overall.csv`, top30 bar, run heatmap)
- Optional cell area violin plot if an area column exists

In [None]:
RUN_COL = 'run'
SAMPLE_COL = 'sample_id'
CT_COL = 'cell_types'
COUNTS_COL = 'total_counts'
NGENES_COL = 'n_genes_by_counts'

def _p(series, p):
    return float(np.nanpercentile(series, p))

agg_dict = {'n_cells': ('cell_id', 'count')}
if COUNTS_COL in ad.obs.columns:
    agg_dict |= {
        'counts_mean': (COUNTS_COL, 'mean'),
        'counts_median': (COUNTS_COL, 'median'),
        'counts_p10': (COUNTS_COL, lambda x: _p(x, 10)),
        'counts_p90': (COUNTS_COL, lambda x: _p(x, 90)),
    }
if NGENES_COL in ad.obs.columns:
    agg_dict |= {
        'genes_mean': (NGENES_COL, 'mean'),
        'genes_median': (NGENES_COL, 'median'),
        'genes_p10': (NGENES_COL, lambda x: _p(x, 10)),
        'genes_p90': (NGENES_COL, lambda x: _p(x, 90)),
    }

summary = ad.obs.groupby(SAMPLE_COL).agg(**agg_dict).sort_values('n_cells', ascending=False)
display(summary)
summary.to_csv(QC_DIR / 'summary_by_run.csv')

plt.figure(figsize=(9, 4.5))
sns.barplot(y=summary.index, x=summary['n_cells'], palette='Set3')
plt.title('Cells per Xenium run')
plt.xlabel('# cells')
plt.ylabel('Run')
plt.tight_layout()
plt.savefig(QC_DIR / 'cells_per_run_bar.png', dpi=200)
plt.show()

if NGENES_COL in ad.obs.columns:
    plt.figure(figsize=(10, 4.5))
    sns.violinplot(data=ad.obs, x=SAMPLE_COL, y=NGENES_COL, inner='quartile', palette='rocket')
    plt.title('n_genes_by_counts per run')
    plt.xlabel('Run')
    plt.ylabel('n_genes_by_counts')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(QC_DIR / 'ngenes_violin.png', dpi=200)
    plt.show()

if COUNTS_COL in ad.obs.columns:
    plt.figure(figsize=(10, 4.5))
    sns.violinplot(data=ad.obs, x=SAMPLE_COL, y=COUNTS_COL, inner='quartile', palette='mako')
    plt.title('total_counts per run')
    plt.xlabel('Run')
    plt.ylabel('total_counts')
    plt.xticks(rotation=45, ha='right')
    plt.tight_layout()
    plt.savefig(QC_DIR / 'counts_violin.png', dpi=200)
    plt.show()

if {COUNTS_COL, NGENES_COL}.issubset(ad.obs.columns):
    plt.figure(figsize=(6, 5))
    plt.hexbin(ad.obs[COUNTS_COL], ad.obs[NGENES_COL], gridsize=50, mincnt=1)
    plt.xlabel('total_counts')
    plt.ylabel('n_genes_by_counts')
    plt.title('Counts vs genes (all runs)')
    plt.tight_layout()
    plt.savefig(QC_DIR / 'counts_vs_genes_hex.png', dpi=200)
    plt.show()

if CT_COL in ad.obs.columns:
    ct_counts = ad.obs[CT_COL].value_counts()
    plt.figure(figsize=(9, 5))
    sns.barplot(y=ct_counts.index[:20], x=ct_counts.values[:20], palette='Spectral')
    plt.title('Top cell types (all runs)')
    plt.xlabel('# cells')
    plt.ylabel('cell type')
    plt.tight_layout()
    plt.savefig(QC_DIR / 'celltypes_top20.png', dpi=200)
    plt.show()

x = ad.X
if sparse.issparse(x):
    detected = (x > 0).astype(np.int8)
    det_overall = np.array(detected.sum(axis=0)).ravel() / ad.n_obs
else:
    det_overall = np.asarray((x > 0).sum(axis=0)).ravel() / ad.n_obs

det_overall_series = pd.Series(det_overall, index=ad.var_names, name='fraction_cells')
det_overall_series.sort_values(ascending=False).to_csv(QC_DIR / 'gene_detection_overall.csv')

top30 = det_overall_series.sort_values(ascending=False).head(30)
plt.figure(figsize=(8, 5))
sns.barplot(y=top30.index, x=top30.values, palette='coolwarm')
plt.xlabel('fraction of cells detected')
plt.ylabel('gene')
plt.title('Panel coverage: top 30 genes')
plt.tight_layout()
plt.savefig(QC_DIR / 'gene_detection_top30.png', dpi=200)
plt.show()

runs_array = ad.obs[SAMPLE_COL].astype(str).values
run_idx = {name: np.where(runs_array == name)[0] for name in np.unique(runs_array)}
det_run = {}
for run_name, idx in run_idx.items():
    if len(idx) == 0:
        continue
    if sparse.issparse(x):
        sub = x[idx, :]
        frac = np.array((sub > 0).sum(axis=0)).ravel() / len(idx)
    else:
        sub = x[idx, :]
        frac = np.asarray((sub > 0).sum(axis=0)).ravel() / len(idx)
    det_run[run_name] = frac

if det_run:
    det_df = pd.DataFrame(det_run, index=ad.var_names).T
    sel = det_df.var(axis=0).sort_values(ascending=False).head(40).index
    plt.figure(figsize=(min(12, len(sel) * 0.3 + 4), max(4, len(det_df) * 0.35 + 2)))
    sns.heatmap(det_df[sel], cmap='rocket', vmin=0, vmax=1, cbar_kws={'label': 'fraction detected'})
    plt.title('Gene detection per run (top variable genes)')
    plt.xlabel('gene')
    plt.ylabel('run')
    plt.tight_layout()
    plt.savefig(QC_DIR / 'gene_detection_heatmap_runs.png', dpi=200)
    plt.show()

for cand in ['cell_area_um2', 'cell_area', 'area', 'nucleus_area_um2']:
    if cand in ad.obs.columns:
        plt.figure(figsize=(7, 4))
        sns.violinplot(data=ad.obs, x=SAMPLE_COL, y=cand, inner='quartile', palette='PuBuGn')
        plt.title(f'{cand} per run')
        plt.xlabel('Run')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.savefig(QC_DIR / f'{cand}_violin.png', dpi=200)
        plt.show()
        break

print(f'QC outputs saved to: {QC_DIR}')

## Step 7: Filter, normalize, reduce dimensions, and cluster

This performs the same core downstream processing pattern as the reference notebook. Below is the mathematical purpose of each step:

1. **Filter low-quality cells**
   - Removes extreme low-information points (very low counts/genes) that mostly add technical noise.
   - This improves signal-to-noise before any geometry-based methods.

2. **Normalize total counts and log-transform**
   - Each cell has different sequencing depth (library size), so raw counts are not directly comparable.
   - Size-factor normalization rescales each cell to a common total (`target_sum`):
     - If cell `i` has total `c_i`, scaled expression is approximately `x_ig * (target_sum / c_i)`.
   - `log1p` then compresses dynamic range (`log(1+x)`), reducing dominance of very highly expressed genes.

3. **PCA (linear dimensionality reduction)**
   - Gene space is very high-dimensional. Distances become unstable in very high dimensions (curse of dimensionality).
   - PCA projects data onto orthogonal directions of maximal variance (top eigenvectors of covariance matrix).
   - Keeping top PCs denoises data while preserving most biological structure.

4. **kNN graph in PC space (`sc.pp.neighbors`)**
   - For each cell, connect to its `k` nearest neighbors in PCA space.
   - This builds a graph that approximates the local manifold geometry of cell states.
   - Downstream methods (UMAP, Leiden) operate on this graph, not raw expression directly.

5. **UMAP (nonlinear embedding for visualization)**
   - UMAP constructs a fuzzy neighbor graph in high-dimensional space, then finds a 2D embedding with similar fuzzy connectivity.
   - It optimizes a cross-entropy-like objective between high-D and low-D neighbor probabilities.
   - Main purpose: human-interpretable visualization, not formal statistical inference.

6. **Leiden clustering (graph community detection)**
   - Leiden partitions the neighbor graph into communities with dense within-cluster edges and sparse between-cluster edges.
   - It optimizes a graph quality function (modularity/CPM family) and improves partition connectivity over older methods.
   - Resolution controls granularity: lower = broader groups, higher = finer subclusters.

In [None]:
ad_cluster = ad.copy()

sc.pp.filter_cells(ad_cluster, min_counts=MIN_COUNTS)
sc.pp.filter_cells(ad_cluster, min_genes=MIN_GENES)

sc.pp.normalize_total(ad_cluster, target_sum=TARGET_SUM, inplace=True)
sc.pp.log1p(ad_cluster)

sc.tl.pca(ad_cluster)
sc.pp.neighbors(ad_cluster, n_neighbors=N_NEIGHBORS, n_pcs=N_PCS)
sc.tl.umap(ad_cluster, min_dist=UMAP_MIN_DIST)

for resolution in LEIDEN_RESOLUTIONS:
    key = f'leiden_{resolution}'
    sc.tl.leiden(ad_cluster, resolution=resolution, key_added=key)
    print(f'Computed {key}')

ad_cluster

## Step 8: UMAP quick checks and marker genes

- Plots UMAP colored by sample and final cluster label.
- Computes marker genes with `rank_genes_groups`.
  - Here we use a t-test per gene (cluster vs rest) to rank genes by differential expression signal.
  - This gives interpretable candidate markers for cluster annotation.
- Saves clustered object and markers table.

In [None]:
final_leiden_key = f'leiden_{LEIDEN_RESOLUTIONS[-1]}'

sc.pl.umap(ad_cluster, color=['sample_id', final_leiden_key], wspace=0.35)

sc.tl.rank_genes_groups(ad_cluster, groupby=final_leiden_key, method='t-test')
markers = sc.get.rank_genes_groups_df(ad_cluster, group=None)
display(markers.head())

clustered_path = DATA_OUT_DIR / 'clustered.h5ad'
markers_path = DATA_OUT_DIR / 'markers_by_cluster.csv'

ad_cluster.write(clustered_path)
markers.to_csv(markers_path, index=False)

print(f'Saved clustered AnnData: {clustered_path}')
print(f'Saved marker genes   : {markers_path}')

## Troubleshooting

- If you get `No run directories...`, verify `DATA_DIR` and `RUN_PREFIX`.
- If row mismatch appears in Step 3, one run has inconsistent matrix vs metadata files.
- If `sample_id` values look wrong, adjust `SAMPLE_ID_SPLIT` and `SAMPLE_ID_INDEX`.
- If `scanpy` is missing, install dependencies in your environment before rerunning.