# Spatial Topic Modeling for IMC Data Analysis

## Overview
This notebook demonstrates the application of **Latent Dirichlet Allocation (LDA) topic modeling** to Imaging Mass Cytometry (IMC) data to discover spatial cellular organization patterns (spatial motifs).

---

## üéØ What This Notebook Does

### The Big Idea
Just as topic modeling discovers themes in documents by analyzing word co-occurrence, **spatial topic modeling** discovers recurring patterns of cell type co-occurrence in tissue space. This helps identify:
- **Spatial motifs**: Recurring cellular neighborhoods (e.g., "tumor-immune interface", "lymphoid aggregates")
- **Tissue organization**: How different cell types organize spatially
- **Biological structures**: Functionally relevant spatial patterns

### Key Concepts

**1. Spatial LDA Analogy**
- **Documents** = Cell neighborhoods (spatial regions around each cell)
- **Words** = Cell phenotypes/types found in each neighborhood
- **Topics** = Spatial motifs (recurring patterns of cell type co-occurrence)

**2. Why Coherence Analysis?**
- Scimap's `spatial_lda` requires you to specify the number of topics (`num_motifs`)
- Choosing this number arbitrarily can lead to suboptimal results
- **Coherence analysis** tests multiple topic numbers and selects the one with highest coherence (best topic quality)
- This makes the analysis **data-driven** rather than arbitrary

**3. The Workflow**
1. Extract spatial neighborhoods around each cell
2. Use coherence analysis to find optimal number of topics
3. Apply LDA topic modeling with optimal topics
4. Use scimap's spatial_lda with the data-driven topic number
5. Visualize and interpret discovered spatial patterns

---

## üìã Prerequisites

- IMC data with:
  - Expression matrix (cells √ó features)
  - Metadata with: X/Y coordinates, cell types, image IDs
- Required packages: `scimap`, `anndata`, `scanpy`, `gensim`, `pyLDAvis`

---

## üöÄ Quick Start

1. **Load your data** (Section 2) - Update file paths and column names
2. **Run coherence analysis** (Section 5) - Determines optimal topics
3. **Apply spatial LDA** (Section 8) - Uses optimal topic number
4. **Visualize results** (Section 10) - Explore discovered patterns

**Note**: All column names (e.g., `Location_Center_X`, `major_celltype`) should be adjusted to match your metadata.

## 1. Import Libraries

In [None]:
# Core libraries
import scimap as sm
import anndata as ad
import pandas as pd
import scanpy as sc
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt

# Machine learning and topic modeling
from sklearn.neighbors import BallTree
from sklearn.cluster import MiniBatchKMeans
from gensim import corpora
from gensim.models import LdaModel, CoherenceModel

# Visualization
import pyLDAvis
import pyLDAvis.gensim_models as gensimvis

# Settings
sns.set(color_codes=True)
sc.settings.verbosity = 3
import warnings
warnings.filterwarnings('ignore')

print("Libraries imported successfully")

## 2. Load Data

Load your IMC expression data and metadata. Adjust file paths and column names to match your data structure.

In [None]:
# OPTION A: Load your own data
# Uncomment and adjust paths:
# data = pd.read_csv('expression_data.csv')  # Expression matrix (cells x features)
# meta = pd.read_csv('metadata.csv')        # Cell metadata (cells x annotations)
# adata = ad.AnnData(data)
# adata.obs = meta

# OPTION B: Generate sample data for testing
# Uncomment to generate synthetic IMC data:
print("Generating sample IMC data for demonstration...")
np.random.seed(42)

# Generate synthetic data
n_cells = 2000
n_features = 10
n_images = 3

# Expression data (random for demo)
data = pd.DataFrame(
    np.random.rand(n_cells, n_features),
    columns=[f'Marker_{i+1}' for i in range(n_features)]
)

# Metadata with spatial coordinates and cell types
cell_types = ['T_cell', 'B_cell', 'Macrophage', 'Tumor', 'Stroma']
meta = pd.DataFrame({
    'Location_Center_X': np.random.uniform(0, 1000, n_cells),
    'Location_Center_Y': np.random.uniform(0, 1000, n_cells),
    'major_celltype': np.random.choice(cell_types, n_cells),
    'ImageNumber': np.random.choice(range(1, n_images+1), n_cells),
    'celltype_detail': np.random.choice(cell_types, n_cells)  # For optional analyses
})

# Create AnnData object
adata = ad.AnnData(data)
adata.obs = meta

print(f"‚úì Generated {adata.n_obs} cells and {adata.n_vars} features")
print(f"‚úì Metadata columns: {list(adata.obs.columns)}")
print(f"‚úì Cell types: {adata.obs['major_celltype'].unique()}")
print(f"‚úì Images: {adata.obs['ImageNumber'].unique()}")
print("\nNote: This is synthetic data. Replace with your own data for real analysis.")

## 3. Extract Spatial Neighborhoods

This function extracts spatial neighborhoods around each cell, which will be used as "documents" for topic modeling. This is needed to perform coherence analysis for optimal topic selection.

In [None]:
def extract_spatial_neighborhoods(adata, x_coordinate='X_centroid', y_coordinate='Y_centroid',
                                  phenotype='phenotype', method='radius', radius=30, 
                                  imageid='imageid', subset=None):
    """
    Extract spatial neighborhoods for each cell.
    
    Parameters:
    ----------
    adata : AnnData
        Annotated data object
    x_coordinate : str
        Column name for X coordinates
    y_coordinate : str
        Column name for Y coordinates
    phenotype : str
        Column name for cell phenotypes/types
    method : str
        Method for neighborhood identification ('radius' or 'knn')
    radius : float
        Radius in pixels for neighborhood identification
    imageid : str
        Column name for image IDs
    subset : int or None
        Process specific image ID or None for all images
        
    Returns:
    -------
    all_neighborhoods : list
        List of neighborhoods (each is a list of cell phenotypes)
    """
    def process_image(adata_subset, x_coordinate, y_coordinate, phenotype, method, radius, imageid):
        """Process a single image."""
        image_id = np.unique(adata_subset.obs[imageid])[0]
        print(f'Processing image: {image_id}')
        
        # Create DataFrame with coordinates and phenotypes
        df = pd.DataFrame({
            'x': adata_subset.obs[x_coordinate],
            'y': adata_subset.obs[y_coordinate],
            'phenotype': adata_subset.obs[phenotype]
        })
        
        # Find neighbors using BallTree
        if method == 'radius':
            print(f"  Identifying neighbours within {radius} pixels of every cell")
            kdt = BallTree(df[['x', 'y']], leaf_size=2)
            neighbor_indices = kdt.query_radius(df[['x', 'y']], r=radius, return_distance=False)
        
        # Map indices to phenotypes
        phenotype_map = dict(zip(range(len(neighbor_indices)), df['phenotype']))
        neighborhoods = []
        for indices in neighbor_indices:
            neighborhoods.append([phenotype_map[idx] for idx in indices])
        
        return neighborhoods
    
    # Process all images or subset
    if subset is not None:
        adata_list = [adata[adata.obs[imageid] == subset]]
    else:
        adata_list = [adata[adata.obs[imageid] == i] for i in adata.obs[imageid].unique()]
    
    # Extract neighborhoods for all images
    all_neighborhoods = []
    for adata_subset in adata_list:
        neighborhoods = process_image(adata_subset, x_coordinate, y_coordinate, 
                                     phenotype, method, radius, imageid)
        all_neighborhoods.extend(neighborhoods)
    
    return all_neighborhoods

In [None]:
# Extract spatial neighborhoods
# Adjust column names to match your metadata
spatial_neighborhoods = extract_spatial_neighborhoods(
    adata,
    x_coordinate='Location_Center_X',    # X coordinate column name
    y_coordinate='Location_Center_Y',    # Y coordinate column name
    phenotype='major_celltype',         # Cell type/phenotype column name
    method='radius',
    radius=30,                           # Neighborhood radius in pixels
    imageid='ImageNumber',               # Image ID column name
    subset=None                          # None = process all images
)

print(f"\nExtracted {len(spatial_neighborhoods)} spatial neighborhoods")

## 4. Prepare Data for Topic Modeling

Convert spatial neighborhoods to the format required by Gensim LDA.

In [None]:
# Convert neighborhoods to list of lists (documents)
texts = spatial_neighborhoods

# Create dictionary mapping cell types to IDs
id2word = corpora.Dictionary(texts)

# Create corpus (bag of words representation)
corpus = [id2word.doc2bow(text) for text in texts]

print(f"Total neighborhoods (documents): {len(corpus)}")
print(f"Unique cell types (vocabulary size): {len(id2word)}")
print(f"\nCell types in vocabulary: {list(id2word.values())[:10]}...")  # Show first 10

## 5. Coherence Analysis for Optimal Topic Selection

This step determines the optimal number of topics by computing coherence scores for different topic numbers. Higher coherence indicates better topic quality.

In [None]:
def compute_coherence_values(dictionary, corpus, texts, limit, start=2, step=2, random_state=0):
    """
    Compute c_v coherence for various number of topics.

    Parameters:
    ----------
    dictionary : Gensim dictionary
        Dictionary mapping words to IDs
    corpus : list
        Gensim corpus (bag of words)
    texts : list
        List of input texts (neighborhoods)
    limit : int
        Maximum number of topics to test
    start : int
        Starting number of topics
    step : int
        Step size for topic numbers
    random_state : int
        Random seed for reproducibility

    Returns:
    -------
    model_list : list
        List of LDA topic models
    coherence_values : list
        Coherence values corresponding to each model
    """
    coherence_values = []
    model_list = []
    
    for num_topics in range(start, limit, step):
        print(f"Testing {num_topics} topics...")
        model = LdaModel(corpus=corpus, num_topics=num_topics, id2word=dictionary, 
                        random_state=random_state, passes=10, alpha='auto', per_word_topics=True)
        model_list.append(model)
        
        coherencemodel = CoherenceModel(model=model, texts=texts, dictionary=dictionary, 
                                       coherence='c_v')
        coherence = coherencemodel.get_coherence()
        coherence_values.append(coherence)
        print(f"  Coherence score: {coherence:.4f}")
    
    return model_list, coherence_values

In [None]:
# Compute coherence for different numbers of topics
limit = 12      # Maximum number of topics to test
start = 4       # Starting number of topics
step = 2        # Step size

model_list, coherence_values = compute_coherence_values(
    dictionary=id2word,
    corpus=corpus,
    texts=texts,
    start=start,
    limit=limit,
    step=step,
    random_state=0
)

In [None]:
# Plot coherence values
x = range(start, limit, step)

plt.figure(figsize=(10, 6))
plt.plot(x, coherence_values, marker='o', linewidth=2, markersize=8)
plt.xlabel("Number of Topics", fontsize=12)
plt.ylabel("Coherence Score", fontsize=12)
plt.title("Topic Model Coherence Analysis", fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print results
print("\nCoherence Scores:")
for num_topics, coherence in zip(x, coherence_values):
    print(f"  {num_topics} topics: {coherence:.4f}")

# Find optimal number of topics (highest coherence)
optimal_idx = np.argmax(coherence_values)
optimal_topics = x[optimal_idx]
optimal_coherence = coherence_values[optimal_idx]

print(f"\n‚úì Optimal number of topics: {optimal_topics} (coherence: {optimal_coherence:.4f})")

## 6. Build Final LDA Model with Optimal Topics

Use the optimal number of topics determined from coherence analysis.

In [None]:
# Build final LDA model with optimal number of topics
# You can also manually set num_topics if you prefer
num_topics = optimal_topics  # or set manually, e.g., num_topics = 6

lda_model = LdaModel(
    corpus=corpus,
    id2word=id2word,
    num_topics=num_topics,
    random_state=0,
    passes=20,           # Number of passes through the corpus
    alpha='auto',        # Automatic alpha estimation
    per_word_topics=True
)

print(f"LDA model created with {num_topics} topics")

# Display topics
print("\nTop words for each topic:")
for idx, topic in lda_model.print_topics(-1, num_words=5):
    print(f"\nTopic {idx}:")
    print(topic)

## 7. Visualize Topics with pyLDAvis

### Interactive Topic Exploration
pyLDAvis provides an interactive visualization where you can:
- See which cell types are most associated with each topic
- Explore topic similarity and overlap
- Understand the relative importance of topics

**Note**: This visualization helps validate that topics are meaningful and interpretable.

In [None]:
# Prepare visualization
pyLDAvis.enable_notebook()
vis = gensimvis.prepare(lda_model, corpus, id2word, sort_topics=False)
vis

## 8. Apply Spatial LDA using scimap

### The Key Step
Now we use scimap's built-in `spatial_lda` function with the **optimal number of topics** determined from coherence analysis. This is the main analysis that assigns each cell a probability distribution over spatial motifs.

### What Gets Stored
- Results in `adata.uns['spatial_lda']`: Matrix of cells √ó topics (motif probabilities)
- Each cell gets a probability score for each discovered spatial motif
- Higher probability = cell is more likely to be part of that spatial pattern

### Why This Matters
By using the data-driven optimal topic number, we ensure the discovered motifs are:
- **Meaningful**: High coherence = interpretable patterns
- **Appropriate**: Not too few (missing structure) or too many (overfitting)
- **Reproducible**: Based on data properties, not arbitrary choices

In [None]:
# Apply scimap's spatial_lda with optimal number of topics
sm.tl.spatial_lda(
    adata,
    x_coordinate='Location_Center_X',
    y_coordinate='Location_Center_Y',
    phenotype='major_celltype',
    method='radius',
    radius=30,
    knn=60,
    imageid='ImageNumber',
    num_motifs=num_topics,  # Use optimal number from coherence analysis
    random_state=0,
    subset=None,
    label='spatial_lda'
)

print(f"\nSpatial LDA completed. Results stored in adata.uns['spatial_lda']")
print(f"Shape: {adata.uns['spatial_lda'].shape}")

## 9. Cluster Spatial Motifs

### Purpose
Group cells with similar spatial motif profiles together. This creates spatial "communities" or "regions" in the tissue.

### What Happens
- K-means clustering on the motif probability matrix
- Cells with similar motif profiles cluster together
- Creates discrete spatial groups (e.g., "spatial-0", "spatial-1", etc.)

### Adjustable
- `n_clusters`: Number of spatial groups to identify (default: 6)

In [None]:
# Fill NaN values and cluster spatial motifs
adata.uns['spatial_lda'] = adata.uns['spatial_lda'].fillna(0)

# K-means clustering of spatial motifs
n_clusters = 6  # Adjust based on your data
kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=0)
kmeans.fit(adata.uns['spatial_lda'])

# Add cluster labels to metadata
cluster_labels = ['spatial-' + str(label) for label in kmeans.labels_]
adata.obs['spatial_kmeans'] = cluster_labels

print(f"Clustered cells into {n_clusters} spatial groups")
print(f"\nCluster distribution:")
print(adata.obs['spatial_kmeans'].value_counts())

## 10. Visualize Spatial Motifs

### Visualizations Provided

1. **Dot Plot**: Shows which cell types are enriched in each spatial cluster
2. **Stacked Bar Plot**: Shows cell type composition proportions across spatial clusters
3. **Voronoi Plot**: Spatial map showing where each spatial cluster appears in tissue

### Interpretation Tips
- **Spatial clusters** represent regions with distinct cellular organization
- Compare cell type compositions between clusters to understand what makes each region unique
- Voronoi plots show the spatial distribution - are clusters localized or dispersed?

In [None]:
# Dot plot showing cell type composition of each spatial cluster
sc.pl.dotplot(
    adata,
    var_names=adata.var.index,
    groupby='spatial_kmeans',
    dendrogram=False,
    use_raw=False,
    expression_cutoff=0.6,
    standard_scale='var',
    cmap='magma'
)

In [None]:
# Stacked bar plot showing cell type proportions in each spatial cluster
sm.pl.stacked_barplot(
    adata,
    x_axis='spatial_kmeans',
    y_axis='major_celltype',
    method='percent',
    plot_tool='matplotlib',
    figsize=(10, 10),
    matplotlib_cmap='Paired'
)

In [None]:
# Voronoi plot showing spatial distribution of motifs
# Adjust subset to visualize specific images
sc.set_figure_params(dpi=100, dpi_save=200)

sm.pl.voronoi(
    adata,
    color_by='spatial_kmeans',
    colors=None,
    x_coordinate='Location_Center_X',
    y_coordinate='Location_Center_Y',
    imageid='ImageNumber',
    subset=1,  # Change to visualize different images
    voronoi_edge_color='black',
    voronoi_line_width=0.2,
    voronoi_alpha=1,
    overlay_points_categories=None,
    overlay_drop_categories=None,
    overlay_point_size=5,
    overlay_point_alpha=1,
    overlay_point_shape=".",
    plot_legend=True,
    legend_size=6,
    figsize=(15, 15)
)

## 11. Additional Spatial Analyses (Optional)

### Complementary Analyses

These scimap functions provide additional spatial insights:

1. **Spatial Proximity Score**: Measures how often specific cell types are found near each other
2. **Spatial Distance**: Calculates distances between cell types
3. **Spatial Interaction**: Tests for significant co-localization between cell type pairs

**Note**: Adjust cell type names in these sections to match your data.

In [None]:
# Spatial proximity score
sm.tl.spatial_pscore(
    adata,
    proximity=['Bcell', 'CD4'],  # Adjust to your cell types of interest
    score_by='ImageNumber',
    x_coordinate='Location_Center_X',
    y_coordinate='Location_Center_Y',
    phenotype='celltype_detail',  # Adjust to your detailed cell type column
    method='radius',
    radius=20,
    knn=3,
    imageid='ImageNumber',
    subset=None,
    label='spatial_pscore'
)

In [None]:
# Spatial distance analysis
sm.tl.spatial_distance(
    adata,
    x_coordinate='Location_Center_X',
    y_coordinate='Location_Center_Y',
    phenotype='major_celltype',
    subset=None,
    imageid='ImageNumber',
    label='spatial_distance'
)

In [None]:
# Spatial interaction analysis
sm.tl.spatial_interaction(
    adata,
    x_coordinate='Location_Center_X',
    y_coordinate='Location_Center_Y',
    phenotype='major_celltype',
    method='radius',
    radius=30,
    knn=10,
    permutation=1000,
    imageid='ImageNumber',
    subset=None,
    pval_method='zscore',
    label='spatial_interaction'
)

# Visualize spatial interactions
sm.pl.spatial_interaction(
    adata,
    summarize_plot=False,
    row_cluster=False,
    col_cluster=False,
    yticklabels=True,
    p_val=0.05
)

## Summary

### What This Notebook Achieves

1. **Data-Driven Topic Selection**: Coherence analysis determines optimal number of spatial motifs
2. **Spatial Pattern Discovery**: Identifies recurring cellular organization patterns in tissue
3. **Interpretable Results**: High coherence ensures discovered motifs are biologically meaningful
4. **Comprehensive Analysis**: From topic modeling to spatial visualization

### Key Innovation

**The coherence workflow solves a critical problem**: Instead of guessing how many topics to use, we let the data tell us. This makes spatial LDA analysis:
- ‚úÖ **Reproducible**: Same data ‚Üí same optimal topics
- ‚úÖ **Interpretable**: High coherence = clear, meaningful patterns
- ‚úÖ **Robust**: Less sensitive to arbitrary parameter choices

### Expected Outputs

- Optimal number of topics (from coherence analysis)
- Spatial motif assignments for each cell
- Spatial clusters showing distinct tissue regions
- Visualizations of spatial patterns

### Next Steps

- Interpret discovered spatial motifs biologically
- Compare motifs across different conditions/groups
- Validate findings with domain knowledge
- Use motifs as features for downstream analyses

---

## üìù Notes for Users

- **Column names**: Update all column references to match your metadata
- **File paths**: Update data loading paths to your data location
- **Parameters**: Adjust radius, topic ranges, cluster numbers based on your data scale
- **Cell types**: Update cell type names in optional analyses sections

---

## üîó References

- Scimap documentation: https://scimap.readthedocs.io/
- LDA topic modeling: Blei et al. (2003) JMLR
- Coherence metrics: R√∂der et al. (2015) EMNLP