# Attention Pooling for Cell Classification

Compare two approaches:
1. **Baseline**: Weighted sum (expression × ProteinBERT) → LogReg
2. **Attention Pooling**: Learned attention (sees embedding + expression) → Classifier

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent
sys.path.insert(0, str(PROJECT_ROOT / "src"))

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import cellxgene_census
import scanpy as sc
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, adjusted_rand_score, silhouette_score
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from umap import UMAP

from model.cell_embeddings import load_gene_embeddings
from model.attention_pooling import AttentionPooling, CellDataset, compute_baseline_embeddings
from preprocess_data.config import OUTPUT_FILES, CENSUS_QUERY_PARAMS

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Gene Embeddings

In [2]:
gene_to_embedding = load_gene_embeddings(OUTPUT_FILES["gene_to_embedding"])
print(f"Loaded {len(gene_to_embedding)} gene embeddings (512-dim)")

Loaded 19294 gene embeddings (512-dim)


## 2. Load Data from Census

In [3]:
N_TRAIN = 500  # cells per condition for training
N_TEST = 200   # cells per condition for testing

healthy_filter = CENSUS_QUERY_PARAMS["healthy_filter"]
cancer_filter = CENSUS_QUERY_PARAMS["cancer_filter"]
min_genes = CENSUS_QUERY_PARAMS["min_genes"]

print(f"Healthy filter: {healthy_filter}")
print(f"Cancer filter: {cancer_filter}")

Healthy filter: tissue_general == 'breast' and disease == 'normal' and assay == "10x 3' v3" and is_primary_data == True
Cancer filter: tissue_general == 'breast' and disease == 'breast cancer' and assay == "10x 3' v3" and is_primary_data == True


In [4]:
# Query cell IDs from Census
print("Querying cell IDs from Census...")

with cellxgene_census.open_soma(census_version="stable") as census:
    healthy_obs = census["census_data"]["homo_sapiens"].obs.read(
        value_filter=healthy_filter,
        column_names=["soma_joinid"]
    ).concat().to_pandas()
    
    cancer_obs = census["census_data"]["homo_sapiens"].obs.read(
        value_filter=cancer_filter,
        column_names=["soma_joinid"]
    ).concat().to_pandas()

print(f"Available: {len(healthy_obs)} healthy, {len(cancer_obs)} cancer cells")

# Split into train and test
np.random.seed(42)
healthy_ids = healthy_obs["soma_joinid"].sample(n=N_TRAIN + N_TEST).tolist()
cancer_ids = cancer_obs["soma_joinid"].sample(n=N_TRAIN + N_TEST).tolist()

train_healthy_ids = healthy_ids[:N_TRAIN]
test_healthy_ids = healthy_ids[N_TRAIN:]
train_cancer_ids = cancer_ids[:N_TRAIN]
test_cancer_ids = cancer_ids[N_TRAIN:]

print(f"Train: {len(train_healthy_ids)} healthy + {len(train_cancer_ids)} cancer")
print(f"Test: {len(test_healthy_ids)} healthy + {len(test_cancer_ids)} cancer")

Querying cell IDs from Census...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


Available: 1726582 healthy, 34164 cancer cells
Train: 500 healthy + 500 cancer
Test: 200 healthy + 200 cancer


In [None]:
# Fetch expression data
print("Fetching expression data...")

with cellxgene_census.open_soma(census_version="stable") as census:
    # Training data
    train_healthy_adata = cellxgene_census.get_anndata(
        census=census, organism="Homo sapiens",
        measurement_name="RNA", X_name="raw",
        obs_coords=train_healthy_ids
    )
    train_cancer_adata = cellxgene_census.get_anndata(
        census=census, organism="Homo sapiens",
        measurement_name="RNA", X_name="raw",
        obs_coords=train_cancer_ids
    )
    
    # Test data
    test_healthy_adata = cellxgene_census.get_anndata(
        census=census, organism="Homo sapiens",
        measurement_name="RNA", X_name="raw",
        obs_coords=test_healthy_ids
    )
    test_cancer_adata = cellxgene_census.get_anndata(
        census=census, organism="Homo sapiens",
        measurement_name="RNA", X_name="raw",
        obs_coords=test_cancer_ids
    )

print(f"Train healthy: {train_healthy_adata.shape}")
print(f"Train cancer: {train_cancer_adata.shape}")
print(f"Test healthy: {test_healthy_adata.shape}")
print(f"Test cancer: {test_cancer_adata.shape}")

Fetching expression data...


The "stable" release is currently 2025-01-30. Specify 'census_version="2025-01-30"' in future calls to open_soma() to ensure data consistency.


In [None]:
# Apply QC and fix gene names
def preprocess_adata(adata, min_genes=200):
    sc.pp.filter_cells(adata, min_genes=min_genes)
    adata.var_names = adata.var['feature_name'].values
    return adata

train_healthy_adata = preprocess_adata(train_healthy_adata, min_genes)
train_cancer_adata = preprocess_adata(train_cancer_adata, min_genes)
test_healthy_adata = preprocess_adata(test_healthy_adata, min_genes)
test_cancer_adata = preprocess_adata(test_cancer_adata, min_genes)

print(f"After QC:")
print(f"  Train: {train_healthy_adata.n_obs} healthy + {train_cancer_adata.n_obs} cancer")
print(f"  Test: {test_healthy_adata.n_obs} healthy + {test_cancer_adata.n_obs} cancer")

## 3. Define Attention Pooling Model

In [None]:
# Model imported from src/model/attention_pooling.py
print("AttentionPooling model loaded from src/model/attention_pooling.py")
model = AttentionPooling().to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Create Dataset and DataLoader

In [None]:
# CellDataset imported from src/model/attention_pooling.py
print("CellDataset loaded from src/model/attention_pooling.py")

In [None]:
# Create datasets
train_labels_healthy = np.zeros(train_healthy_adata.n_obs)
train_labels_cancer = np.ones(train_cancer_adata.n_obs)
test_labels_healthy = np.zeros(test_healthy_adata.n_obs)
test_labels_cancer = np.ones(test_cancer_adata.n_obs)

print("Creating training dataset...")
train_dataset = CellDataset(
    [train_healthy_adata, train_cancer_adata],
    [train_labels_healthy, train_labels_cancer],
    gene_to_embedding
)

print("Creating test dataset...")
test_dataset = CellDataset(
    [test_healthy_adata, test_cancer_adata],
    [test_labels_healthy, test_labels_cancer],
    gene_to_embedding
)

print(f"Train: {len(train_dataset)} cells")
print(f"Test: {len(test_dataset)} cells")

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

## 5. Train Attention Pooling Model

In [None]:
model = AttentionPooling().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCELoss()

n_epochs = 50
train_losses = []

print("Training Attention Pooling model...")
print("=" * 50)

for epoch in range(n_epochs):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        embeddings = batch['embeddings'].to(device)
        expression = batch['expression'].to(device)
        mask = batch['mask'].to(device)
        labels = batch['label'].to(device)
        
        pred, _, _ = model(embeddings, expression, mask)
        loss = criterion(pred.squeeze(), labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)
    
    if epoch % 10 == 0 or epoch == n_epochs - 1:
        print(f"Epoch {epoch:3d}: Loss = {avg_loss:.4f}")

print("=" * 50)
print("Training complete!")

In [None]:
# Plot training loss
plt.figure(figsize=(10, 4))
plt.plot(train_losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True, alpha=0.3)
plt.show()

## 6. Save Model Weights

In [None]:
# Save model weights
weights_dir = PROJECT_ROOT / "models" / "attention_pooling"
weights_dir.mkdir(parents=True, exist_ok=True)

weights_path = weights_dir / "attention_pooling.pt"
torch.save(model.state_dict(), weights_path)
print(f"Model weights saved to: {weights_path}")

## 7. Evaluate on Test Set

In [None]:
def evaluate_model(model, dataloader, device):
    """Evaluate attention pooling model and return predictions + embeddings."""
    model.eval()
    all_preds = []
    all_probs = []
    all_embeddings = []
    all_labels = []
    
    with torch.no_grad():
        for batch in dataloader:
            embeddings = batch['embeddings'].to(device)
            expression = batch['expression'].to(device)
            mask = batch['mask'].to(device)
            labels = batch['label']
            
            probs, cell_emb, _ = model(embeddings, expression, mask)
            probs = probs.squeeze().cpu().numpy()
            preds = (probs > 0.5).astype(int)
            
            all_probs.extend(probs.tolist() if hasattr(probs, '__iter__') else [probs])
            all_preds.extend(preds.tolist() if hasattr(preds, '__iter__') else [preds])
            all_embeddings.append(cell_emb.cpu().numpy())
            all_labels.extend(labels.numpy().tolist())
    
    return (
        np.array(all_labels),
        np.array(all_preds),
        np.array(all_probs),
        np.vstack(all_embeddings)
    )

# Evaluate attention model
print("Evaluating Attention Pooling on test set...")
attn_labels, attn_preds, attn_probs, attn_embeddings = evaluate_model(model, test_loader, device)
print(f"Test samples: {len(attn_labels)}")

## 8. Compute Baseline (Weighted Sum + LogReg)

In [None]:
# compute_baseline_embeddings imported from src/model/attention_pooling.py
print("Computing baseline embeddings (weighted sum)...")
train_labels, train_baseline_emb = compute_baseline_embeddings(train_dataset)
test_labels, test_baseline_emb = compute_baseline_embeddings(test_dataset)

print(f"Train baseline: {train_baseline_emb.shape}")
print(f"Test baseline: {test_baseline_emb.shape}")

# Train LogReg on baseline embeddings
print("\nTraining LogReg on baseline embeddings...")
logreg = LogisticRegression(max_iter=1000, random_state=42)
logreg.fit(train_baseline_emb, train_labels)

# Predict
baseline_preds = logreg.predict(test_baseline_emb)
baseline_probs = logreg.predict_proba(test_baseline_emb)[:, 1]
print("Baseline evaluation complete.")

## 9. Compare Metrics

In [None]:
def get_all_metrics(embeddings, y_true, y_pred, y_probs):
    """Compute classification and clustering metrics."""
    # Classification
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro')
    auc = roc_auc_score(y_true, y_probs)
    
    # Clustering (embedding quality)
    kmeans = KMeans(n_clusters=2, random_state=42, n_init=10).fit(embeddings)
    ari = adjusted_rand_score(y_true, kmeans.labels_)
    asw = silhouette_score(embeddings, y_true)
    
    return {"Accuracy": acc, "Macro F1": f1, "AUC": auc, "ARI": ari, "ASW": asw}

In [None]:
# Compute metrics for both methods
print("=" * 70)
print("COMPARISON: BASELINE vs ATTENTION POOLING")
print("=" * 70)

baseline_metrics = get_all_metrics(test_baseline_emb, test_labels, baseline_preds, baseline_probs)
attention_metrics = get_all_metrics(attn_embeddings, attn_labels, attn_preds, attn_probs)

# Create comparison table
print(f"\n{'Metric':<15} {'Weighted Sum':<15} {'Attention':<15} {'Diff':<10}")
print("-" * 55)
for metric in baseline_metrics:
    b = baseline_metrics[metric]
    a = attention_metrics[metric]
    diff = a - b
    sign = "+" if diff > 0 else ""
    print(f"{metric:<15} {b:<15.3f} {a:<15.3f} {sign}{diff:.3f}")

# Save comparison table
comparison_df = pd.DataFrame({
    'Metric': list(baseline_metrics.keys()),
    'Weighted Sum (Baseline)': list(baseline_metrics.values()),
    'Attention Pooling': list(attention_metrics.values())
})
comparison_df['Improvement'] = comparison_df['Attention Pooling'] - comparison_df['Weighted Sum (Baseline)']
comparison_df.to_csv(PROJECT_ROOT / "data" / "attention_comparison.csv", index=False)
print(f"\nComparison saved to: {PROJECT_ROOT / 'data' / 'attention_comparison.csv'}")

## 10. UMAP Visualization (The "Money Shot")

In [None]:
# UMAP for both embedding methods
print("Computing UMAP projections...")

umap_baseline = UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(test_baseline_emb)
umap_attention = UMAP(n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(attn_embeddings)

# Plot side by side
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

colors = ['#2ecc71', '#e74c3c']
labels_names = ['Healthy', 'Cancer']

# Baseline
ax1 = axes[0]
for i, (color, name) in enumerate(zip(colors, labels_names)):
    mask = test_labels == i
    ax1.scatter(umap_baseline[mask, 0], umap_baseline[mask, 1], c=color, label=name, alpha=0.7, s=50)
ax1.set_xlabel('UMAP 1')
ax1.set_ylabel('UMAP 2')
ax1.set_title(f'Baseline (Weighted Sum)\nAcc={baseline_metrics["Accuracy"]:.2%}, ARI={baseline_metrics["ARI"]:.3f}')
ax1.legend()

# Attention
ax2 = axes[1]
for i, (color, name) in enumerate(zip(colors, labels_names)):
    mask = attn_labels == i
    ax2.scatter(umap_attention[mask, 0], umap_attention[mask, 1], c=color, label=name, alpha=0.7, s=50)
ax2.set_xlabel('UMAP 1')
ax2.set_ylabel('UMAP 2')
ax2.set_title(f'Attention Pooling\nAcc={attention_metrics["Accuracy"]:.2%}, ARI={attention_metrics["ARI"]:.3f}')
ax2.legend()

plt.tight_layout()
plt.savefig(PROJECT_ROOT / "data" / "attention_umap_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"UMAP plot saved to: {PROJECT_ROOT / 'data' / 'attention_umap_comparison.png'}")

## Summary

| Method | Key Idea |
|--------|----------|
| **Baseline** | attention = expression (fixed weights) |
| **Attention Pooling** | attention = f(embedding, expression) (learned) |

The attention model can learn:
- "TP53 matters even when lowly expressed"
- "Ignore housekeeping genes even if highly expressed"