# 03 - Integration with scVI

**COVID-19 GSE171524 Single-Cell Analysis**

This notebook performs batch correction and integration using scVI.

## Objectives
1. Normalize and log-transform data
2. Select highly variable genes (HVGs)
3. Train scVI model for batch correction
4. Extract latent representation
5. Compute UMAP embedding
6. Save integrated data and model

## Why scVI?
- Probabilistic deep learning approach
- Handles batch effects while preserving biological variation
- Provides uncertainty estimates
- Scales well to large datasets

In [None]:
# Import libraries
import os
import sys
import warnings
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import scvi
import matplotlib.pyplot as plt
from pathlib import Path

warnings.filterwarnings('ignore')

# Add scripts to path
sys.path.insert(0, '../scripts')
from plotting import plot_umap_celltype, COVID_COLORS

# Settings
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=100, facecolor='white')
scvi.settings.seed = 42

print(f"Scanpy: {sc.__version__}")
print(f"scvi-tools: {scvi.__version__}")

In [None]:
# Check GPU availability
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    print("Apple MPS (Metal) available - will use for acceleration")
else:
    print("No GPU detected - training will use CPU (slower but works)")

In [None]:
# Define paths
INPUT_PATH = Path('../data/processed_data/adata_qc.h5ad')
OUTPUT_DIR = Path('../data/processed_data')
MODEL_DIR = OUTPUT_DIR / 'scvi_model'
FIGURE_DIR = Path('../results/figures/integration')
FIGURE_DIR.mkdir(parents=True, exist_ok=True)

# Load QC'd data
print(f"Loading: {INPUT_PATH}")
adata = sc.read_h5ad(INPUT_PATH)
print(f"Loaded: {adata.n_obs:,} cells, {adata.n_vars:,} genes")

## Preprocessing

In [None]:
# Store raw counts for scVI
adata.layers['counts'] = adata.X.copy()

print("Stored raw counts in layers['counts']")

In [None]:
# Normalize and log-transform for HVG selection
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)

print("Normalized and log-transformed")

In [None]:
# Select highly variable genes
# Using seurat_v3 flavor which is recommended for integration
N_HVG = 4000

sc.pp.highly_variable_genes(
    adata,
    n_top_genes=N_HVG,
    flavor='seurat_v3',
    batch_key='sample_id',
    layer='counts'
)

print(f"Selected {adata.var['highly_variable'].sum():,} highly variable genes")

In [None]:
# Visualize HVG selection
sc.pl.highly_variable_genes(adata, show=False)
plt.savefig(FIGURE_DIR / 'hvg_selection.png', dpi=150)
plt.show()

In [None]:
# Keep full data for later, subset for scVI
adata_full = adata.copy()

# For scVI, use raw counts and HVGs only
adata_hvg = adata[:, adata.var['highly_variable']].copy()
adata_hvg.X = adata_hvg.layers['counts'].copy()

print(f"\nData for scVI:")
print(f"  Cells: {adata_hvg.n_obs:,}")
print(f"  Genes: {adata_hvg.n_vars:,}")

## Train scVI Model

In [None]:
# Setup AnnData for scVI
scvi.model.SCVI.setup_anndata(
    adata_hvg,
    layer=None,  # Use X which contains counts
    batch_key='sample_id'  # Batch correction by sample
)

print("AnnData setup complete")

In [None]:
# Create scVI model
model = scvi.model.SCVI(
    adata_hvg,
    n_hidden=128,
    n_latent=30,
    n_layers=2,
    dropout_rate=0.1,
    gene_likelihood='nb'  # Negative binomial for UMI counts
)

print("scVI model created")
print(model)

In [None]:
# Train the model
# GPU: ~10-20 min, CPU: ~30-60 min depending on hardware
print("Training scVI model...")

# Adjust batch size for CPU if no GPU
use_gpu = torch.cuda.is_available()
batch_size = 256 if use_gpu else 128

model.train(
    max_epochs=200,
    early_stopping=True,
    early_stopping_patience=10,
    train_size=0.9,
    batch_size=batch_size,
    plan_kwargs={'lr': 1e-3},
    accelerator='auto'  # Automatically uses GPU/MPS/CPU
)

print("Training complete!")

In [None]:
# Plot training history
fig, ax = plt.subplots(figsize=(8, 5))

# Handle different scvi-tools versions (key names vary)
history = model.history
for key in history:
    if 'train' in key.lower():
        train_key = key
    if 'validation' in key.lower() or 'val' in key.lower():
        val_key = key

if 'elbo_train' in history:
    ax.plot(history['elbo_train'].index, history['elbo_train'].values, label='Train ELBO')
if 'elbo_validation' in history:
    ax.plot(history['elbo_validation'].index, history['elbo_validation'].values, label='Validation ELBO')

# Fallback: plot whatever keys exist
if len(ax.lines) == 0:
    for key, df in history.items():
        ax.plot(df.index, df.values, label=key)

ax.set_xlabel('Epoch')
ax.set_ylabel('ELBO')
ax.set_title('scVI Training History')
ax.legend()
plt.savefig(FIGURE_DIR / 'scvi_training.png', dpi=150)
plt.show()

In [None]:
# Save model
MODEL_DIR.mkdir(parents=True, exist_ok=True)
model.save(MODEL_DIR, overwrite=True)
print(f"Model saved to: {MODEL_DIR}")

## Extract Latent Representation

In [None]:
# Get latent representation
latent = model.get_latent_representation()

print(f"Latent representation shape: {latent.shape}")

# Store in adata
adata_hvg.obsm['X_scVI'] = latent

In [None]:
# Compute neighbors using scVI latent space
sc.pp.neighbors(
    adata_hvg,
    use_rep='X_scVI',
    n_neighbors=30,
    n_pcs=30
)

print("Computed neighbors from scVI latent space")

In [None]:
# Compute UMAP
sc.tl.umap(adata_hvg, min_dist=0.3)

print("Computed UMAP embedding")

## Visualize Integration

In [None]:
# UMAP colored by sample
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.umap(
    adata_hvg,
    color='sample_id',
    title='UMAP - Samples (scVI integrated)',
    ax=ax,
    show=False
)
plt.savefig(FIGURE_DIR / 'umap_samples.png', dpi=150)
plt.show()

In [None]:
# UMAP colored by condition
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.umap(
    adata_hvg,
    color='condition',
    palette=COVID_COLORS,
    title='UMAP - Condition (scVI integrated)',
    ax=ax,
    show=False
)
plt.savefig(FIGURE_DIR / 'umap_condition.png', dpi=150)
plt.show()

In [None]:
# UMAP split by condition
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for ax, cond in zip(axes, ['Control', 'COVID']):
    mask = adata_hvg.obs['condition'] == cond
    sc.pl.umap(
        adata_hvg[mask],
        color='condition',
        palette=COVID_COLORS,
        title=f'{cond} (n={mask.sum():,})',
        ax=ax,
        show=False
    )

plt.tight_layout()
plt.savefig(FIGURE_DIR / 'umap_condition_split.png', dpi=150)
plt.show()

In [None]:
# Check batch mixing - UMAP density by sample
fig, axes = plt.subplots(3, 9, figsize=(20, 7))
axes = axes.flatten()

samples = adata_hvg.obs['sample_id'].unique()

for i, sample in enumerate(sorted(samples)):
    if i >= len(axes):
        break
    mask = adata_hvg.obs['sample_id'] == sample
    cond = adata_hvg.obs.loc[mask, 'condition'].iloc[0]
    color = COVID_COLORS[cond]
    
    # Plot all cells faint
    axes[i].scatter(
        adata_hvg.obsm['X_umap'][:, 0],
        adata_hvg.obsm['X_umap'][:, 1],
        c='lightgray', s=0.5, alpha=0.3
    )
    # Highlight sample
    axes[i].scatter(
        adata_hvg.obsm['X_umap'][mask, 0],
        adata_hvg.obsm['X_umap'][mask, 1],
        c=color, s=0.5, alpha=0.5
    )
    axes[i].set_title(f'{sample}\n(n={mask.sum():,})', fontsize=8)
    axes[i].set_xticks([])
    axes[i].set_yticks([])

# Hide unused axes
for j in range(i+1, len(axes)):
    axes[j].set_visible(False)

plt.suptitle('Sample Distribution on UMAP', fontsize=12)
plt.tight_layout()
plt.savefig(FIGURE_DIR / 'umap_sample_distribution.png', dpi=150)
plt.show()

## Transfer Results to Full Data

In [None]:
# Transfer latent representation and UMAP to full data
adata_full.obsm['X_scVI'] = adata_hvg.obsm['X_scVI']
adata_full.obsm['X_umap'] = adata_hvg.obsm['X_umap']

# Copy neighbors graph
adata_full.obsp['connectivities'] = adata_hvg.obsp['connectivities']
adata_full.obsp['distances'] = adata_hvg.obsp['distances']
adata_full.uns['neighbors'] = adata_hvg.uns['neighbors']
adata_full.uns['umap'] = adata_hvg.uns['umap']

print("Transferred integration results to full data")

In [None]:
# Verify UMAP on full data
fig, ax = plt.subplots(figsize=(10, 8))
sc.pl.umap(
    adata_full,
    color='condition',
    palette=COVID_COLORS,
    title='UMAP - Full Data (scVI integrated)',
    ax=ax,
    show=False
)
plt.show()

In [None]:
# Check key marker genes
markers = ['EPCAM', 'CD68', 'COL1A1', 'CD3D', 'CD79A', 'PECAM1']

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()

for ax, gene in zip(axes, markers):
    if gene in adata_full.var_names:
        sc.pl.umap(
            adata_full,
            color=gene,
            ax=ax,
            show=False,
            title=gene,
            cmap='viridis'
        )

plt.tight_layout()
plt.savefig(FIGURE_DIR / 'umap_markers.png', dpi=150)
plt.show()

In [None]:
# Save integrated data
output_path = OUTPUT_DIR / 'adata_integrated.h5ad'

print(f"Saving to: {output_path}")
adata_full.write_h5ad(output_path, compression='gzip')

print(f"\nFile saved: {output_path}")
print(f"File size: {output_path.stat().st_size / 1e9:.2f} GB")

## Summary

### Integration Pipeline
1. **HVG Selection**: 4,000 genes using seurat_v3 method
2. **scVI Training**: 200 max epochs with early stopping
3. **Latent Representation**: 30-dimensional embedding
4. **UMAP**: Computed from scVI latent space

### Model Parameters
- Hidden layers: 2 x 128 units
- Latent dimensions: 30
- Gene likelihood: Negative binomial
- Batch key: sample_id

### Output
- `data/processed_data/adata_integrated.h5ad` - Integrated AnnData
- `data/processed_data/scvi_model/` - Saved scVI model

### Next Steps
â†’ **04_clustering_annotation.ipynb**: Cluster cells and annotate cell types

In [None]:
# Session info
print("\n=== Session Info ===")
print(f"Scanpy: {sc.__version__}")
print(f"scvi-tools: {scvi.__version__}")
print(f"PyTorch: {torch.__version__}")
print(f"NumPy: {np.__version__}")