# Tutorial: Spatial Interpolation with FEAST

In this notebook, we will demonstrate how to use FEAST to interpolate between two spatial transcriptomics slices. We will reconstruct the 3D spatial expression patterns of genes by generating intermediate slices.

In [None]:
# Install FEAST if not already installed
# !pip install FEAST-py
!pip install gdown

In [1]:
import sys
import os
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns
from pathlib import Path
import torch


import spateo as st
print(f"Spateo version: {st.__version__}")

# Import FEAST interpolation modules
from FEAST.interpolation import (
    interpolate_slices,
    InterpolationConfig
)

# Set figure params
sc.settings.set_figure_params(dpi=100, facecolor="white")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Running on: {device}")

2025-11-20 05:09:03.424784: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.



<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.


<stdin>:1:10: fatal error: nvrtc.h: No such file or directory
compilation terminated.


<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.


<stdin>:1:10: fatal error: nvrtc.h: No such file or directory
compilation terminated.

[KeOps] Compiling cuda jit compiler engine ... 
<stdin>:1:10: fatal error: cuda.h: No such file or directory
compilation terminated.


<stdin>:1:10: fatal error: nvrtc.h: No such file or directory
compilation terminated.

[KeOps] Compiling cuda jit compiler engine ... 
/maiziezhou_lab6/chen_yr/miniconda3/envs/simulator/lib/python3.9/site-packages/keopscore/binders/nvrtc/nvrtc_jit.cpp:16:10: fatal error: cuda.h: No such file or directory
   16 | #include <cuda.h>
      |          ^~~~~~~~
compilation terminated.

OK

/maiziezhou_lab6/chen_yr/miniconda3/envs/simulator/lib/python3.9/site-packages/keopscore/binders/nvrtc/nvrtc_jit.cpp:16:10: fata

## 1. Load and Preprocess Data

We load the raw spatial transcriptomics slices and perform standard preprocessing (filtering, normalization, log-transformation, and HVG selection) to prepare them for alignment.

In [None]:
import gdown

# Define paths and download data
data_dir = "./example_data"
os.makedirs(data_dir, exist_ok=True)

# Use file IDs directly, not URLs
files = {
    "merfish_0610.h5ad": "1lOQasZ9nxIDIZwlqEQDCBY0kJaA7GwZD",
    "merfish_0613.h5ad": "1L2oCrL23sjiOCdG2IOam7Q0_06bQQqWY"
}

for filename, file_id in files.items():
    output_path = os.path.join(data_dir, filename)
    if not os.path.exists(output_path):
        print(f"Downloading {filename}...")
        # Use id parameter directly which is more robust than url
        gdown.download(id=file_id, output=output_path, quiet=False)

print("Loading data...")
slice1 = st.read(os.path.join(data_dir, "merfish_0610.h5ad"))
slice2 = st.read(os.path.join(data_dir, "merfish_0613.h5ad"))

def preprocess_slice(adata, label):
    print(f"\nPreprocessing {label}...")
    # Filter
    sc.pp.filter_cells(adata, min_genes=10)
    sc.pp.filter_genes(adata, min_cells=3)
    
    # Save raw counts
    adata.layers["counts"] = adata.X.copy()
    
    # Normalize and Log
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    
    # HVGs
    sc.pp.highly_variable_genes(adata, n_top_genes=2000)
    print(f"Shape after preprocessing: {adata.shape}")
    return adata

slice1 = preprocess_slice(slice1, "Slice 1")
slice2 = preprocess_slice(slice2, "Slice 2")

Downloading merfish_0610.h5ad...


FileURLRetrievalError: Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=https://drive.google.com/uc?id=1lOQasZ9nxIDIZwlqEQDCBY0kJaA7GwZD

but Gdown can't. Please check connections and permissions.

## 2. Spateo Alignment

We use Spateo's `morpho_align` to align the two slices based on their spatial distribution and gene expression patterns. This generates an alignment matrix ($\pi$) and transforms the spatial coordinates of the slices to a common coordinate system.

In [None]:
print("Computing Group PCA...")
try:
    st.align.group_pca([slice1, slice2], pca_key='X_pca')
except Exception as e:
    print(f"Group PCA failed ({e}), falling back to individual PCA.")
    sc.pp.pca(slice1)
    sc.pp.pca(slice2)

print("\nPerforming Morphological Alignment...")
aligned_slices, pis = st.align.morpho_align(
    models=[slice1, slice2],
    rep_layer='X_pca',
    rep_field='obsm',
    dissimilarity='cos',
    verbose=True,
    spatial_key='spatial',
    key_added='align_spatial',
    device=device,
)

# Extract results
slice_ref1, slice_ref2 = aligned_slices
alignment_pi = pis[0]

# Convert alignment to numpy if needed
if hasattr(alignment_pi, 'cpu'):
    alignment = alignment_pi.cpu().numpy()
else:
    alignment = np.array(alignment_pi)

print(f"\nAlignment complete.")
print(f"Alignment matrix shape: {alignment.shape}")

# Update spatial coordinates in the objects to use the aligned coordinates
# Spateo stores aligned coordinates in obsm['align_spatial_rigid'] or similar depending on mode
# But morpho_align with key_added='align_spatial' usually adds 'align_spatial' and 'align_spatial_rigid'
# We will use 'align_spatial_rigid' for rigid alignment visualization if available, or 'align_spatial'
if 'align_spatial_rigid' in slice_ref1.obsm:
    slice_ref1.obsm['spatial'] = slice_ref1.obsm['align_spatial_rigid']
    slice_ref2.obsm['spatial'] = slice_ref2.obsm['align_spatial_rigid']
elif 'align_spatial' in slice_ref1.obsm:
    slice_ref1.obsm['spatial'] = slice_ref1.obsm['align_spatial']
    slice_ref2.obsm['spatial'] = slice_ref2.obsm['align_spatial']

print("Spatial coordinates updated to aligned positions.")

## 3. Prepare for Interpolation

We ensure that both slices have the exact same set of genes and are properly normalized for the interpolation step.

In [None]:
# Find common genes
common_genes = list(set(slice_ref1.var_names) & set(slice_ref2.var_names))
print(f"Common genes: {len(common_genes)}")

# Subset to common genes
slice_ref1 = slice_ref1[:, common_genes].copy()
slice_ref2 = slice_ref2[:, common_genes].copy()

# Verify counts layer exists
if 'counts' not in slice_ref1.layers:
    print("Warning: 'counts' layer missing in Slice 1")
if 'counts' not in slice_ref2.layers:
    print("Warning: 'counts' layer missing in Slice 2")

print("Data preparation complete.")
print(f"Final Slice 1 shape: {slice_ref1.shape}")
print(f"Final Slice 2 shape: {slice_ref2.shape}")

## 4. Run Interpolation

We will generate intermediate slices between the two reference slices.

In [None]:
# Configuration
n_slices = 5  # Number of intermediate slices
t_values = np.linspace(0, 1, n_slices + 2)[1:-1]

# Store all slices
all_slices = []
slice_metadata = []

# Add Reference 1
slice_ref1.obs['slice_type'] = 'reference'
all_slices.append(slice_ref1)
slice_metadata.append({'type': 'reference', 't': 0.0, 'z': 0.0})

# Run Interpolation Loop
print("Starting interpolation...")
for i, t in enumerate(t_values):
    print(f"Generating slice {i+1}/{len(t_values)} at t={t:.2f}...")
    
    # Configure interpolation parameters
    config = InterpolationConfig(
        t=t,
        use_normalized=True,
        ot_method='sinkhorn',
        ot_regularization=0.05,  # Tuned parameter
        feature_weights={'mean': 3.0, 'variance': 1.0, 'zero_prop': 1.0}, # Tuned weights
        sigma=1.0,
        verbose=False
    )
    
    # Interpolate
    interpolated_slice = interpolate_slices(
        adata1=slice_ref1,
        adata2=slice_ref2,
        alignment_matrix=alignment,
        config=config
    )
    
    interpolated_slice.obs['slice_type'] = 'interpolated'
    all_slices.append(interpolated_slice)
    slice_metadata.append({'type': 'interpolated', 't': t, 'z': t * 100.0})

# Add Reference 2
slice_ref2.obs['slice_type'] = 'reference'
all_slices.append(slice_ref2)
slice_metadata.append({'type': 'reference', 't': 1.0, 'z': 100.0})

print(f"Interpolation complete. Total slices: {len(all_slices)}")

## 5. Visualization

We will visualize the 3D spatial expression of specific genes: **Chat** and **Dlk1**.

In [None]:
def plot_gene_3d(all_slices, slice_metadata, gene_name):
    """Plot a single gene in 3D across all slices."""
    if gene_name not in all_slices[0].var_names:
        print(f"Gene {gene_name} not found in dataset.")
        return

    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    for slice_data, metadata in zip(all_slices, slice_metadata):
        # Get expression
        gene_idx = list(slice_data.var_names).index(gene_name)
        X = slice_data.X.toarray() if hasattr(slice_data.X, 'toarray') else slice_data.X
        expr = X[:, gene_idx]
        
        # Get coordinates
        if 'spatial' in slice_data.obsm:
            coords = slice_data.obsm['spatial']
            z = np.full(len(coords), metadata['z'])
            
            # Style based on type
            alpha = 0.4 if metadata['type'] == 'reference' else 0.2
            size = 15 if metadata['type'] == 'reference' else 10
            
            p = ax.scatter(coords[:, 0], coords[:, 1], z, c=expr, cmap='viridis', 
                          s=size, alpha=alpha, linewidth=0)
            
    ax.set_title(f"3D Expression: {gene_name}")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.set_zlabel("Z (Depth)")
    plt.colorbar(p, label='Expression', shrink=0.5)
    plt.show()

# Plot specific genes
plot_gene_3d(all_slices, slice_metadata, "Chat")
plot_gene_3d(all_slices, slice_metadata, "Dlk1")

## 6. Evaluation

We evaluate the quality of interpolation using the **Gene Continuity Score**. This metric measures how smoothly gene expression changes across the interpolated Z-axis.

In [None]:
def calculate_continuity(all_slices, slice_metadata, gene_name):
    """Calculate continuity score for a gene."""
    means = []
    for slice_data in all_slices:
        if gene_name in slice_data.var_names:
            gene_idx = list(slice_data.var_names).index(gene_name)
            X = slice_data.X.toarray() if hasattr(slice_data.X, 'toarray') else slice_data.X
            means.append(X[:, gene_idx].mean())
            
    means = np.array(means)
    if len(means) < 2 or np.mean(means) == 0:
        return 0.0
        
    # Continuity = 1 - (std of differences / mean expression)
    score = 1.0 - np.std(np.diff(means)) / np.mean(means)
    return score, means

# Calculate for Chat and Dlk1
genes_to_check = ["Chat", "Dlk1"]

print("Interpolation Performance (Continuity Score):")
print("-" * 40)

for gene in genes_to_check:
    score, means = calculate_continuity(all_slices, slice_metadata, gene)
    print(f"{gene}: {score:.4f}")
    
    # Optional: Plot the mean expression trajectory
    plt.figure(figsize=(6, 3))
    z_vals = [m['z'] for m in slice_metadata]
    plt.plot(z_vals, means, 'o-', label=gene)
    plt.title(f"{gene} Expression Trajectory")
    plt.xlabel("Z Position")
    plt.ylabel("Mean Expression")
    plt.grid(True, alpha=0.3)
    plt.show()