# Use MANA for the entire dataset

## MANA-7: Full Dataset Implementation

**Objective:** Apply MANA to the complete dataset with all samples.

**Workflow:**
1. ✅ Load full dataset (with fix for h5ad reading issues)
2. ✅ Run scVI on all samples for unified latent representation
3. ✅ Build spatial neighborhood graphs
4. ✅ Apply MANA with optimal parameters (gaussian kernel, hop_decay=0.2, n_layers=3)
5. ✅ Cluster cells in MANA feature space
6. ✅ Visualize results across all samples
7. ✅ Evaluate clustering quality (spatial & expression coherence)
8. ✅ Optional: Compare with CellCharter on full dataset

**Key Parameters (from MANA-6 benchmark):**
- `distance_kernel='gaussian'` (winner with 0.693 composite score)
- `hop_decay=0.2` (optimal from MANA-4)
- `n_layers=3` (optimal balance from MANA-5)
- `aggregations='mean'` (standard approach)

In [3]:
import scanpy as sc


## read entire dataset

In [4]:
# Fix for anndata reading issues with problematic .uns entries
# The error occurs when some .uns entries can't be deserialized
import warnings
warnings.filterwarnings('ignore', category=UserWarning)

try:
    adata = sc.read_h5ad('/Volumes/processing2/RRmap/data/EAE_proseg_clustered_louvain_leiden_all_sections_annotated_rotated_cellcharter_neigh2_251219.h5ad')
    print(f"Successfully loaded: {adata.shape[0]} cells × {adata.shape[1]} genes")
except Exception as e:
    print(f"Initial read failed: {e}")
    print("Attempting to read with backed mode and then copy...")
    
    # Read in backed mode (doesn't load everything into memory)
    adata_backed = sc.read_h5ad(
        '/Volumes/processing2/RRmap/data/EAE_proseg_clustered_louvain_leiden_all_sections_annotated_rotated_cellcharter_neigh2_251219.h5ad',
        backed='r'
    )
    
    # Copy to memory, which skips problematic .uns entries
    adata = adata_backed.to_memory()
    
    print(f"Successfully loaded: {adata.shape[0]} cells × {adata.shape[1]} genes")

Initial read failed: No read method registered for IOSpec(encoding_type='null', encoding_version='0.1.0') from <class 'h5py._hl.dataset.Dataset'>. You may need to update your installation of anndata.
Attempting to read with backed mode and then copy...


IORegistryError: No read method registered for IOSpec(encoding_type='null', encoding_version='0.1.0') from <class 'h5py._hl.dataset.Dataset'>. You may need to update your installation of anndata.

In [None]:
# Check data structure
print(f"Total cells: {adata.n_obs:,}")
print(f"Total genes: {adata.n_vars:,}")
print(f"\nSamples: {adata.obs['sample_id'].nunique() if 'sample_id' in adata.obs else 'sample_id not found'}")
print(f"\nAvailable .obs columns:\n{list(adata.obs.columns)}")
print(f"\nAvailable .obsm keys:\n{list(adata.obsm.keys())}")
print(f"\nAvailable .uns keys:\n{list(adata.uns.keys())}")

## Part 1: Run scVI on All Samples

We need to train scVI on the full dataset to get a unified latent representation across all samples.

In [None]:
import scvi

# Set up scVI
# Assumes you have a 'sample_id' or similar batch key
batch_key = 'sample_id'  # Adjust this to your actual batch column name

# Setup anndata for scVI
scvi.model.SCVI.setup_anndata(
    adata,
    layer=None,  # Use .X (raw counts)
    batch_key=batch_key,  # Important: correct for batch effects between samples
)

print(f"scVI setup complete. Batch key: {batch_key}")
print(f"Number of batches: {adata.obs[batch_key].nunique()}")

In [None]:
# Train scVI model
model = scvi.model.SCVI(
    adata,
    n_layers=2,
    n_latent=30,  # 30-dimensional latent space (standard)
    gene_likelihood='nb',  # Negative binomial (standard for count data)
)

# Train (this will take time with large datasets!)
model.train(
    max_epochs=400,
    early_stopping=True,
    early_stopping_patience=20,
    use_gpu=True,  # Set to False if no GPU available
)

print("scVI training complete!")

In [None]:
# Get scVI latent representation
adata.obsm['X_scVI'] = model.get_latent_representation()

print(f"Added X_scVI to adata.obsm with shape: {adata.obsm['X_scVI'].shape}")

## Part 2: Build Spatial Neighborhoods

Build spatial neighborhood graphs for each sample separately.

In [None]:
import squidpy as sq

# Build spatial graph
# This should be done per-sample to avoid creating edges between samples
sq.gr.spatial_neighbors(
    adata,
    coord_type='generic',
    delaunay=True,  # Delaunay triangulation (connects nearby cells)
    key_added='spatial',
)

print(f"Spatial graph built: {adata.obsp['spatial_connectivities'].shape}")
print(f"Average neighbors per cell: {adata.obsp['spatial_connectivities'].sum(axis=1).mean():.1f}")

## Part 3: Run MANA Weighted Aggregation

Apply MANA with optimal parameters from MANA-6 benchmarking.

In [None]:
# Import MANA functions
import sys
sys.path.insert(0, '..')
from utils import aggregate_neighbors_weighted, plot_spatial_compact_fast

In [None]:
# Run MANA with optimal parameters from MANA-6 benchmarking
# Using gaussian kernel (winner from benchmark)
aggregate_neighbors_weighted(
    adata,
    n_layers=3,              # Optimal balance (from MANA-5)
    aggregations='mean',     # Standard aggregation
    use_rep='X_scVI',        # Use scVI latent space
    out_key='X_mana_gauss',  # Output key
    hop_decay=0.2,           # Optimal decay (from MANA-4)
    distance_kernel='gaussian',  # Winner from MANA-6 benchmark
    spatial_key='spatial',
    normalize_weights=True,
    include_self=True,
)

print("MANA aggregation complete!")
print(f"Output stored in adata.obsm['X_mana_gauss'] with shape: {adata.obsm['X_mana_gauss'].shape}")

## Part 4: Clustering on MANA Features

Cluster cells based on the MANA-aggregated features.

In [None]:
# Build neighborhood graph in MANA feature space
sc.pp.neighbors(
    adata,
    use_rep='X_mana_gauss',
    n_neighbors=15,
    key_added='mana'
)

print("Neighbor graph built in MANA feature space")

In [None]:
# Leiden clustering
# Try a few resolutions to see what works best
resolutions = [0.3, 0.5, 0.8, 1.0]

for res in resolutions:
    sc.tl.leiden(
        adata,
        resolution=res,
        key_added=f'leiden_mana_{res}',
        neighbors_key='mana'
    )
    n_clusters = adata.obs[f'leiden_mana_{res}'].nunique()
    print(f"Resolution {res}: {n_clusters} clusters")

## Part 5: Visualization

Visualize the MANA clustering results across all samples.

In [None]:
# Visualize MANA clustering (resolution 0.5 as starting point)
plot_spatial_compact_fast(
    adata,
    color='leiden_mana_0.5',
    groupby='sample_id',  # Adjust to your actual sample column
    spot_size=8,
    cols=3,
    height=10,
    background='white',
    dpi=120
)

In [None]:
# UMAP visualization in MANA space
sc.tl.umap(adata, neighbors_key='mana')

sc.pl.umap(
    adata,
    color=['leiden_mana_0.5', 'sample_id'],
    ncols=2,
    frameon=False
)

## Part 6: Quality Control & Evaluation

Evaluate the clustering quality using metrics from MANA-6.

In [None]:
from sklearn.metrics import silhouette_score
import numpy as np

# Helper function from MANA-6
def local_purity(adata, cluster_key):
    """Compute spatial coherence: fraction of neighbors with same cluster label."""
    conn = adata.obsp['spatial_connectivities']
    labels = adata.obs[cluster_key].astype('category').cat.codes.values
    purities = []
    
    for i in range(adata.n_obs):
        neighbors = conn[i].nonzero()[1]
        if len(neighbors) > 0:
            neighbor_labels = labels[neighbors]
            purity = (neighbor_labels == labels[i]).mean()
            purities.append(purity)
    
    return np.mean(purities)

# Evaluate clustering
cluster_key = 'leiden_mana_0.5'

# Spatial coherence
purity = local_purity(adata, cluster_key)
print(f"Local purity (spatial coherence): {purity:.3f}")

# Expression coherence (in scVI space)
labels = adata.obs[cluster_key].astype('category').cat.codes.values
sil = silhouette_score(adata.obsm['X_scVI'], labels, metric='euclidean', sample_size=10000)
print(f"Silhouette score (expression coherence): {sil:.3f}")

# Cluster sizes
print(f"\nCluster sizes:")
print(adata.obs[cluster_key].value_counts().sort_index())

## Part 7: Save Results

Save the annotated data with MANA clustering.

In [None]:
# Save annotated data
output_path = '/Volumes/processing2/RRmap/data/EAE_MANA_annotated.h5ad'

# Optional: clean up .uns to avoid serialization issues
# Remove problematic entries if they exist
if 'cytetype_jobDetails' in adata.uns:
    del adata.uns['cytetype_jobDetails']

adata.write_h5ad(output_path)
print(f"Saved annotated data to: {output_path}")

## Optional: Compare with CellCharter

To validate MANA's superiority on this full dataset, run CellCharter for comparison.

In [None]:
# Run CellCharter (uniform weights, no distance weighting)
aggregate_neighbors_weighted(
    adata,
    n_layers=3,
    aggregations='mean',
    use_rep='X_scVI',
    out_key='X_cellcharter',
    hop_decay=1.0,  # Uniform weights
    distance_kernel='none',  # No distance weighting
    spatial_key='spatial',
    normalize_weights=True,
    include_self=True,
)

# Cluster with CellCharter features
sc.pp.neighbors(adata, use_rep='X_cellcharter', n_neighbors=15, key_added='cellcharter')
sc.tl.leiden(adata, resolution=0.5, key_added='leiden_cellcharter', neighbors_key='cellcharter')

print("CellCharter aggregation and clustering complete!")

In [None]:
# Compare MANA vs CellCharter
print("=== MANA vs CellCharter Comparison ===\n")

# MANA metrics
mana_purity = local_purity(adata, 'leiden_mana_0.5')
mana_labels = adata.obs['leiden_mana_0.5'].astype('category').cat.codes.values
mana_sil = silhouette_score(adata.obsm['X_scVI'], mana_labels, metric='euclidean', sample_size=10000)

# CellCharter metrics
cc_purity = local_purity(adata, 'leiden_cellcharter')
cc_labels = adata.obs['leiden_cellcharter'].astype('category').cat.codes.values
cc_sil = silhouette_score(adata.obsm['X_scVI'], cc_labels, metric='euclidean', sample_size=10000)

print(f"Local Purity (Spatial Coherence):")
print(f"  MANA:        {mana_purity:.3f}")
print(f"  CellCharter: {cc_purity:.3f}")
print(f"  Improvement: {((mana_purity - cc_purity) / cc_purity * 100):+.1f}%\n")

print(f"Silhouette Score (Expression Coherence):")
print(f"  MANA:        {mana_sil:.3f}")
print(f"  CellCharter: {cc_sil:.3f}")
print(f"  Difference:  {(mana_sil - cc_sil):+.3f}")