# Task 1: Gene Perturbation Example

This notebook demonstrates the usage of the `PerturbationWorkflow` class to perturb genes in ALS case data.

## Overview

The workflow performs in-silico gene perturbations on single-cell RNA-seq data:
1. Loads ALS patient data
2. Applies fold changes to target genes
3. Generates embeddings using Geneformer foundation model
4. Saves results for downstream analysis

This is a demonstration with a small subset (10 cells) to validate the workflow.

In [None]:
# Import required libraries
import anndata as ad
import logging
from perturbation_workflow import PerturbationWorkflow
from helical.models.geneformer import Geneformer, GeneformerConfig

# Configure logging to track workflow progress
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

## Step 1: Load and Filter Dataset

We load the full dataset containing both ALS cases and healthy controls, then filter to only ALS cases since we want to perturb disease cells.

In [None]:
# Load the full dataset with ALS cases (Condition='ALS') and healthy controls (Condition='PN')
logging.info("Loading dataset...")
adata = ad.read_h5ad("data/counts_combined_filtered_BA4_sALS_PN.h5ad")
logging.info(f"Dataset loaded with shape: {adata.shape} (cells × genes)")

# Select only ALS cases for perturbation analysis
# Controls will remain unchanged and serve as reference
adata = adata[adata.obs['Condition'] == 'ALS'].copy()
logging.info(f"Dataset filtered to ALS cases with shape: {adata.shape} (cells × genes)")

## Step 2: Initialize the Perturbation Workflow

Create a `PerturbationWorkflow` instance with the filtered ALS data.

In [None]:
# Initialize the workflow with the ALS case data
workflow = PerturbationWorkflow(adata)

## Running batch perturbation with embedding model

We'll run perturbations on a small subset of cells (top 10) to demonstrate the workflow.

In [None]:
# Initialize Geneformer v2 foundation model for generating embeddings
# This model transforms gene expression data into latent space representations
model_config = GeneformerConfig(
    model_name="gf-12L-104M-i4096",  # 12-layer, 104M parameters, 4096 token vocabulary
    batch_size=24,                    # Maximum batch size for 16GB GPU
    device="cuda:0"                   # Use first CUDA GPU
)
geneformer_v2 = Geneformer(model_config)

## Step 3: Define Perturbations

We define perturbations for ALS-related genes:
- **C9orf72**: Overexpression (fold_change=2.0) - this gene has repeat expansions in ALS
- **SOD1**: Downregulation (fold_change=0.5) - mutations cause familial ALS

Each perturbation will:
1. Multiply gene expression by the fold change
2. Generate embeddings using Geneformer
3. Save embeddings to disk for analysis

In [None]:
# Define batch perturbations as a list of dictionaries
# Each dictionary specifies: gene_name, fold_change, model, and save path
batch_perturbation_with_embed = [
    {
        "gene_name": "C9orf72",      # Gene to perturb
        "fold_change": 2.0,          # 2x overexpression
        "helical_model": geneformer_v2,  # Model for embeddings
        "save_path": "data/perturbed_embeddings_C9orf72.npz"  # Output file
    },
    {
        "gene_name": "SOD1",         # Gene to perturb
        "fold_change": 0.5,          # 50% downregulation
        "helical_model": geneformer_v2,  # Model for embeddings
        "save_path": "data/perturbed_embeddings_SOD1.npz"  # Output file
    }
]

## Step 4: Run Perturbations on Test Subset

For this demonstration, we use only the first 10 cells to quickly validate the workflow.

In [None]:
# Select only top 10 cells to speed up the test
# In production, you would use the full dataset
head_adata = adata[:10, :].copy()

# Re-initialize workflow with the smaller subset
workflow = PerturbationWorkflow(head_adata)

## Step 5: Execute Batch Perturbations

Run all perturbations sequentially. Each perturbation will:
1. Copy the original data
2. Apply the fold change to the target gene
3. Process through Geneformer to get embeddings
4. Save embeddings as compressed numpy arrays

In [None]:
# Run the batch perturbation workflow
# Returns a list of embedding arrays (one per perturbation)
perturb_batch_results_with_embed = workflow.perturb_batch(batch_perturbation_with_embed)

# Log summary
logging.info(f"Batch perturbation completed. Generated {len(perturb_batch_results_with_embed)} embedding arrays.")

## Step 6: Verify Results

Display the shape of each generated embedding array to confirm successful execution.

In [None]:
# Display results summary
# Each embedding array should have shape (n_cells, embedding_dim)
for i, result in enumerate(perturb_batch_results_with_embed):
    perturbation = batch_perturbation_with_embed[i]
    print(f"Perturbation {i+1} ({perturbation['gene_name']} × {perturbation['fold_change']}): "
          f"Embedding shape = {result.shape}")