# Atlas Matching: Patch-seq Cells to Allen Brain Atlas

This notebook maps Patch-seq neurons onto the Allen Brain Cell Atlas (ABC) reference
using the `patch_seq_transcriptome_mapping` pipeline. The resulting cell type assignments
are then integrated with the GENCIC genetic bias analysis framework.

## Pipeline Steps
1. Load reference atlas (Allen ABC) and Patch-seq query data
2. Preprocess: gene intersection, HVG selection, normalization
3. Train scVI/scANVI model on reference atlas
4. Map query Patch-seq cells via scArches surgery
5. Hierarchical cell type assignment (subclass + cluster)
6. Compute canonical expression profiles
7. Integrate with GENCIC bias analysis

In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Project paths
PROJECT_ROOT = Path("/home/jw3514/Work/ASD_Circuits_CellType/")  # Change to your project root
ATLAS_MATCHING_ROOT = PROJECT_ROOT / "atlas_matching"

# Add paths for imports
sys.path.insert(0, str(ATLAS_MATCHING_ROOT))
sys.path.insert(0, str(PROJECT_ROOT / "src"))

import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import matplotlib.pyplot as plt
import seaborn as sns

# Atlas matching modules
from patch_seq_transcriptome_mapping import (
    load_reference_atlas,
    load_patchseq_v1,
    load_patchseq_m1,
    load_ion_channel_genes,
    intersect_genes,
    preprocess_pipeline,
    train_scvi,
    train_scanvi,
    map_query_scarches,
    hierarchical_assignment,
    compute_canonical_pipeline,
)

# GENCIC modules
from ASD_Circuits import *

print(f"Project root: {PROJECT_ROOT}")
print(f"Atlas matching root: {ATLAS_MATCHING_ROOT}")

## 1. Configuration

Set paths to data directories. The atlas data lives under the CellType_Psy project,
and the Patch-seq data under TransEphys.

In [None]:
# Data directories - adjust these to your setup
ATLAS_DATA_DIR = Path("/home/jw3514/Work/CellType_Psy/AllenBrainCellAtlas/dat/")
PATCHSEQ_DATA_DIR = Path("/home/jw3514/Work/NeurSim/TransEphys/dat/expression/")
ION_CHANNEL_DIR = PATCHSEQ_DATA_DIR

# Results directories
RESULTS_DIR = ATLAS_MATCHING_ROOT / "results"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

# GENCIC data
GENCIC_DATA_DIR = PROJECT_ROOT / "dat"

# Which Patch-seq dataset to process
DATASET = "V1"  # or "M1"
SUBSAMPLE_ATLAS = 50000  # Set to None for full atlas (requires ~100GB RAM)

print(f"Atlas data: {ATLAS_DATA_DIR}")
print(f"Patch-seq data: {PATCHSEQ_DATA_DIR}")
print(f"Results: {RESULTS_DIR}")

## 2. Load Data

In [None]:
# Load reference atlas (Allen Brain Cell Atlas - Isocortex)
print("Loading reference atlas...")
adata_ref = load_reference_atlas(
    ATLAS_DATA_DIR,
    subsample=SUBSAMPLE_ATLAS
)
print(f"Reference atlas: {adata_ref.shape}")
print(f"Obs columns: {list(adata_ref.obs.columns[:10])}")

In [None]:
# Load Patch-seq dataset
print(f"\nLoading {DATASET} Patch-seq data...")
if DATASET == "V1":
    adata_query, metadata_query, _ = load_patchseq_v1(PATCHSEQ_DATA_DIR, load_ephys=False)
elif DATASET == "M1":
    adata_query, metadata_query, _ = load_patchseq_m1(PATCHSEQ_DATA_DIR, load_ephys=False)
else:
    raise ValueError(f"Unknown dataset: {DATASET}")

print(f"Query: {adata_query.shape}")
print(f"Metadata columns: {list(metadata_query.columns[:10])}")

In [None]:
# Load ion channel genes (optional, for biophysical modeling)
print("\nLoading ion channel genes...")
try:
    ion_channel_genes = load_ion_channel_genes(ION_CHANNEL_DIR, source='both')
    print(f"Ion channel genes: {len(ion_channel_genes['all'])}")
except FileNotFoundError:
    print("Ion channel gene files not found, using empty list")
    ion_channel_genes = {'all': []}

## 3. Preprocess Data

In [None]:
# Preprocess: gene intersection, normalization, HVG selection, scaling
results = preprocess_pipeline(
    adata_ref,
    adata_query,
    ion_channel_genes=ion_channel_genes.get('all', []),
    n_hvgs=3000,
    n_hvgs_query=500,
    ref_is_log2=False,  # Reference is raw counts
    return_log_normalized=True
)

adata_ref_pp, adata_query_pp, final_genes, gene_means, gene_stds, adata_ref_log, adata_query_log = results

print(f"\nPreprocessed reference: {adata_ref_pp.shape}")
print(f"Preprocessed query: {adata_query_pp.shape}")
print(f"Final gene set: {len(final_genes)} genes")
print(f"Counts layer preserved: ref={adata_ref_pp.layers.get('counts') is not None}")

## 4. Train Integration Model

Choose between scVI (unsupervised) or scANVI (semi-supervised, recommended).

In [None]:
# Option A: Train scVI (unsupervised)
USE_SCANVI = False  # Set to True for scANVI (better accuracy)

MODEL_DIR = RESULTS_DIR / "models"
MODEL_DIR.mkdir(parents=True, exist_ok=True)

if USE_SCANVI:
    # scANVI: semi-supervised, uses reference cell type labels
    model, adata_ref_trained = train_scanvi(
        adata_ref_log,  # Use log-normalized (not Z-scored) for scANVI
        labels_key='subclass',
        n_latent=50,
        n_hidden=256,
        max_epochs=400,
        use_gpu=True
    )
    model_name = f"scanvi_{DATASET}"
else:
    # scVI: unsupervised
    model, adata_ref_trained = train_scvi(
        adata_ref_log,  # Use log-normalized data with counts layer
        n_latent=50,
        n_hidden=256,
        max_epochs=400,
        use_gpu=True
    )
    model_name = f"scvi_{DATASET}"

# Save model
model_path = MODEL_DIR / model_name
model.save(str(model_path), overwrite=True)
print(f"Model saved to: {model_path}")

## 5. Map Query Cells via scArches Surgery

In [None]:
# Map query cells into the reference latent space
adata_ref_mapped, adata_query_mapped = map_query_scarches(
    model=str(model_path),  # Pass model path for scArches surgery
    adata_ref=adata_ref_trained,
    adata_query=adata_query_log,
    use_gpu=True,
    max_epochs=200
)

print(f"Reference latent shape: {adata_ref_mapped.obsm['X_latent'].shape}")
print(f"Query latent shape: {adata_query_mapped.obsm['X_latent'].shape}")

## 6. Hierarchical Cell Type Assignment

In [None]:
# Perform hierarchical assignment: subclass (coarse) + cluster (fine)
mapping_results = hierarchical_assignment(
    adata_ref_mapped,
    adata_query_mapped,
    latent_key='X_latent',
    subclass_key='subclass',
    cluster_key='cluster',
    k_candidates=30,
    k_nn=20,
    conf_threshold_subclass=0.7,
    conf_threshold_cluster=0.5
)

print(f"\nMapping results shape: {mapping_results.shape}")
print(f"\nMapping status distribution:")
print(mapping_results['mapping_status'].value_counts())
print(f"\nTop assigned subclasses:")
print(mapping_results['assigned_subclass'].value_counts().head(10))

In [None]:
# Save mapping results
mapped_dir = RESULTS_DIR / "scvi_mapped" / model_name
mapped_dir.mkdir(parents=True, exist_ok=True)
mapping_results.to_csv(mapped_dir / f"{DATASET}_mapping_results.csv")
print(f"Mapping results saved to: {mapped_dir}")

## 7. Compute Canonical Expression Profiles

In [None]:
# Compute canonical expression profiles for mapped cells
canonical_cluster, canonical_subclass, effective_expr = compute_canonical_pipeline(
    adata_ref_mapped,
    adata_query_mapped,
    mapping_results,
    use_soft_assignment=True,
    k_nn=20
)

print(f"\nCluster profiles: {canonical_cluster.shape}")
print(f"Subclass profiles: {canonical_subclass.shape}")
print(f"Query effective expression: {effective_expr.shape}")

# Save canonical profiles
canonical_dir = RESULTS_DIR / "canonical" / model_name
canonical_dir.mkdir(parents=True, exist_ok=True)
canonical_cluster.to_csv(canonical_dir / f"{DATASET}_canonical_cluster.csv")
canonical_subclass.to_csv(canonical_dir / f"{DATASET}_canonical_subclass.csv")
effective_expr.to_csv(canonical_dir / f"{DATASET}_effective_expression.csv")
print(f"Canonical profiles saved to: {canonical_dir}")

## 8. Visualize Mapping Results

In [None]:
# Visualize mapping confidence distribution
fig, axes = plt.subplots(1, 2, figsize=(12, 5), dpi=150)

axes[0].hist(mapping_results['conf_subclass'], bins=50, color='steelblue', edgecolor='black', alpha=0.7)
axes[0].axvline(0.7, color='red', linestyle='--', label='Threshold (0.7)')
axes[0].set_xlabel('Confidence')
axes[0].set_ylabel('Count')
axes[0].set_title('Subclass Assignment Confidence')
axes[0].legend()

axes[1].hist(mapping_results['conf_cluster'], bins=50, color='coral', edgecolor='black', alpha=0.7)
axes[1].axvline(0.5, color='red', linestyle='--', label='Threshold (0.5)')
axes[1].set_xlabel('Confidence')
axes[1].set_ylabel('Count')
axes[1].set_title('Cluster Assignment Confidence')
axes[1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Visualize cell type distribution of mapped Patch-seq cells
fig, ax = plt.subplots(figsize=(10, 6), dpi=150)

subclass_counts = mapping_results['assigned_subclass'].value_counts().head(20)
subclass_counts.plot(kind='barh', ax=ax, color='steelblue', edgecolor='black')
ax.set_xlabel('Number of Cells')
ax.set_ylabel('Subclass')
ax.set_title(f'{DATASET} Patch-seq Cells: Assigned Subclass Distribution')
ax.invert_yaxis()
plt.tight_layout()
plt.show()

## 9. Integration with GENCIC Genetic Bias Analysis

Now connect the Patch-seq cell type assignments to the GENCIC bias framework.
This allows us to compare which Patch-seq cell types carry the highest ASD genetic bias.

In [None]:
# Load GENCIC data: cell type Z2 bias matrix and cluster annotations
ClusterAnn = pd.read_csv(GENCIC_DATA_DIR / "MouseCT_Cluster_Anno.csv", index_col="cluster_id_label")
CellTypesDF = pd.read_csv(GENCIC_DATA_DIR / "CellTypeHierarchy.csv")

# Build class/subclass to cluster mappings
Class2Cluster = {}
Subclass2Cluster = {}
for _, row in CellTypesDF.iterrows():
    _cluster, _class, _subclass, _supertype = row
    Class2Cluster.setdefault(_class, []).append(_cluster)
    Subclass2Cluster.setdefault(_subclass, []).append(_cluster)

print(f"Cluster annotations: {ClusterAnn.shape}")
print(f"Cell type classes: {len(Class2Cluster)}")
print(f"Cell type subclasses: {len(Subclass2Cluster)}")

In [None]:
# Load Z2 bias matrix
MouseSC_Z2 = pd.read_csv(
    "/home/jw3514/Work/CellType_Psy/AllenBrainCellAtlas/dat/SC_UMI_Mats/Cluster_Z2Mat_ISHMatch.z1clip3.csv.gz",
    index_col=0
)
print(f"Z2 bias matrix: {MouseSC_Z2.shape}")

In [None]:
# Load ASD gene weights
HGNC, ENSID2Entrez, GeneSymbol2Entrez, Entrez2Symbol = LoadGeneINFO()
ASD_GW = Fil2Dict(str(GENCIC_DATA_DIR / "Genetics/GeneWeights_DN/Spark_Meta_EWS.GeneWeight.DN.gw"))
print(f"ASD gene weights: {len(ASD_GW)} genes")

In [None]:
# Compute ASD genetic bias per cluster
ASD_SC_Bias = MouseCT_AvgZ_Weighted(MouseSC_Z2, ASD_GW)
ASD_SC_Bias = add_class(ASD_SC_Bias, ClusterAnn)
print(f"ASD bias computed for {len(ASD_SC_Bias)} clusters")
ASD_SC_Bias.head()

In [None]:
# Map Patch-seq cell type assignments to GENCIC cluster bias
# For each Patch-seq cell, look up the ASD bias of its assigned cluster

patchseq_bias = mapping_results.copy()

# Add bias information for assigned clusters
cluster_to_bias = ASD_SC_Bias.set_index(ASD_SC_Bias.index)['EFFECT'].to_dict()
cluster_to_class = ASD_SC_Bias.set_index(ASD_SC_Bias.index)['class_id_label'].to_dict()

patchseq_bias['asd_bias'] = patchseq_bias['assigned_cluster'].map(cluster_to_bias)
patchseq_bias['class_label'] = patchseq_bias['assigned_cluster'].map(cluster_to_class)

# Filter to cells with valid mapping
patchseq_valid = patchseq_bias[patchseq_bias['mapping_status'].isin(['ok_cluster', 'ok_subclass_only'])]

print(f"Patch-seq cells with valid mapping: {len(patchseq_valid)} / {len(patchseq_bias)}")
print(f"Cells with bias values: {patchseq_valid['asd_bias'].notna().sum()}")
print(f"\nMean ASD bias by subclass:")
print(patchseq_valid.groupby('assigned_subclass')['asd_bias'].agg(['mean', 'count']).sort_values('mean', ascending=False).head(10))

In [None]:
# Visualize ASD bias distribution across Patch-seq cell types
fig, ax = plt.subplots(figsize=(10, 6), dpi=150)

# Group by assigned subclass and plot bias distribution
top_subclasses = patchseq_valid.groupby('assigned_subclass')['asd_bias'].count().nlargest(15).index
plot_data = patchseq_valid[patchseq_valid['assigned_subclass'].isin(top_subclasses)]

# Order by median bias
order = plot_data.groupby('assigned_subclass')['asd_bias'].median().sort_values(ascending=False).index

sns.boxplot(
    data=plot_data,
    x='asd_bias',
    y='assigned_subclass',
    order=order,
    palette='RdBu_r',
    ax=ax
)

ax.axvline(0, color='gray', linestyle='--', alpha=0.5)
ax.set_xlabel('ASD Genetic Bias (Effect Size)', fontsize=12)
ax.set_ylabel('Assigned Subclass', fontsize=12)
ax.set_title(f'{DATASET} Patch-seq Cells: ASD Genetic Bias by Assigned Cell Type', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Compare original Patch-seq cell type annotations (if available) with atlas-assigned types
if 'corresponding_AIT2.3.1_alias' in adata_query_mapped.obs.columns:
    comparison = pd.DataFrame({
        'original_type': adata_query_mapped.obs['corresponding_AIT2.3.1_alias'],
        'atlas_subclass': mapping_results['assigned_subclass'],
        'atlas_cluster': mapping_results['assigned_cluster'],
        'confidence': mapping_results['conf_subclass'],
        'status': mapping_results['mapping_status']
    })
    
    print("Sample comparison of original vs atlas-assigned types:")
    print(comparison.head(20).to_string())
else:
    print("No original cell type annotations found in query metadata")

In [None]:
# Save integrated results
output_path = RESULTS_DIR / f"{DATASET}_patchseq_atlas_matched_with_bias.csv"
patchseq_bias.to_csv(output_path)
print(f"\nIntegrated results saved to: {output_path}")
print(f"\nColumns: {list(patchseq_bias.columns)}")