https://scgpt.readthedocs.io/en/latest/tutorial_perturbation.html
https://huggingface.co/ctheodoris/Geneformer/blob/main/examples/in_silico_perturbation.ipynb

In [1]:
import os
import tempfile
import scanpy as sc
import scvi
import seaborn as sns
import torch
from rich import print


Global seed set to 0


In [2]:
adata = sc.read_h5ad("/cis/net/r41/data/iessien1/Multi_Injury_Atlas.h5ad")
print(adata)

In [3]:
n_annotations = len(adata.obs['finalannotationv1'].unique())
print(f"Number of unique cell type annotations: {n_annotations}")

# Create strat column temporarily to check unique combinations
adata.obs['strat_temp'] = adata.obs['finalannotationv1'].astype(str) + "_" + adata.obs['Condition'].astype(str)
n_strat = len(adata.obs['strat_temp'].unique())
print(f"Number of unique cell type-condition combinations: {n_strat}")
strat_counts = adata.obs['strat_temp'].value_counts()
print(strat_counts)

# Print all unique cell type annotations
print("\nAll unique cell type annotations:")
print(adata.obs['finalannotationv1'].unique())

# Print all unique conditions
print("\nAll unique conditions:")
print(adata.obs['Condition'].unique())

# Print all unique stratification combinations
print("\nAll unique cell type-condition combinations:")
print(adata.obs['strat_temp'].unique())

# Print summary statistics
print("\nSummary statistics for stratification counts:")
print(f"Min count: {strat_counts.min()}")
print(f"Max count: {strat_counts.max()}")
print(f"Mean count: {strat_counts.mean():.2f}")
print(f"Median count: {strat_counts.median()}")

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from scanpy import tl
# 1. Create the combined stratification key
# Create stratification key and filter out rare combinations
# Count unique categories in finalannotationv1 and strat before filtering
# Check for the error: The least populated class in y has only 1 member
n_annotations = len(adata.obs['finalannotationv1'].unique())
print(f"Number of unique cell type annotations: {n_annotations}")

# Create strat column temporarily to check unique combinations
adata.obs['strat_temp'] = adata.obs['finalannotationv1'].astype(str) + "_" + adata.obs['Condition'].astype(str)
n_strat = len(adata.obs['strat_temp'].unique())
print(f"Number of unique cell type-condition combinations: {n_strat}")
strat_counts = adata.obs['strat_temp'].value_counts()
print(strat_counts)

# Filter out combinations with fewer than 2 members to avoid the error
valid_strata = strat_counts[strat_counts >= 2].index
valid_cells = adata.obs['strat_temp'].isin(valid_strata)
adata_filtered = adata[valid_cells].copy()
print(f"Number of filtered cell type-condition combinations: {len(valid_strata)}")

target_genes = ['MYC', 'AKT1', 'CCND1', 'STAT3', 'HSPA5', 'JUN', 'FOS', 'COX4I1', 'HIF1A', 'HSPA9']



In [None]:
n_strat = len(adata.obs['strat_temp'].unique())

# Valid strats (with ≥2 members)
len(valid_strata)

In [None]:

def eval_perturb(adata_filtered, target_genes, cell_type_key='finalannotationv1', condition_key='Condition'):
    """
    Evaluate gene perturbation effects across different cell types and conditions.
    
    Args:
        adata_filtered: AnnData object containing single-cell data
        target_genes: List of genes to analyze for perturbation
        cell_type_key: Key in adata.obs for cell type annotations
        condition_key: Key in adata.obs for condition annotations
        
    Returns:
        DataFrame with perturbation results
    """
    # Filter for valid genes (those present in the dataset)
    valid_genes = [gene for gene in target_genes if gene in adata_filtered.var_names]
    if len(valid_genes) < len(target_genes):
        print(f"Warning: {len(target_genes) - len(valid_genes)} genes not found in dataset")
    # Create stratification variable for balanced splitting based on filtered data
    adata_filtered.obs['strat_temp'] = adata_filtered.obs[cell_type_key].astype(str) + "_" + adata_filtered.obs[condition_key].astype(str)
    
    # Check for strata with fewer than 2 members
    strat_counts = adata_filtered.obs['strat_temp'].value_counts()
    valid_strata = strat_counts[strat_counts >= 2].index
    
    # Filter to only include cells from valid strata
    valid_cells = adata_filtered.obs['strat_temp'].isin(valid_strata)
    adata_filtered = adata_filtered[valid_cells].copy()
    
    # Get the stratification variable for the filtered data
    strat = adata_filtered.obs['strat_temp']
    
    # Display stratification information
    print(f"Stratification counts before splitting (filtered to {len(valid_strata)} valid strata):")
    print(strat.value_counts().head(10))
    # First split into train and temp (val+test combined)
    train_idx, temp_idx = train_test_split(
        np.arange(adata_filtered.n_obs),
        test_size=0.3,  # 30% for validation and test combined
        random_state=42,
        stratify=strat
    )
    
    # Then split temp into validation and test
    # Use iloc to access by position in the numpy array
    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=0.33,  # 1/3 of the 30% (10% of total) for test
        random_state=42,
        stratify=strat[temp_idx]  # Using numpy array indexing
    )
    
    # Create train, validation, and test datasets
    train_data = adata_filtered[train_idx].copy()
    val_data = adata_filtered[val_idx].copy()
    test_data = adata_filtered[test_idx].copy()
    
    # Verify stratification worked correctly
    print("\nStratification distribution in splits:")
    print(f"Training data ({train_data.n_obs} cells):")
    print(train_data.obs['strat_temp'].value_counts().head(5))
    print(f"\nValidation data ({val_data.n_obs} cells):")
    print(val_data.obs['strat_temp'].value_counts().head(5))
    print(f"\nTest data ({test_data.n_obs} cells):")
    print(test_data.obs['strat_temp'].value_counts().head(5))
    
    return train_data, val_data, test_data


# Create train, validation, and test datasets for perturbation analysis
train_data, val_data, test_data = eval_perturb(adata_filtered, target_genes)


In [None]:

# Define a dataset class for gene perturbation
class GenePerturbationDataset(torch.utils.data.Dataset):
    def __init__(self, adata, tokenizer, max_length=2048):
        self.adata = adata
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Extract gene expression data
        if isinstance(adata.X, np.ndarray):
            self.expression = adata.X
        else:
            self.expression = adata.X.toarray()
        
        # Get gene names
        self.gene_names = list(adata.var_names)
        
        # Create mapping from gene names to indices
        self.gene_to_idx = {gene: idx for idx, gene in enumerate(self.gene_names)}
        
    def __len__(self):
        return self.adata.n_obs
    
    def __getitem__(self, idx):
        # Get expression vector for this cell
        expr = self.expression[idx]
        
        # Convert to tokens using the tokenizer
        # We only include expressed genes (non-zero values)
        expressed_genes = [self.gene_names[i] for i in np.where(expr > 0)[0]]
        
        # Tokenize the gene names
        tokens = self.tokenizer(
            expressed_genes,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        return {
            "input_ids": tokens.input_ids.squeeze(),
            "attention_mask": tokens.attention_mask.squeeze(),
            "expression": torch.tensor(expr, dtype=torch.float32)
        }


In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from typing import Tuple
import numpy as np
import scanpy as sc
import optuna
import os
from Geneformer import InSilicoPerturber, EmbExtractor

# Define target genes for analysis
target_genes = ['MYC', 'AKT1', 'CCND1', 'STAT3', 'HSPA5', 'JUN', 'FOS', 'COX4I1', 'HIF1A', 'HSPA9']

def train_and_evaluate_model(model_name, train_data, val_data, cell_states_to_model=None, test_data=None, epochs=50) -> Tuple[sc.AnnData, float, float]:
    """Train model and evaluate predictions on both validation and holdout sets."""
    try:
        if model_name == "Geneformer":
            # Define output directory for saving models and hyperparameters
            output_base_dir = "/cis/net/r41/data/iessien1/"
            os.makedirs(output_base_dir, exist_ok=True)
            
            # Initialize tokenizer and dataset
            tokenizer = AutoTokenizer.from_pretrained("ctheodoris/Geneformer")
            train_dataset = GenePerturbationDataset(train_data, tokenizer)
            val_dataset = GenePerturbationDataset(val_data, tokenizer)
            test_dataset = GenePerturbationDataset(test_data, tokenizer) if test_data is not None else None
            
            # Setup data loaders
            train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=32)
            test_loader = DataLoader(test_dataset, batch_size=32) if test_data is not None else None
            
            # Define custom data collator to handle gene expression data
            def data_collator(features):
                input_ids = torch.stack([f["input_ids"] for f in features])
                attention_mask = torch.stack([f["attention_mask"] for f in features])
                labels = torch.stack([f["expression"] for f in features])
                return {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels
                }
            
            # Initialize model with hyperparameter optimization
            def objective(trial):
                # Define hyperparameters to tune
                learning_rate = trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True)
                weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-2, log=True)
                batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
                
                # Initialize model
                model = AutoModelForSequenceClassification.from_pretrained(
                    "ctheodoris/Geneformer",
                    num_labels=len(target_genes)
                )
                
                # Define training arguments with early stopping
                trial_output_dir = os.path.join(output_base_dir, f"trial_{trial.number}")
                training_args = TrainingArguments(
                    output_dir=trial_output_dir,
                    learning_rate=learning_rate,
                    weight_decay=weight_decay,
                    per_device_train_batch_size=batch_size,
                    per_device_eval_batch_size=batch_size,
                    num_train_epochs=epochs,
                    evaluation_strategy="epoch",
                    save_strategy="epoch",
                    load_best_model_at_end=True,
                    metric_for_best_model="eval_loss",
                    greater_is_better=False,
                    save_total_limit=3,  # Keep only the 3 best checkpoints
                    early_stopping_patience=5,  # Stop if no improvement for 5 evaluations
                )
                
                # Define trainer
                trainer = Trainer(
                    model=model,
                    args=training_args,
                    train_dataset=train_dataset,
                    eval_dataset=val_dataset,  # Use validation set for model selection
                    data_collator=data_collator,
                )
                
                # Train model
                trainer.train()
                
                # Evaluate model on validation set
                eval_results = trainer.evaluate()
                return eval_results["eval_loss"]
            
            # Run hyperparameter optimization using validation set
            study = optuna.create_study(direction="minimize")
            study.optimize(objective, n_trials=10)
            
            # Get best hyperparameters
            best_params = study.best_params
            print(f"Best hyperparameters: {best_params}")
            
            # Save hyperparameters to file
            import json
            with open(os.path.join(output_base_dir, "best_hyperparameters.json"), "w") as f:
                json.dump(best_params, f, indent=4)
            
            # Train final model with best hyperparameters
            model = AutoModelForSequenceClassification.from_pretrained(
                "ctheodoris/Geneformer",
                num_labels=len(target_genes)
            )
            
            best_model_dir = os.path.join(output_base_dir, "best_model")
            training_args = TrainingArguments(
                output_dir=best_model_dir,
                learning_rate=best_params["learning_rate"],
                weight_decay=best_params["weight_decay"],
                per_device_train_batch_size=best_params["batch_size"],
                per_device_eval_batch_size=best_params["batch_size"],
                num_train_epochs=epochs,
                evaluation_strategy="epoch",
                save_strategy="epoch",
                load_best_model_at_end=True,
                metric_for_best_model="eval_loss",
                greater_is_better=False,
                save_total_limit=3,
                early_stopping_patience=5,
            )
            
            trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=train_dataset,
                eval_dataset=val_dataset,  # Use validation set for early stopping and model selection
                data_collator=data_collator,
            )
            
            # Train model
            trainer.train()
            
            # Save the model locally
            model.save_pretrained(os.path.join(output_base_dir, "finetuned_model"))
            tokenizer.save_pretrained(os.path.join(output_base_dir, "finetuned_tokenizer"))
            
            # Evaluate model on validation set
            val_results = trainer.evaluate()
            val_loss = val_results["eval_loss"]
            
            # Evaluate model on test set
            test_loss = 0.0
            if test_data is not None:
                trainer.eval_dataset = test_dataset
                test_results = trainer.evaluate()
                test_loss = test_results["eval_loss"]
            
            # For in silico perturbation analysis
            if cell_states_to_model is not None:
                # Define filter data dictionary for cell types of interest
                filter_data_dict = {"cell_type": train_data.obs["cell_type"].unique().tolist()}
                
                # First obtain start, goal, and alt embedding positions using EmbExtractor
                embex = EmbExtractor(
                    model_type="CellClassifier",  # using fine-tuned cell classifier model
                    num_classes=len(target_genes),
                    filter_data=filter_data_dict,
                    max_ncells=1000,
                    emb_layer=0,
                    summary_stat="exact_mean",
                    forward_batch_size=256,
                    nproc=16
                )
                
                # Get state embeddings dictionary
                model_path = os.path.join(output_base_dir, "finetuned_model")
                input_data_path = os.path.join(output_base_dir, "input_data")
                output_dir = os.path.join(output_base_dir, "state_embs")
                os.makedirs(output_dir, exist_ok=True)
                
                state_embs_dict = embex.get_state_embs(
                    cell_states_to_model,
                    model_path,
                    input_data_path,
                    output_dir,
                    "state_embs"
                )
                
                # Setup in silico perturber for gene perturbation analysis
                perturber = InSilicoPerturber(
                    perturb_type="overexpress",  # or "delete" for knockouts
                    perturb_rank_shift=None,
                    genes_to_perturb=target_genes,
                    combos=0,
                    anchor_gene=None,
                    model_type="CellClassifier",
                    num_classes=len(target_genes),
                    emb_mode="cell",
                    cell_emb_style="mean_pool",
                    filter_data=filter_data_dict,
                    cell_states_to_model=cell_states_to_model,
                    state_embs_dict=state_embs_dict,
                    max_ncells=2000,
                    emb_layer=0,
                    forward_batch_size=400,
                    nproc=16
                )
                
                # Run perturbation analysis
                perturber.perturb_data(
                    model_path,
                    input_data_path,
                    os.path.join(output_base_dir, "isp_output"),
                    "perturbation_results"
                )
                
                # Store perturbation results in the AnnData object
                train_data.uns["perturbation_results"] = os.path.join(output_base_dir, "isp_output")
            
            return train_data, val_loss, test_loss
        else:
            raise ValueError(f"Model {model_name} not supported")
    except Exception as e:
        print(f"Error in training: {e}")
        return None, float('inf'), float('inf')



In [None]:
# Analyze perturbation results using InSilicoPerturberStats
from Geneformer import InSilicoPerturberStats
output_base_dir = "/cis/net/r41/data/iessien1/"

# Define the cell states to model (same as used in the perturber)
# This should match the cell_states_to_model used in InSilicoPerturber
cell_states_to_model = {
    "state_key": "cell_type",  # column in metadata that defines cell states
    "start_state": train_data.obs["cell_type"].unique().tolist()[0],  # first cell type as starting state
    "goal_state": train_data.obs["cell_type"].unique().tolist()[1],   # second cell type as goal state
    "alt_states": train_data.obs["cell_type"].unique().tolist()[2:]   # remaining cell types as alternatives
}

# Initialize the stats analyzer
ispstats = InSilicoPerturberStats(
    mode="goal_state_shift",  # analyze shifts toward goal state
    genes_perturbed=target_genes,  # same genes used in perturbation
    combos=0,  # no combinations (matches perturber setting)
    anchor_gene=None,  # no anchor gene (matches perturber setting)
    cell_states_to_model=cell_states_to_model  # same states as in perturber
)

# Process the perturbation results
isp_output_dir = os.path.join(output_base_dir, "isp_output")
isp_stats_output_dir = os.path.join(output_base_dir, "isp_stats")
os.makedirs(isp_stats_output_dir, exist_ok=True)

# Extract data from intermediate files and generate final statistics
ispstats.get_stats(
    isp_output_dir,  # directory with perturbation results
    None,  # no token dictionary file specified (using default)
    isp_stats_output_dir,  # where to save the stats
    "perturbation_stats"  # prefix for output files
)

# Print path to results
print(f"Perturbation statistics saved to: {isp_stats_output_dir}")

# Load and display top genes that shift cell state
try:
    import pandas as pd
    stats_file = os.path.join(isp_stats_output_dir, "perturbation_stats_goal_state_shift.csv")
    if os.path.exists(stats_file):
        stats_df = pd.read_csv(stats_file)
        # Display top 10 genes with largest effect
        print("Top 10 genes with largest effect on cell state shift:")
        print(stats_df.sort_values(by="Cosine_sim_mean", ascending=True).head(10))
except Exception as e:
    print(f"Could not load stats file: {e}")