[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cafferychen777/flashdeconv/blob/main/examples/spatial_deconvolution_tutorial.ipynb)

# Spatial Transcriptomics Cell Type Deconvolution Tutorial

This tutorial demonstrates how to perform **cell type deconvolution** on spatial transcriptomics data using FlashDeconv.

**What you'll learn:**
1. How to load and prepare spatial transcriptomics data (AnnData format)
2. How to prepare a single-cell RNA-seq reference
3. How to run cell type deconvolution with FlashDeconv
4. How to visualize and interpret the results

**Supported platforms:**
- 10x Visium / Visium HD
- Slide-seq / Slide-seqV2
- MERFISH
- Xenium
- Any array-based or imaging-based spatial transcriptomics

## Setup

### Install dependencies

In [None]:
# Uncomment to install
# !pip install flashdeconv scanpy matplotlib

In [None]:
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

import flashdeconv as fd
from flashdeconv import FlashDeconv

print(f"FlashDeconv version: {fd.__version__}")

# Set visualization style
sc.settings.set_figure_params(dpi=100, frameon=False)
plt.rcParams['figure.figsize'] = (6, 6)

## Part 1: Generate Example Data

For this tutorial, we'll generate synthetic spatial transcriptomics data. This allows you to run the tutorial without downloading large datasets.

The synthetic data simulates:
- A tissue with 5,000 spatial spots arranged in a grid
- 8 cell types with distinct spatial patterns
- Realistic gene expression with cell-type-specific markers

In [None]:
def generate_synthetic_spatial_data(
    n_spots=5000,
    n_genes=3000,
    n_cell_types=8,
    random_state=42
):
    """
    Generate synthetic spatial transcriptomics data with ground truth.
    
    Returns
    -------
    adata_st : AnnData
        Spatial data with coordinates in .obsm['spatial']
    adata_ref : AnnData
        Single-cell reference with cell type labels in .obs['cell_type']
    ground_truth : ndarray
        True cell type proportions (n_spots, n_cell_types)
    """
    np.random.seed(random_state)
    
    cell_type_names = [
        'Epithelial', 'Fibroblast', 'Endothelial', 'Macrophage',
        'T_cell', 'B_cell', 'Neuron', 'Stem_cell'
    ][:n_cell_types]
    
    # === Generate reference signatures ===
    # Base expression (log-normal)
    signatures = np.exp(np.random.randn(n_cell_types, n_genes) * 0.5 + 2)
    
    # Add cell-type-specific markers
    for k in range(n_cell_types):
        markers = np.random.choice(n_genes, size=50, replace=False)
        signatures[k, markers] *= 10  # 10x upregulation
    
    # === Generate spatial coordinates (grid with noise) ===
    side = int(np.ceil(np.sqrt(n_spots)))
    x = np.tile(np.arange(side), side)[:n_spots].astype(float)
    y = np.repeat(np.arange(side), side)[:n_spots].astype(float)
    coords = np.column_stack([x, y])
    coords += np.random.randn(n_spots, 2) * 0.1  # Add slight jitter
    
    # === Generate true proportions with spatial patterns ===
    proportions = np.zeros((n_spots, n_cell_types))
    
    for k in range(n_cell_types):
        # Each cell type has a spatial center
        center = np.array([np.random.rand() * side, np.random.rand() * side])
        dist = np.sqrt(np.sum((coords - center) ** 2, axis=1))
        proportions[:, k] = np.exp(-dist / (side / 4))
    
    # Normalize to sum to 1
    proportions = proportions / proportions.sum(axis=1, keepdims=True)
    
    # === Generate spatial counts ===
    expected = proportions @ signatures
    depth = np.random.gamma(shape=5, scale=3000, size=n_spots)
    expected = expected * depth[:, np.newaxis]
    expected = expected / expected.sum(axis=1, keepdims=True) * depth[:, np.newaxis]
    counts = np.random.poisson(expected).astype(float)
    
    # === Create spatial AnnData ===
    gene_names = [f'Gene_{i}' for i in range(n_genes)]
    spot_names = [f'Spot_{i}' for i in range(n_spots)]
    
    adata_st = sc.AnnData(
        X=counts,
        obs=pd.DataFrame(index=spot_names),
        var=pd.DataFrame(index=gene_names)
    )
    adata_st.obsm['spatial'] = coords
    
    # === Create reference AnnData ===
    n_cells_per_type = 200
    ref_counts = []
    ref_labels = []
    
    for k, ct in enumerate(cell_type_names):
        # Sample cells with Poisson noise around signature
        for _ in range(n_cells_per_type):
            cell_expr = np.random.poisson(signatures[k] * 0.1)
            ref_counts.append(cell_expr)
            ref_labels.append(ct)
    
    adata_ref = sc.AnnData(
        X=np.array(ref_counts, dtype=float),
        obs=pd.DataFrame({'cell_type': ref_labels}),
        var=pd.DataFrame(index=gene_names)
    )
    
    return adata_st, adata_ref, proportions, cell_type_names

# Generate data
adata_st, adata_ref, ground_truth, cell_type_names = generate_synthetic_spatial_data()

print(f"Spatial data: {adata_st.n_obs:,} spots x {adata_st.n_vars:,} genes")
print(f"Reference data: {adata_ref.n_obs:,} cells x {adata_ref.n_vars:,} genes")
print(f"Cell types: {cell_type_names}")

## Part 2: Explore the Data

Before running deconvolution, let's visualize the spatial data.

In [None]:
# Basic QC: total counts per spot
adata_st.obs['total_counts'] = np.array(adata_st.X.sum(axis=1)).flatten()
adata_st.obs['n_genes'] = np.array((adata_st.X > 0).sum(axis=1)).flatten()

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Scatter plot of spatial locations colored by total counts
ax = axes[0]
scatter = ax.scatter(
    adata_st.obsm['spatial'][:, 0],
    adata_st.obsm['spatial'][:, 1],
    c=adata_st.obs['total_counts'],
    s=5, cmap='viridis'
)
plt.colorbar(scatter, ax=ax, label='Total counts')
ax.set_xlabel('X coordinate')
ax.set_ylabel('Y coordinate')
ax.set_title('Spatial distribution of sequencing depth')
ax.set_aspect('equal')

# Histogram of total counts
ax = axes[1]
ax.hist(adata_st.obs['total_counts'], bins=50, edgecolor='black')
ax.set_xlabel('Total counts per spot')
ax.set_ylabel('Number of spots')
ax.set_title('Distribution of sequencing depth')

plt.tight_layout()
plt.show()

In [None]:
# Check reference cell type distribution
print("Reference cell type counts:")
print(adata_ref.obs['cell_type'].value_counts())

## Part 3: Run FlashDeconv

FlashDeconv provides two APIs:

1. **Scanpy-style API** (recommended): `fd.tl.deconvolve()` - stores results in AnnData
2. **NumPy API**: `FlashDeconv` class - for more control

### Option A: Scanpy-style API (Recommended)

In [None]:
# Run deconvolution using scanpy-style API
fd.tl.deconvolve(
    adata_st,                    # Spatial AnnData
    adata_ref,                   # Reference scRNA-seq AnnData
    cell_type_key='cell_type',   # Column in adata_ref.obs with cell type labels
    # Optional parameters:
    # sketch_dim=512,            # Sketch dimension (larger = more accurate but slower)
    # lambda_spatial=5000,       # Spatial regularization strength
    # n_hvg=2000,                # Number of highly variable genes to use
    # verbose=True,              # Print progress
)

# Results are stored in adata_st.obsm['flashdeconv']
print(f"Proportions shape: {adata_st.obsm['flashdeconv'].shape}")
print(f"Cell types: {adata_st.uns['flashdeconv']['cell_types']}")

### Option B: NumPy API (Alternative)

For more control over the deconvolution process:

In [None]:
# Alternative: NumPy API
# (Skip this cell if you already ran the scanpy-style API above)

# from flashdeconv import FlashDeconv
# import time
#
# # Build signature matrix from reference
# cell_types_sorted = sorted(adata_ref.obs['cell_type'].unique())
# X_ref = np.zeros((len(cell_types_sorted), adata_ref.n_vars))
# for i, ct in enumerate(cell_types_sorted):
#     mask = adata_ref.obs['cell_type'] == ct
#     X_ref[i] = np.asarray(adata_ref[mask].X.mean(axis=0)).flatten()
#
# # Get spatial data
# Y = np.asarray(adata_st.X)
# coords = adata_st.obsm['spatial']
#
# # Create and fit model
# model = FlashDeconv(
#     sketch_dim=512,
#     lambda_spatial=5000,
#     n_hvg=2000,
#     verbose=True
# )
#
# start = time.time()
# proportions = model.fit_transform(Y, X_ref, coords, cell_type_names=cell_types_sorted)
# print(f"Runtime: {time.time() - start:.2f}s")

## Part 4: Visualize Results

### 4.1 Spatial cell type maps

In [None]:
# Get results
proportions = adata_st.obsm['flashdeconv']
cell_types = adata_st.uns['flashdeconv']['cell_types']

# Plot spatial distribution of each cell type
n_types = len(cell_types)
n_cols = 4
n_rows = int(np.ceil(n_types / n_cols))

fig, axes = plt.subplots(n_rows, n_cols, figsize=(3*n_cols, 3*n_rows))
axes = axes.flatten()

for i, ct in enumerate(cell_types):
    ax = axes[i]
    scatter = ax.scatter(
        adata_st.obsm['spatial'][:, 0],
        adata_st.obsm['spatial'][:, 1],
        c=proportions[:, i],
        s=3, cmap='Reds', vmin=0, vmax=0.5
    )
    ax.set_title(ct, fontsize=10)
    ax.set_aspect('equal')
    ax.axis('off')

# Hide empty subplots
for j in range(i+1, len(axes)):
    axes[j].axis('off')

plt.suptitle('Estimated Cell Type Proportions', y=1.02, fontsize=12)
plt.tight_layout()
plt.show()

### 4.2 Dominant cell type per spot

In [None]:
# Assign dominant cell type
dominant_idx = np.argmax(proportions, axis=1)
adata_st.obs['dominant_cell_type'] = [cell_types[i] for i in dominant_idx]
adata_st.obs['max_proportion'] = proportions.max(axis=1)

# Create categorical colormap
from matplotlib.colors import ListedColormap
colors = plt.cm.tab10(np.linspace(0, 1, len(cell_types)))
cmap = ListedColormap(colors)

fig, ax = plt.subplots(figsize=(8, 8))
scatter = ax.scatter(
    adata_st.obsm['spatial'][:, 0],
    adata_st.obsm['spatial'][:, 1],
    c=dominant_idx,
    s=5, cmap=cmap
)

# Legend
handles = [plt.Line2D([0], [0], marker='o', color='w', 
                       markerfacecolor=colors[i], markersize=8, label=ct)
           for i, ct in enumerate(cell_types)]
ax.legend(handles=handles, loc='center left', bbox_to_anchor=(1, 0.5))

ax.set_title('Dominant Cell Type per Spot')
ax.set_aspect('equal')
ax.axis('off')
plt.tight_layout()
plt.show()

## Part 5: Evaluate Accuracy (with ground truth)

Since we have synthetic data with known ground truth, we can evaluate the deconvolution accuracy.

In [None]:
from scipy.stats import pearsonr, spearmanr

# Overall correlation
pred_flat = proportions.flatten()
true_flat = ground_truth.flatten()

pearson_r, _ = pearsonr(pred_flat, true_flat)
spearman_r, _ = spearmanr(pred_flat, true_flat)
rmse = np.sqrt(np.mean((pred_flat - true_flat) ** 2))

print("=" * 50)
print("OVERALL ACCURACY")
print("=" * 50)
print(f"Pearson correlation:  {pearson_r:.4f}")
print(f"Spearman correlation: {spearman_r:.4f}")
print(f"RMSE:                 {rmse:.4f}")

In [None]:
# Per-cell-type correlation
print("\n" + "=" * 50)
print("PER CELL TYPE ACCURACY")
print("=" * 50)
print(f"{'Cell Type':<15} {'Pearson r':<12} {'Spearman r':<12} {'RMSE':<10}")
print("-" * 50)

for i, ct in enumerate(cell_types):
    pred = proportions[:, i]
    true = ground_truth[:, i]
    r_p, _ = pearsonr(pred, true)
    r_s, _ = spearmanr(pred, true)
    rmse_ct = np.sqrt(np.mean((pred - true) ** 2))
    print(f"{ct:<15} {r_p:<12.4f} {r_s:<12.4f} {rmse_ct:<10.4f}")

In [None]:
# Scatter plot: predicted vs true
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.flatten()

for i, ct in enumerate(cell_types):
    ax = axes[i]
    ax.scatter(ground_truth[:, i], proportions[:, i], s=1, alpha=0.3)
    ax.plot([0, 1], [0, 1], 'r--', lw=1)  # Identity line
    
    r, _ = pearsonr(ground_truth[:, i], proportions[:, i])
    ax.set_title(f'{ct}\nr = {r:.3f}', fontsize=9)
    ax.set_xlabel('True', fontsize=8)
    ax.set_ylabel('Predicted', fontsize=8)
    ax.set_xlim(-0.05, 1)
    ax.set_ylim(-0.05, 1)

plt.suptitle('Predicted vs True Proportions', y=1.02)
plt.tight_layout()
plt.show()

## Part 6: Export Results

Save the deconvolution results for downstream analysis.

In [None]:
# Save as CSV
results_df = pd.DataFrame(
    proportions,
    index=adata_st.obs_names,
    columns=cell_types
)
results_df['dominant_type'] = adata_st.obs['dominant_cell_type']
results_df['x'] = adata_st.obsm['spatial'][:, 0]
results_df['y'] = adata_st.obsm['spatial'][:, 1]

# results_df.to_csv('deconvolution_results.csv')
print(results_df.head())

In [None]:
# Save updated AnnData
# adata_st.write_h5ad('spatial_with_deconvolution.h5ad')
print("Results saved in:")
print("  - adata_st.obsm['flashdeconv']: proportions matrix")
print("  - adata_st.obs['dominant_cell_type']: dominant type per spot")
print("  - adata_st.uns['flashdeconv']['cell_types']: cell type names")

## Summary

In this tutorial, you learned how to:

1. **Prepare data**: Load spatial transcriptomics data in AnnData format
2. **Run deconvolution**: Use `fd.tl.deconvolve()` for a scanpy-style workflow
3. **Visualize results**: Create spatial maps of cell type proportions
4. **Evaluate accuracy**: Compare predictions to ground truth (when available)

### Next steps

- Try FlashDeconv on your own data
- Explore the [multi-resolution analysis tutorial](resolution_horizon_analysis.ipynb) for Visium HD
- Combine with [Squidpy](https://squidpy.readthedocs.io/) for downstream spatial analysis

### Citation

If you use FlashDeconv in your research, please cite:

```
Yang, C., Chen, J. & Zhang, X. FlashDeconv enables atlas-scale,
multi-resolution spatial deconvolution via structure-preserving sketching.
bioRxiv (2025). https://doi.org/10.64898/2025.12.22.696108
```