# Task 3: Explore and Evaluate Perturbed Embeddings

This notebook analyzes the effectiveness of gene perturbations in bringing ALS case cells closer to healthy control cells in the embedding space.

## Analysis Pipeline

1. Load unperturbed embeddings (baseline)
2. Load perturbed embeddings for each gene manipulation
3. Compute similarity metrics (silhouette, distance, kNN purity)
4. Compare perturbations to identify which genes show therapeutic potential
5. Aggregate results by cell type to understand cell-type-specific effects

## Key Metrics

- **Silhouette Coefficient**: Measures cluster separation (lower = cases more similar to controls)
- **Distance to Controls**: Mean Euclidean distance from each case to all controls (lower = better)
- **kNN Purity**: Fraction of control neighbors in k-nearest neighbors (higher = better)

All metrics are computed in the **original embedding space**, not UMAP, for meaningful distance comparisons.

In [1]:
# Import required libraries
import pandas as pd
import anndata as ad
import numpy as np
from sklearn.neighbors import NearestNeighbors

# Load sample metadata (contains condition and cell type information)
metadata = pd.read_csv("data/sample_1000_cells_balanced_obs.csv")

# Load unperturbed (baseline) embeddings for cases and controls
vanilla_embeddings = np.load("data/unperturbed_embeddings.npz")["embeddings"]

# List of all perturbations to analyze
# Includes normalization perturbations and knockouts
perturbations = ["SOD1", "SOD1_ko", "FUS", "FUS_ko", "TARDBP", "TARDBP_ko", 
                 "C9orf72", "C9orf72_ko", "GDNF"]

## Define Metric Functions

We define reusable functions to compute three key metrics that evaluate how well perturbations bring ALS cases closer to healthy controls.

In [2]:
# Import libraries for metric computation
import scanpy as sc
import umap
from sklearn.metrics import silhouette_samples
from scipy.spatial.distance import cdist

def compute_case_knn_purity(embeddings, metadata, k=5):
    """
    Compute kNN purity for each case cell only.
    
    Purity = fraction of k-nearest neighbors that are controls (healthy cells).
    Higher purity means the case is surrounded by more controls (better).
    
    Args:
        embeddings: Embedding matrix (n_cells × embedding_dim)
        metadata: DataFrame with 'Condition' column ('ALS' or 'PN')
        k: Number of nearest neighbors to consider
    
    Returns:
        Array of purity scores for each case cell
    """
    case_mask = metadata['Condition'] == 'ALS'
    case_indices = np.where(case_mask)[0]
    
    # Fit kNN on all embeddings (cases + controls)
    knn = NearestNeighbors(n_neighbors=k+1).fit(embeddings)
    
    # Get neighbors for case cells only
    distances, indices = knn.kneighbors(embeddings[case_mask])
    
    purity_scores = []
    for i, neighbors in enumerate(indices):
        # Exclude the point itself (first neighbor is self)
        neighbor_labels = metadata['Condition'].iloc[neighbors[1:]]
        # Purity = fraction of neighbors that are controls
        n_control_neighbors = (neighbor_labels == 'PN').sum()
        purity_scores.append(n_control_neighbors / k)
    
    return np.array(purity_scores)

def compute_case_distances_to_controls(embeddings, metadata):
    """
    Compute mean distance from each case to all controls in embedding space.
    
    Lower distance means the case is closer to the control distribution (better).
    
    Args:
        embeddings: Embedding matrix (n_cells × embedding_dim)
        metadata: DataFrame with 'Condition' column ('ALS' or 'PN')
    
    Returns:
        Array of mean distances for each case cell
    """
    case_mask = metadata['Condition'] == 'ALS'
    control_mask = metadata['Condition'] == 'PN'
    
    case_embeddings = embeddings[case_mask]
    control_embeddings = embeddings[control_mask]
    
    # Compute pairwise distances from cases to controls
    distances = cdist(case_embeddings, control_embeddings, metric='euclidean')
    
    # Mean distance from each case to all controls
    mean_distances = distances.mean(axis=1)
    
    return mean_distances

def compute_case_silhouette_coefficients(embeddings, metadata):
    """
    Compute silhouette coefficient for each case cell only.
    
    Silhouette measures how well a sample fits its own cluster vs other clusters.
    Lower (more negative) values mean cases are more similar to controls (better).
    
    Args:
        embeddings: Embedding matrix (n_cells × embedding_dim)
        metadata: DataFrame with 'Condition' column ('ALS' or 'PN')
    
    Returns:
        Array of silhouette coefficients for each case cell
    """
    # Compute silhouette coefficient for all samples
    all_silhouette = silhouette_samples(embeddings, metadata['Condition'])
    
    # Extract only case cell silhouette coefficients
    case_mask = metadata['Condition'] == 'ALS'
    case_silhouette = all_silhouette[case_mask]
    
    return case_silhouette

def umap_and_compute_metrics(embeddings, metadata):
    """
    Compute UMAP and per-case metrics for a given embedding matrix.
    
    This wrapper function:
    1. Creates AnnData object with embeddings and metadata
    2. Computes UMAP for visualization (not used for metrics)
    3. Computes all three metrics in original embedding space
    
    Args:
        embeddings: Embedding matrix (n_cells × embedding_dim)
        metadata: DataFrame with cell annotations
    
    Returns:
        Tuple: (adata, case_silhouette, case_purity, case_distances)
    """
    # Create an AnnData object to hold the embeddings and metadata
    adata = ad.AnnData(X=embeddings)
    adata.obs = metadata.reset_index(drop=True)

    # Compute UMAP for visualization (optional, not used for metrics)
    sc.pp.neighbors(adata, use_rep='X')
    sc.tl.umap(adata)

    # Compute per-case metrics in ORIGINAL embedding space (not UMAP)
    case_silhouette = compute_case_silhouette_coefficients(embeddings, adata.obs)
    case_purity = compute_case_knn_purity(embeddings, adata.obs, k=5)
    case_distances = compute_case_distances_to_controls(embeddings, adata.obs)
    
    return adata, case_silhouette, case_purity, case_distances

  from .autonotebook import tqdm as notebook_tqdm


## Compute Baseline Metrics

Calculate metrics for unperturbed embeddings to establish a baseline for comparison.

In [3]:
# Compute baseline metrics for unperturbed embeddings
# These serve as the reference for evaluating perturbation effectiveness

adata_vanilla, silhouette_vanilla, purity_vanilla, distances_vanilla = umap_and_compute_metrics(
    vanilla_embeddings, metadata
)

# Get case cell identifiers from metadata
case_mask = metadata['Condition'] == 'ALS'
case_metadata = metadata[case_mask].reset_index(drop=True)

# Initialize results dictionary with case identifiers
# This will store per-case metrics for all perturbations
results_dict = {
    'Cell_Index': case_metadata.index.tolist(),
}

# Add unperturbed metrics as baseline
results_dict['Silhouette_Unperturbed'] = silhouette_vanilla
results_dict['Distance_Unperturbed'] = distances_vanilla
results_dict['Purity_Unperturbed'] = purity_vanilla

  return dispatch(args[0].__class__)(*args, **kw)


## Compute Metrics for All Perturbations

For each perturbation:
1. Load perturbed embeddings (ALS cases only)
2. Replace case embeddings with perturbed versions
3. Compute metrics (comparing perturbed cases to controls)
4. Store results for comparison

In [4]:
# Loop through all perturbations and compute metrics

# Store global (aggregated) metrics for summary table
global_results = [{
    'Perturbation': 'Unperturbed',
    'Mean Silhouette': silhouette_vanilla.mean(),
    'Mean Purity': purity_vanilla.mean(),
    'Mean Distance': distances_vanilla.mean()
}]

for perturbation in perturbations:
    # Load perturbed embeddings for this gene manipulation
    perturbed_embeddings = np.load(f"data/perturbed_embeddings_{perturbation}.npz")["embeddings"]
    
    # Create combined embedding matrix: controls (unchanged) + perturbed cases
    embeddings = vanilla_embeddings.copy()
    case_mask = metadata['Condition'] == 'ALS'
    
    # Replace vanilla case embeddings with perturbed ones
    # The order of cases in unperturbed and perturbed embeddings is the same
    embeddings[case_mask] = perturbed_embeddings
    
    # Compute metrics for this perturbation
    adata_pert, silhouette_pert, purity_pert, distances_pert = umap_and_compute_metrics(
        embeddings, metadata
    )
    
    # Store per-case metrics in the results dictionary
    results_dict[f'Silhouette_{perturbation}'] = silhouette_pert
    results_dict[f'Distance_{perturbation}'] = distances_pert
    results_dict[f'Purity_{perturbation}'] = purity_pert
    
    # Store global metrics for summary
    global_results.append({
        'Perturbation': perturbation,
        'Mean Silhouette': silhouette_pert.mean(),
        'Mean Purity': purity_pert.mean(),
        'Mean Distance': distances_pert.mean()
    })

# Create per-case results DataFrame
per_case_df = pd.DataFrame(results_dict)

# Add cell type information to per-case results
per_case_df['CellType'] = case_metadata['CellType'].values

# Create global summary DataFrame
global_df = pd.DataFrame(global_results)

  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)
  return dispatch(args[0].__class__)(*args, **kw)


## Aggregate Results by Cell Type

Compute cell-type-specific metrics to understand which perturbations work best for different cell types.

In [5]:
# Aggregate results by cell type to identify cell-type-specific effects

# Get all metric column names
silhouette_cols = [col for col in per_case_df.columns if col.startswith('Silhouette_')]
distance_cols = [col for col in per_case_df.columns if col.startswith('Distance_')]
purity_cols = [col for col in per_case_df.columns if col.startswith('Purity_')]

# Group by cell type and compute mean for each metric
celltype_aggregated = per_case_df.groupby('CellType').agg({
    **{col: 'mean' for col in silhouette_cols},
    **{col: 'mean' for col in distance_cols},
    **{col: 'mean' for col in purity_cols}
}).reset_index()

# Reorder columns for better readability
ordered_cols = ['CellType'] + silhouette_cols + distance_cols + purity_cols
celltype_aggregated = celltype_aggregated[ordered_cols]

# Compute improvements (delta from unperturbed) for each cell type
# Positive improvements indicate better outcomes

# Create improvement dataframes
silhouette_improvements = pd.DataFrame({'CellType': celltype_aggregated['CellType']})
distance_improvements = pd.DataFrame({'CellType': celltype_aggregated['CellType']})
purity_improvements = pd.DataFrame({'CellType': celltype_aggregated['CellType']})

# Calculate improvements for each perturbation
for col in silhouette_cols:
    if col != 'Silhouette_Unperturbed':
        pert_name = col.replace('Silhouette_', '')
        # Higher silhouette = better separation, but we want LOWER (closer to controls)
        # So improvement is (unperturbed - perturbed)
        silhouette_improvements[pert_name] = celltype_aggregated[col] - celltype_aggregated['Silhouette_Unperturbed']

for col in distance_cols:
    if col != 'Distance_Unperturbed':
        pert_name = col.replace('Distance_', '')
        # Lower distance = better, so improvement is (unperturbed - perturbed)
        distance_improvements[pert_name] = celltype_aggregated['Distance_Unperturbed'] - celltype_aggregated[col]

for col in purity_cols:
    if col != 'Purity_Unperturbed':
        pert_name = col.replace('Purity_', '')
        # Higher purity = better, so improvement is (perturbed - unperturbed)
        purity_improvements[pert_name] = celltype_aggregated[col] - celltype_aggregated['Purity_Unperturbed']

# Find best perturbation for each cell type based on distance improvement
perturbation_names = [col for col in distance_improvements.columns if col != 'CellType']
best_pert_per_celltype = []

for idx, row in distance_improvements.iterrows():
    celltype = row['CellType']
    improvements = {pert: row[pert] for pert in perturbation_names}
    
    # Find perturbation with maximum distance improvement
    best_pert = max(improvements, key=improvements.get)
    best_improvement = improvements[best_pert]
    
    best_pert_per_celltype.append({
        'CellType': celltype,
        'Best_Perturbation': best_pert,
        'Distance_Improvement': best_improvement,
        'Silhouette_Improvement': silhouette_improvements.iloc[idx][best_pert],
        'Purity_Improvement': purity_improvements.iloc[idx][best_pert]
    })

best_pert_df = pd.DataFrame(best_pert_per_celltype)

## Results: Best Perturbation per Cell Type

Show which perturbation is most effective for each cell type based on silhouette coefficient improvement.

**Note**: Lower silhouette coefficient indicates cases are closer to controls (better therapeutic effect).

In [6]:
# Display best perturbation per cell type based on silhouette coefficient
# Lower silhouette = better (cases more similar to controls)

print("="*120)
print("BEST PERTURBATION PER CELL TYPE (based on Silhouette coefficient improvement)")
print("="*120)

for idx, row in silhouette_improvements.iterrows():
    celltype = row['CellType']
    improvements = {pert: row[pert] for pert in perturbation_names}
    
    # Find perturbation with minimum (most negative) silhouette improvement
    # Negative improvement means silhouette decreased (cases closer to controls)
    best_pert = min(improvements, key=improvements.get)
    best_improvement = improvements[best_pert]
    
    print(f"Cell Type: {celltype}, Best Perturbation (Silhouette): {best_pert}, "
          f"Improvement: {best_improvement:.4f}")

BEST PERTURBATION PER CELL TYPE (based on Silhouette coefficient improvement)
Cell Type: 5HT3aR, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0001
Cell Type: Astro, Best Perturbation (Silhouette): FUS, Improvement: -0.0000
Cell Type: Endo, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0001
Cell Type: Fibro, Best Perturbation (Silhouette): FUS, Improvement: -0.0000
Cell Type: L2_L3, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0002
Cell Type: L3_L5, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0001
Cell Type: L4_L5, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0002
Cell Type: L4_L6, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0002
Cell Type: L5, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0002
Cell Type: L5_L6, Best Perturbation (Silhouette): FUS_ko, Improvement: -0.0002
Cell Type: L6, Best Perturbation (Silhouette): SOD1_ko, Improvement: -0.0003
Cell Type: Micro, Best Perturbation (Silhouette): SOD1_ko

Knocking out SOD1 seems to have the largest effect for most cell types.