# Spatial Transcriptomics Analysis with Transformer-GNN

## 1. Environment Setup

In [1]:
# System configuration
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

## 2. Package Imports

In [2]:
import numpy as np
import tensorflow as tf
import scanpy as sc
import optuna
import matplotlib.pyplot as plt
import seaborn as sns
import scipy.sparse 
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras import layers
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

2025-05-03 10:00:09.903553: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746284409.947798  350800 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746284409.960659  350800 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746284410.057993  350800 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746284410.058008  350800 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746284410.058009  350800 computation_placer.cc:177] computation placer alr

## 3. Data Loading & Preprocessing

In [3]:
def load_and_preprocess_data(file_path, n_top_genes=2000, target_cells=1_500_000):
    # 1. Load data IN-MEMORY (remove 'backed' mode)
    adata = sc.read_h5ad(file_path)
    
    # 2. Filter mitochondrial genes
    adata.var["mt"] = adata.var_names.str.startswith("MT-")
    sc.pp.calculate_qc_metrics(adata, qc_vars=["mt"], percent_top=None, inplace=True)
    adata = adata[adata.obs["pct_counts_mt"] < 20, :].copy()  # Force in-memory
    
    # 3. Subsample to target_cells (1.5M)
    if adata.n_obs > target_cells:
        sc.pp.subsample(adata, n_obs=target_cells, random_state=42, copy=False)
    
    # 4. Remove rare clusters
    cluster_counts = adata.obs['cluster'].value_counts()
    valid_clusters = cluster_counts[cluster_counts >= 2].index
    adata = adata[adata.obs['cluster'].isin(valid_clusters)].copy()
    
    # 5. Ensure CSR sparse format
    if not isinstance(adata.X, scipy.sparse.csr_matrix):
        adata.X = scipy.sparse.csr_matrix(adata.X)
    
    # 6. Proceed with HVG selection and split
    sc.pp.filter_genes(adata, min_counts=1)
    sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor="seurat_v3")
    adata = adata[:, adata.var["highly_variable"]]

    # Convert sparse matrix to dense array
    X_dense = adata.X.toarray() if isinstance(adata.X, scipy.sparse.spmatrix) else adata.X

    return train_test_split(
        X_dense,  # Now returns dense array
        adata.obs['cluster'].cat.codes.values,
        test_size=0.15,
        stratify=adata.obs['cluster'].cat.codes.values,
        random_state=42
    )

## 4. Model Architecture

In [12]:
class PositionalEncoding(layers.Layer):
    def __init__(self, d_model, dropout_rate=0.1):
        super().__init__()
        self.dropout = layers.Dropout(dropout_rate)
        self.dense = layers.Dense(d_model, activation='relu')
        
    def call(self, coords):
        # Simple positional encoding using coordinate projection
        pos_emb = self.dense(coords)
        return self.dropout(pos_emb)

In [None]:
class TransformerGNNBlock(layers.Layer):
    def __init__(self, units, num_heads, dropout_rate):
        super().__init__()
        self.att = layers.MultiHeadAttention(
            num_heads=num_heads,
            key_dim=units // num_heads  # Ensure divisibility
        )
        self.ffn = tf.keras.Sequential([
            layers.Dense(units*4, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(units)  # Match input dimension
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.dropout2 = layers.Dropout(dropout_rate)
        
    def call(self, inputs):
        # Attention operation
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output)
        out1 = self.layernorm1(inputs + attn_output)
        
        # Feed-forward network
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output)
        return self.layernorm2(out1 + ffn_output), attn_output

In [None]:
class SpatialTransformerGNN(tf.keras.Model):
    def __init__(self, input_dim, num_classes, units, num_heads, dropout_rate):
        super().__init__()
        self.feature_dim = input_dim - 2  # Last 2 columns are coordinates
        
        # Feature projection to (units - pos_encoding_dim)
        self.feature_proj = tf.keras.Sequential([
            layers.Dense(units - 128, activation='relu'),  # 128 = pos encoding dim
            layers.Dropout(dropout_rate)
        ])
        
        # Positional encoding fixed at 128 dimensions
        self.pos_encoder = PositionalEncoding(128, dropout_rate)
        
        # Transformer blocks now match units dimension
        self.transformer_block1 = TransformerGNNBlock(units, num_heads, dropout_rate)
        self.transformer_block2 = TransformerGNNBlock(units, num_heads, dropout_rate)
        
        # Output layers
        self.flatten = layers.Flatten()
        self.classifier = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        # Split inputs
        features = inputs[:, :self.feature_dim]
        coords = inputs[:, self.feature_dim:]
        
        # Process features with explicit sequence dimension
        x = self.feature_proj(features)
        pos_emb = self.pos_encoder(coords)
        x = tf.concat([x, tf.expand_dims(pos_emb, 1)], axis=-1)
        
        # Transformer blocks
        x, _ = self.transformer_block1(x)
        x, _ = self.transformer_block2(x)
        
        # Classification
        x = self.flatten(x)
        return self.classifier(x)

## 5. Training Configuration

In [None]:
def configure_training(model, learning_rate):
    """Configure training parameters and callbacks"""
    optimizer = tf.keras.optimizers.Adam(
        learning_rate=learning_rate,
        weight_decay=1e-4
    )

    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=['accuracy']
    )
    
    return [
        EarlyStopping(patience=10, restore_best_weights=True),
        ReduceLROnPlateau(factor=0.5, patience=5)
    ]

## 6. Optuna Hyperparameter Optimization

In [16]:
def create_optuna_study():
    return optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(),
        pruner=optuna.pruners.MedianPruner(n_warmup_steps=10)
    )

def objective(trial, X_train, y_train, X_val, y_val):
    # Model architecture parameters
    model_params = {
        'input_dim': X_train.shape[1],
        'num_classes': len(np.unique(y_train)),
        'units': trial.suggest_int('units', 128, 512, step=64),
        'num_heads': trial.suggest_int('num_heads', 2, 8),
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5, step=0.1)
    }
    
    # Training parameters (separate from model params)
    training_params = {
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [128, 256, 512])
    }

    # Initialize model with only architecture parameters
    model = SpatialTransformerGNN(**model_params)
    
    # Use training_params for training configuration
    callbacks = configure_training(model, training_params['learning_rate'])
    history = model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=100,
        batch_size=training_params['batch_size'],  # <-- Use here
        callbacks=callbacks,
        verbose=0
    )
    
    return max(history.history['val_accuracy'])

## 7. Visualization & Analysis

In [17]:
def plot_training_history(history):
    """Plot training and validation metrics"""
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.title('Training History')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Loss Curves')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

## 8. Main Execution Flow

In [None]:
if __name__ == "__main__":
    # Load and preprocess data
    (X_train, X_val, y_train, y_val) = load_and_preprocess_data(
        '/home/frankfurt/LDL/data/abc_atlas/Zhuang-ABCA-1-merged-final-CLEAN-v2.h5ad',
        n_top_genes=2000,
        target_cells=1_500_000
    )
    
    # Hyperparameter optimization
        # Create study with data closure
    study = create_optuna_study()
    study.optimize(
        lambda trial: objective(trial, X_train, y_train, X_val, y_val),  # Data closure
        n_trials=50,
        timeout=3600
    )
    
    # Train final model
best_params = study.best_params

# Separate model params from training params
model_params = {
    'input_dim': X_train.shape[1],
    'num_classes': len(np.unique(y_train)),
    'units': best_params['units'],
    'num_heads': best_params['num_heads'],
    'dropout_rate': best_params['dropout_rate']
}

training_params = {
    'learning_rate': best_params['learning_rate'],
    'batch_size': best_params['batch_size']
}

# Initialize final model
final_model = SpatialTransformerGNN(**model_params)

# Train with separated training params
history = final_model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    batch_size=training_params['batch_size'],
    callbacks=configure_training(final_model, training_params['learning_rate'])
)
    
    # Evaluation and visualization
plot_training_history(history)
final_model.evaluate(X_val, y_val, verbose=2)