# Spatial Deconvolution with Cell2Location

## Overview
This notebook uses cell2location to deconvolve spatial transcriptomics spots into cell type compositions using our scRNA-seq reference atlas.

### Objectives
1. Prepare scRNA-seq reference signatures
2. Train cell2location model on spatial data
3. Map cell types to spatial coordinates
4. Identify spatial niches

### Why Cell2Location?
- Bayesian model accounting for technical noise
- Estimates absolute cell abundances
- GPU-accelerated for scalability

---

In [None]:
import scanpy as sc
import squidpy as sq
import anndata as ad
import cell2location
from cell2location.utils.filtering import filter_genes
from cell2location.models import RegressionModel
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import torch
import warnings

warnings.filterwarnings('ignore')

print(f"CUDA available: {torch.cuda.is_available()}")

# Project paths
PROJECT_ROOT = Path("../..").resolve()
DATA_PROCESSED_SCRNA = PROJECT_ROOT / 'data' / 'processed' / 'scrna'
DATA_RAW_SPATIAL = PROJECT_ROOT / 'data' / 'raw' / 'spatial'
DATA_PROCESSED_SPATIAL = PROJECT_ROOT / 'data' / 'processed' / 'spatial'
MODELS = PROJECT_ROOT / 'results' / 'models'
FIGURES = PROJECT_ROOT / 'results' / 'figures'
CONFIG_PATH = PROJECT_ROOT / 'config' / 'analysis_params.yaml'

with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

## 1. Load Reference scRNA-seq Atlas

In [None]:
# Load annotated scRNA-seq reference
adata_ref = sc.read_h5ad(DATA_PROCESSED_SCRNA / 'integrated_atlas_annotated.h5ad')

print(f"Reference atlas:")
print(f"  Cells: {adata_ref.n_obs}")
print(f"  Genes: {adata_ref.n_vars}")
print(f"\nCell types: {adata_ref.obs['cell_type_major'].nunique()}")

In [None]:
# Use raw counts for reference
if 'counts' in adata_ref.layers:
    adata_ref.X = adata_ref.layers['counts'].copy()

# Filter genes for cell2location
selected_genes = filter_genes(
    adata_ref,
    cell_count_cutoff=5,
    cell_percentage_cutoff2=0.03,
    nonz_mean_cutoff=1.12
)

adata_ref = adata_ref[:, selected_genes].copy()
print(f"Selected genes for reference: {adata_ref.n_vars}")

## 2. Train Reference Model

Estimate reference cell type signatures using negative binomial regression.

In [None]:
# Setup reference model
cell2location.models.RegressionModel.setup_anndata(
    adata_ref,
    batch_key='dataset',
    labels_key='cell_type_major'
)

# Create model
mod_ref = RegressionModel(adata_ref)

print("Reference model initialized")

In [None]:
# Train reference model
mod_ref.train(
    max_epochs=250,
    use_gpu=torch.cuda.is_available(),
    batch_size=2500
)

# Plot training history
mod_ref.plot_history(20)
plt.savefig(FIGURES / 'cell2location_ref_training.png', dpi=150)
plt.show()

In [None]:
# Export reference signatures
adata_ref = mod_ref.export_posterior(
    adata_ref,
    sample_kwargs={'num_samples': 1000, 'batch_size': 2500, 'use_gpu': True}
)

# Extract signature matrix
inf_aver = adata_ref.varm['means_per_cluster_mu_fg'].copy()
print(f"Reference signatures shape: {inf_aver.shape}")

## 3. Load Spatial Data

In [None]:
# Load spatial data (example for GSE203612)
spatial_id = "GSE203612"
sample_name = "sample1"  # Modify based on your data

spatial_path = DATA_RAW_SPATIAL / spatial_id / sample_name

if spatial_path.exists():
    adata_st = sc.read_visium(spatial_path, count_file='filtered_feature_bc_matrix.h5')
    adata_st.var_names_make_unique()
    
    print(f"Loaded spatial data:")
    print(f"  Spots: {adata_st.n_obs}")
    print(f"  Genes: {adata_st.n_vars}")
else:
    print(f"Spatial data not found at: {spatial_path}")

In [None]:
# Intersect genes
shared_genes = [g for g in adata_ref.var_names if g in adata_st.var_names]
adata_st = adata_st[:, shared_genes].copy()
print(f"Shared genes: {len(shared_genes)}")

## 4. Train Cell2Location Model

In [None]:
# Setup spatial model
cell2location.models.Cell2location.setup_anndata(
    adata_st,
    batch_key=None
)

# Create model with reference signatures
c2l_params = config['spatial']['cell2location']

mod_st = cell2location.models.Cell2location(
    adata_st,
    cell_state_df=inf_aver,
    N_cells_per_location=c2l_params['n_cells_per_location'],
    detection_alpha=c2l_params['detection_alpha']
)

print("Spatial model initialized")

In [None]:
# Train spatial model
mod_st.train(
    max_epochs=c2l_params['max_epochs'],
    batch_size=None,
    train_size=1,
    use_gpu=torch.cuda.is_available()
)

# Plot training
mod_st.plot_history(1000)
plt.savefig(FIGURES / f'{spatial_id}_{sample_name}_c2l_training.png', dpi=150)
plt.show()

In [None]:
# Export posterior
adata_st = mod_st.export_posterior(
    adata_st,
    sample_kwargs={'num_samples': 1000, 'batch_size': mod_st.adata.n_obs, 'use_gpu': True}
)

# Cell type abundances are in adata_st.obsm['q05_cell_abundance_w_sf']
print("Deconvolution complete!")
print(f"Cell abundances shape: {adata_st.obsm['q05_cell_abundance_w_sf'].shape}")

## 5. Visualize Spatial Cell Type Distribution

In [None]:
# Add cell type abundances to obs for plotting
cell_types = list(inf_aver.columns)
abundance_df = pd.DataFrame(
    adata_st.obsm['q05_cell_abundance_w_sf'],
    index=adata_st.obs_names,
    columns=cell_types
)

for ct in cell_types:
    adata_st.obs[f'abundance_{ct}'] = abundance_df[ct].values

In [None]:
# Plot spatial distribution of cell types
fig, axes = plt.subplots(2, 4, figsize=(20, 10))
axes = axes.flatten()

for i, ct in enumerate(cell_types[:8]):
    sq.pl.spatial_scatter(
        adata_st,
        color=f'abundance_{ct}',
        ax=axes[i],
        title=ct,
        cmap='viridis',
        size=1.5
    )

plt.tight_layout()
plt.savefig(FIGURES / f'{spatial_id}_{sample_name}_cell_abundances.png', dpi=150)
plt.show()

## 6. Identify Spatial Niches

In [None]:
# Cluster spots by cell type composition
from sklearn.cluster import KMeans

n_niches = config['spatial']['niche']['n_neighborhoods']

# Cluster based on cell abundances
kmeans = KMeans(n_clusters=n_niches, random_state=config['random_seed'])
adata_st.obs['spatial_niche'] = kmeans.fit_predict(abundance_df.values).astype(str)

print(f"Identified {n_niches} spatial niches")

In [None]:
# Visualize niches
sq.pl.spatial_scatter(
    adata_st,
    color='spatial_niche',
    size=1.5,
    save=f'{spatial_id}_{sample_name}_niches.png'
)

In [None]:
# Characterize niches by cell type composition
niche_composition = abundance_df.groupby(adata_st.obs['spatial_niche']).mean()

plt.figure(figsize=(12, 8))
sns.heatmap(niche_composition.T, cmap='YlOrRd', annot=True, fmt='.1f')
plt.title('Cell Type Composition per Spatial Niche')
plt.xlabel('Niche')
plt.ylabel('Cell Type')
plt.savefig(FIGURES / f'{spatial_id}_{sample_name}_niche_composition.png', dpi=150)
plt.show()

## 7. Save Results

In [None]:
# Save deconvolved spatial data
output_path = DATA_PROCESSED_SPATIAL / f'{spatial_id}_{sample_name}_deconvolved.h5ad'
adata_st.write(output_path)
print(f"Saved deconvolved data to: {output_path}")

# Save model
model_path = MODELS / f'cell2location_{spatial_id}_{sample_name}'
mod_st.save(str(model_path), overwrite=True)
print(f"Saved model to: {model_path}")

## Summary

### Completed
- Trained reference signatures from scRNA-seq
- Deconvolved spatial spots to cell types
- Identified spatial niches

### Next Steps
1. Spatial statistics in `05d_spatial_statistics.ipynb`
2. Correlate niches with resistance in `06_resistance_analysis/`