# Pan-Cancer scRNA-seq Integration with scVI

## Overview
This notebook integrates multiple scRNA-seq datasets across cancer types using scVI (single-cell Variational Inference). scVI provides a probabilistic framework that accounts for batch effects while preserving biological variation.

### Objectives
1. Merge all preprocessed datasets
2. Train scVI model for batch correction
3. Generate integrated latent representation
4. Evaluate integration quality

### Why scVI?
- **Probabilistic model**: Accounts for count noise and dropout
- **Scalable**: Handles millions of cells with GPU acceleration
- **Flexible**: Can include categorical and continuous covariates
- **scverse ecosystem**: Native AnnData support

---

## 1. Setup

In [None]:
import scanpy as sc
import anndata as ad
import scvi
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import torch
import warnings

warnings.filterwarnings('ignore')

# Check GPU
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Project paths
PROJECT_ROOT = Path("../..").resolve()
DATA_PROCESSED = PROJECT_ROOT / 'data' / 'processed' / 'scrna'
MODELS = PROJECT_ROOT / 'results' / 'models'
FIGURES = PROJECT_ROOT / 'results' / 'figures'
CONFIG_PATH = PROJECT_ROOT / 'config' / 'analysis_params.yaml'

# Load configuration
with open(CONFIG_PATH, 'r') as f:
    config = yaml.safe_load(f)

# scVI settings
scvi.settings.seed = config['random_seed']
scvi.settings.progress_bar_style = 'rich'

# Display scVI parameters
scvi_params = config['integration']['scvi']
print("\nscVI parameters:")
for key, value in scvi_params.items():
    print(f"  {key}: {value}")

## 2. Load and Merge Datasets

In [None]:
# Get list of processed datasets
datasets = config['datasets']['scrna']

adata_list = []

for dataset in datasets:
    geo_id = dataset['id']
    cancer_type = dataset['cancer_type']
    
    filepath = DATA_PROCESSED / f'{geo_id}_final.h5ad'
    
    if filepath.exists():
        adata = sc.read_h5ad(filepath)
        
        # Add metadata
        adata.obs['dataset'] = geo_id
        adata.obs['cancer_type'] = cancer_type
        adata.obs['treatment'] = dataset.get('treatment', 'Unknown')
        
        adata_list.append(adata)
        print(f"Loaded {geo_id}: {adata.n_obs} cells, {cancer_type}")
    else:
        print(f"Not found: {filepath}")

print(f"\nTotal datasets loaded: {len(adata_list)}")

In [None]:
# Concatenate datasets
if len(adata_list) > 0:
    adata = ad.concat(
        adata_list,
        join='inner',  # Keep only common genes
        merge='same',
        uns_merge='same'
    )
    
    # Make observation names unique
    adata.obs_names_make_unique()
    
    print(f"\nMerged dataset:")
    print(f"  Total cells: {adata.n_obs}")
    print(f"  Common genes: {adata.n_vars}")
    print(f"\nCells per cancer type:")
    print(adata.obs['cancer_type'].value_counts())
else:
    print("No datasets to merge. Run preprocessing first.")

## 3. Prepare for scVI

scVI requires raw counts and specific setup.

In [None]:
# Use raw counts
if 'counts' in adata.layers:
    adata.X = adata.layers['counts'].copy()
    print("Using raw counts from layers['counts']")
else:
    print("Warning: No raw counts found. Using current X matrix.")

# Re-identify HVGs on merged data
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=config['feature_selection']['n_top_genes'],
    flavor='seurat_v3',
    batch_key='dataset',  # Account for batch in HVG selection
    subset=False
)

n_hvg = adata.var['highly_variable'].sum()
print(f"\nSelected {n_hvg} highly variable genes across batches")

In [None]:
# Setup AnnData for scVI
scvi.model.SCVI.setup_anndata(
    adata,
    layer='counts' if 'counts' in adata.layers else None,
    batch_key=scvi_params['batch_key'],
    categorical_covariate_keys=scvi_params.get('categorical_covariate_keys', None),
    continuous_covariate_keys=scvi_params.get('continuous_covariate_keys', None),
)

print("AnnData setup for scVI complete")

## 4. Train scVI Model

In [None]:
# Initialize model
model = scvi.model.SCVI(
    adata,
    n_latent=scvi_params['n_latent'],
    n_layers=scvi_params['n_layers'],
    gene_likelihood='nb'  # Negative binomial for counts
)

print(f"Model initialized:")
print(f"  Latent dimensions: {scvi_params['n_latent']}")
print(f"  Encoder layers: {scvi_params['n_layers']}")

In [None]:
# Train model
model.train(
    max_epochs=scvi_params['n_epochs'],
    early_stopping=scvi_params['early_stopping'],
    early_stopping_patience=15,
    early_stopping_monitor='elbo_validation',
    plan_kwargs={'lr': 1e-3},
    check_val_every_n_epoch=10,
)

print("Training complete!")

In [None]:
# Plot training history
train_elbo = model.history['elbo_train'][1:]
val_elbo = model.history['elbo_validation'][1:]

plt.figure(figsize=(8, 4))
plt.plot(train_elbo, label='Train')
plt.plot(val_elbo, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('ELBO')
plt.title('scVI Training History')
plt.legend()
plt.savefig(FIGURES / 'scvi_training_history.png', dpi=150, bbox_inches='tight')
plt.show()

## 5. Extract Latent Representation

In [None]:
# Get latent representation
adata.obsm['X_scVI'] = model.get_latent_representation()

print(f"Latent representation shape: {adata.obsm['X_scVI'].shape}")

In [None]:
# Compute neighbors and UMAP on latent space
sc.pp.neighbors(
    adata,
    use_rep='X_scVI',
    n_neighbors=config['dim_reduction']['n_neighbors']
)

sc.tl.umap(adata, random_state=config['random_seed'])

print("UMAP computed on scVI latent space")

In [None]:
# Clustering on integrated data
sc.tl.leiden(
    adata,
    resolution=config['clustering']['default_resolution'],
    key_added='leiden_integrated',
    random_state=config['random_seed']
)

print(f"Integrated clusters: {adata.obs['leiden_integrated'].nunique()}")

## 6. Visualize Integration

In [None]:
# Plot by dataset and cancer type
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

sc.pl.umap(adata, color='dataset', ax=axes[0], show=False, title='Dataset')
sc.pl.umap(adata, color='cancer_type', ax=axes[1], show=False, title='Cancer Type')
sc.pl.umap(adata, color='leiden_integrated', ax=axes[2], show=False, title='Integrated Clusters')

plt.tight_layout()
plt.savefig(FIGURES / 'scvi_integration_overview.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Check batch mixing per cluster
batch_per_cluster = pd.crosstab(
    adata.obs['leiden_integrated'],
    adata.obs['dataset'],
    normalize='index'
)

plt.figure(figsize=(12, 8))
sns.heatmap(batch_per_cluster, cmap='YlOrRd', annot=False)
plt.title('Dataset Distribution per Cluster')
plt.xlabel('Dataset')
plt.ylabel('Cluster')
plt.savefig(FIGURES / 'scvi_batch_mixing.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Save Model and Data

In [None]:
# Save scVI model
model_path = MODELS / 'scvi_model'
model.save(str(model_path), overwrite=True)
print(f"Model saved to: {model_path}")

# Save integrated data
output_path = DATA_PROCESSED / 'integrated_atlas.h5ad'
adata.write(output_path)
print(f"Integrated data saved to: {output_path}")

## 8. Summary

### Integration Complete
- Merged {n_datasets} datasets across {n_cancer_types} cancer types
- Total cells: {n_cells}
- Latent dimensions: {n_latent}

### Next Steps
1. Benchmark integration quality in `03c_integration_benchmarking.ipynb`
2. Proceed to cell type annotation in `04_cell_annotation/`