In [None]:
!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy

import pandas as pd
import numpy as np
import gzip
import scipy.io
from scipy.sparse import csr_matrix
import scanpy as sc
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from skbio.stats.distance import mantel, DistanceMatrix
from scipy.stats import spearmanr
import pingouin as pg
import matplotlib.pyplot as plt
import os, json

from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:

# File paths
base_path = "/content/drive/MyDrive/dataset-gene-embed/"
files = {
    'barcodes': base_path + "GSE133344_filtered_barcodes.tsv.gz",
    'cell_identities': base_path + "GSE133344_filtered_cell_identities.csv.gz",
    'matrix': base_path + "GSE133344_filtered_matrix.mtx",
    'genes': base_path + "GSE133344_filtered_genes.tsv",
    'generain': base_path + "GeneRAIN-vec.200d.txt"
}

# Load GSE133344 data
def load_gse133344_data():
    """Load GSE133344 dataset and create AnnData object"""
    print("Loading GSE133344 dataset...")

    # Load genes
    genes = pd.read_csv(files['genes'], sep='\t', header=None, names=['ensembl_id', 'gene_symbol'])

    # Load barcodes
    with gzip.open(files['barcodes'], 'rt') as f:
        barcodes = [line.strip() for line in f]

    # Load cell identities
    cell_identities = pd.read_csv(files['cell_identities'], compression='gzip')

    # Load expression matrix
    matrix = scipy.io.mmread(files['matrix']).T.tocsr()  # Transpose to cells x genes

    # Create AnnData object
    adata = sc.AnnData(X=matrix)
    adata.obs_names = barcodes
    adata.var_names = genes['gene_symbol'].values
    adata.var['ensembl_id'] = genes['ensembl_id'].values

    # Add cell annotations (align by barcode)
    cell_identities = cell_identities.set_index('cell_barcode')
    # Only keep cells that are in both datasets
    common_cells = list(set(adata.obs_names) & set(cell_identities.index))
    adata = adata[common_cells].copy()
    adata.obs = adata.obs.join(cell_identities.loc[common_cells])

    print(f"Loaded dataset: {adata.n_obs} cells x {adata.n_vars} genes")
    print(f"Perturbation guides: {adata.obs['guide_identity'].nunique()}")

    return adata

# Parse combinatorial targets
def parse_combinatorial_targets(guide_identity):
    """
    Parse guide identities to extract target genes.
    Format: GENE1_GENE2__GENE1_GENE2 or GENE1_NegCtrl0__GENE1_NegCtrl0
    """
    if pd.isna(guide_identity):
        return []

    # Extract the first part before '__'
    target_part = guide_identity.split('__')[0]

    # Split by underscore to get individual components
    components = target_part.split('_')

    # Filter out control components and extract real gene targets
    targets = []
    for comp in components:
        if 'NegCtrl' not in comp and len(comp) > 0:
            targets.append(comp)

    # Remove duplicates while preserving order
    unique_targets = []
    for t in targets:
        if t not in unique_targets:
            unique_targets.append(t)

    return unique_targets

def categorize_perturbations(adata):
    """Categorize perturbations into control, single-gene, and combinatorial"""
    adata.obs['parsed_targets'] = adata.obs['guide_identity'].apply(parse_combinatorial_targets)
    adata.obs['n_targets'] = adata.obs['parsed_targets'].apply(len)

    # Create perturbation categories
    adata.obs['perturbation_type'] = 'unknown'
    adata.obs.loc[adata.obs['n_targets'] == 0, 'perturbation_type'] = 'control'
    adata.obs.loc[adata.obs['n_targets'] == 1, 'perturbation_type'] = 'single'
    adata.obs.loc[adata.obs['n_targets'] == 2, 'perturbation_type'] = 'dual'
    adata.obs.loc[adata.obs['n_targets'] > 2, 'perturbation_type'] = 'multi'

    print("Perturbation type distribution:")
    print(adata.obs['perturbation_type'].value_counts())

    return adata

# Load GeneRAIN embeddings
def load_generain_embeddings(file_path):
    """Load GeneRAIN embeddings"""
    print(f"Loading GeneRAIN embeddings from {file_path}...")

    genes = []
    embeddings = []

    with open(file_path, 'r') as f:
        # Skip metadata line
        first_line = f.readline().strip()
        print(f"Metadata: {first_line}")

        for line_num, line in enumerate(f, 2):
            parts = line.strip().split()
            if len(parts) < 2:
                continue

            gene_name = parts[0]
            embedding_values = [float(x) for x in parts[1:]]
            genes.append(gene_name)
            embeddings.append(embedding_values)

    df = pd.DataFrame(embeddings, index=genes)
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]

    print(f"Loaded GeneRAIN: {df.shape[0]} genes, {df.shape[1]} dimensions")
    return df

# Create combined embeddings for gene combinations
def create_combined_embeddings(target_combinations, generain_embeddings, method='average'):
    """
    Create embeddings for gene combinations.
    Methods: 'average', 'concatenate', 'element_wise_product'
    """
    combined_embeddings = {}

    for combo_name, gene_list in target_combinations.items():
        if len(gene_list) == 0:
            # Control - use zero vector
            combined_embeddings[combo_name] = np.zeros(generain_embeddings.shape[1])
        elif len(gene_list) == 1:
            # Single gene
            gene = gene_list[0]
            if gene in generain_embeddings.index:
                combined_embeddings[combo_name] = generain_embeddings.loc[gene].values
            else:
                # Gene not found, skip
                continue
        else:
            # Multiple genes - combine embeddings
            available_genes = [g for g in gene_list if g in generain_embeddings.index]

            if len(available_genes) == 0:
                continue  # No genes found in embeddings

            gene_embeds = [generain_embeddings.loc[g].values for g in available_genes]

            if method == 'average':
                combined_embeddings[combo_name] = np.mean(gene_embeds, axis=0)
            elif method == 'concatenate':
                # Pad shorter lists to same length for concatenation
                max_genes = 2  # Assume max 2 genes per combination
                while len(gene_embeds) < max_genes:
                    gene_embeds.append(np.zeros(generain_embeddings.shape[1]))
                combined_embeddings[combo_name] = np.concatenate(gene_embeds[:max_genes])
            elif method == 'element_wise_product':
                combined_embeddings[combo_name] = np.prod(gene_embeds, axis=0)

    # Convert to DataFrame
    combo_embedding_df = pd.DataFrame.from_dict(combined_embeddings, orient='index')
    print(f"Created combined embeddings: {combo_embedding_df.shape[0]} combinations, {combo_embedding_df.shape[1]} dimensions")

    return combo_embedding_df

# Pseudobulk aggregation for combinations
def make_combinatorial_pseudobulk(adata, min_cells=50):
    """Create pseudobulk for each perturbation combination"""
    # Group by guide identity
    groups = adata.obs.groupby('guide_identity').indices

    pseudobulk_data = []
    meta_data = []

    for guide, cell_idx in groups.items():
        if len(cell_idx) < min_cells:
            continue

        # Sum counts for this combination
        summed_counts = np.array(adata.X[cell_idx].sum(axis=0)).ravel()

        pseudobulk_data.append(summed_counts)
        meta_data.append({
            'guide_identity': guide,
            'n_cells': len(cell_idx),
            'targets': adata.obs.loc[adata.obs_names[cell_idx[0]], 'parsed_targets'],
            'perturbation_type': adata.obs.loc[adata.obs_names[cell_idx[0]], 'perturbation_type']
        })

    pseudobulk_df = pd.DataFrame(pseudobulk_data, columns=adata.var_names)
    pseudobulk_df.index = [m['guide_identity'] for m in meta_data]
    meta_df = pd.DataFrame(meta_data)

    print(f"Created pseudobulk for {len(pseudobulk_df)} combinations")
    return pseudobulk_df, meta_df

def compute_combinatorial_effects(pseudobulk_df, meta_df):
    """Compute perturbation effects relative to controls"""
    # Normalize to CPM and log-transform
    lib_sizes = pseudobulk_df.sum(axis=1)
    cpm = pseudobulk_df.div(lib_sizes, axis=0) * 1e6
    logcpm = np.log1p(cpm)

    # Identify controls
    control_guides = meta_df[meta_df['perturbation_type'] == 'control']['guide_identity'].tolist()

    if len(control_guides) == 0:
        raise ValueError("No control perturbations found")

    # Compute mean control profile
    control_profiles = logcpm.loc[control_guides]
    mean_control = control_profiles.mean(axis=0)

    # Compute effects (perturbation - control)
    effects = logcpm.subtract(mean_control, axis=1)

    # Z-score standardize each gene
    effects_zscore = effects.subtract(effects.mean(axis=0), axis=1)
    effects_zscore = effects_zscore.divide(effects_zscore.std(axis=0), axis=1)

    print(f"Computed effects for {len(effects_zscore)} combinations")
    return effects_zscore

# Distance calculations and statistics (same as before)
def pairwise_cosine(M):
    if np.isnan(M).any():
        M = np.nan_to_num(M, nan=0.0)
    return pairwise_distances(M, metric='cosine')

def mantel_and_spearman(D1, D2, n_perms=999):
    if np.isnan(D1).any() or np.isnan(D2).any():
        D1 = np.nan_to_num(D1, nan=1.0)
        D2 = np.nan_to_num(D2, nan=1.0)

    iu = np.triu_indices_from(D1, k=1)
    rho, p = spearmanr(D1[iu], D2[iu])

    if n_perms > 0:
        try:
            ids = [f"X{i}" for i in range(D1.shape[0])]
            m1, m2 = DistanceMatrix(D1, ids=ids), DistanceMatrix(D2, ids=ids)
            r_m, p_m, _ = mantel(m1, m2, method='spearman', permutations=n_perms)
            return rho, p, r_m, p_m
        except:
            return rho, p, rho, p
    return rho, p, rho, p

def retrieval_metrics(Dg, Dp, ks=(5,10,20)):
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf)
    np.fill_diagonal(Dp2, np.inf)

    precisions = {k: [] for k in ks}
    for i in range(n):
        rank_g = np.argsort(Dg2[i])
        rank_p = np.argsort(Dp2[i])
        for k in ks:
            precisions[k].append(len(set(rank_g[:k]) & set(rank_p[:k]))/k)

    return {k: float(np.mean(v)) for k, v in precisions.items()}

# Plotting functions
def plot_distance_scatter(Dg, Dp, title, out="figs/combinatorial_scatter.png"):
    os.makedirs("figs", exist_ok=True)
    iu = np.triu_indices_from(Dg, k=1)
    x, y = Dg[iu], Dp[iu]

    rho, p = spearmanr(x, y)
    plt.figure(figsize=(6,5))
    plt.scatter(x, y, s=8, alpha=0.6)
    plt.title(f"{title}\nSpearman ρ={rho:.3f}, p={p:.2e}")
    plt.xlabel("Combined Embedding Distance")
    plt.ylabel("Combinatorial Perturbation Distance")
    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close()

# Main analysis function
def run_combinatorial_analysis():
    """Run the full combinatorial perturbation analysis"""
    os.makedirs("results", exist_ok=True)

    print("=== COMBINATORIAL PERTURBATION ANALYSIS ===\n")

    # Load data
    adata = load_gse133344_data()
    adata = categorize_perturbations(adata)

    # Load GeneRAIN embeddings
    generain_emb = load_generain_embeddings(files['generain'])

    # Create pseudobulk
    pseudobulk_df, meta_df = make_combinatorial_pseudobulk(adata, min_cells=50)

    # Compute effects
    effects_df = compute_combinatorial_effects(pseudobulk_df, meta_df)

    # Create target combination mapping
    target_combinations = {}
    for idx, row in meta_df.iterrows():
        guide = row['guide_identity']
        targets = row['targets']
        target_combinations[guide] = targets

    # Create combined embeddings for the combinations we have data for
    available_guides = effects_df.index.tolist()
    available_combinations = {g: target_combinations[g] for g in available_guides if g in target_combinations}

    combined_embeddings = create_combined_embeddings(available_combinations, generain_emb, method='average')

    # Align datasets
    common_guides = list(set(effects_df.index) & set(combined_embeddings.index))
    print(f"Common perturbation combinations: {len(common_guides)}")

    if len(common_guides) < 10:
        print("ERROR: Too few combinations for analysis")
        return None

    effects_aligned = effects_df.loc[common_guides]
    embeddings_aligned = combined_embeddings.loc[common_guides]

    # Compute distance matrices
    Dp = pairwise_cosine(effects_aligned.values)
    Dg = pairwise_cosine(embeddings_aligned.values)

    # Statistical analysis
    rho, p, r_mantel, p_mantel = mantel_and_spearman(Dg, Dp, n_perms=999)
    retrieval_results = retrieval_metrics(Dg, Dp)

    # Results
    results = {
        'analysis_type': 'combinatorial',
        'n_combinations': len(common_guides),
        'embedding_method': 'average',
        'spearman_rho': float(rho),
        'spearman_p': float(p),
        'mantel_r': float(r_mantel),
        'mantel_p': float(p_mantel),
        'precision_at': retrieval_results
    }

    # Save results
    with open("results/combinatorial_metrics.json", "w") as f:
        json.dump(results, f, indent=2)

    # Plot
    plot_distance_scatter(Dg, Dp, "Combinatorial Perturbation Analysis")

    # Print results
    print("\n=== COMBINATORIAL ANALYSIS RESULTS ===")
    print(f"Combinations analyzed: {len(common_guides)}")
    print(f"Correlation: ρ = {rho:.3f}, p = {p:.2e}")
    print(f"Mantel test: r = {r_mantel:.3f}, p = {p_mantel:.2e}")
    print(f"Precision@10: {retrieval_results[10]:.3f}")
    print(f"Random baseline: {10/(len(common_guides)-1):.3f}")

    return results

# Run analysis
results = run_combinatorial_analysis()

=== COMBINATORIAL PERTURBATION ANALYSIS ===

Loading GSE133344 dataset...


  utils.warn_names_duplicates("var")


Loaded dataset: 111445 cells x 33694 genes
Perturbation guides: 290
Perturbation type distribution:
perturbation_type
single     57831
dual       41759
control    11855
Name: count, dtype: int64
Loading GeneRAIN embeddings from /content/drive/MyDrive/dataset-gene-embed/GeneRAIN-vec.200d.txt...
Metadata: 31769 200
Loaded GeneRAIN: 31769 genes, 200 dimensions
Created pseudobulk for 290 combinations
Computed effects for 290 combinations
Created combined embeddings: 284 combinations, 200 dimensions
Common perturbation combinations: 284

=== COMBINATORIAL ANALYSIS RESULTS ===
Combinations analyzed: 284
Correlation: ρ = 0.081, p = 1.76e-59
Mantel test: r = 0.081, p = 1.00e-03
Precision@10: 0.328
Random baseline: 0.035


In [None]:
!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy

import pandas as pd
import numpy as np
import gzip
import scipy.io
from scipy.sparse import csr_matrix
import scanpy as sc
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from skbio.stats.distance import mantel, DistanceMatrix
from scipy.stats import spearmanr
import pingouin as pg
import matplotlib.pyplot as plt
import os, json
from scipy import sparse

from google.colab import drive
drive.mount('/content/drive')

# File paths
base_path = "/content/drive/MyDrive/dataset-gene-embed/"
files = {
    'barcodes': base_path + "GSE133344_filtered_barcodes.tsv.gz",
    'cell_identities': base_path + "GSE133344_filtered_cell_identities.csv.gz",
    'matrix': base_path + "GSE133344_filtered_matrix.mtx",
    'genes': base_path + "GSE133344_filtered_genes.tsv",
    'generain': base_path + "GeneRAIN-vec.200d.txt"
}

# Load GSE133344 data
def load_gse133344_data():
    """Load GSE133344 dataset and create AnnData object"""
    print("Loading GSE133344 dataset...")

    # Load genes
    genes = pd.read_csv(files['genes'], sep='\t', header=None, names=['ensembl_id', 'gene_symbol'])

    # Load barcodes
    with gzip.open(files['barcodes'], 'rt') as f:
        barcodes = [line.strip() for line in f]

    # Load cell identities
    cell_identities = pd.read_csv(files['cell_identities'], compression='gzip')

    # Load expression matrix
    matrix = scipy.io.mmread(files['matrix']).T.tocsr()  # Transpose to cells x genes

    # Create AnnData object
    adata = sc.AnnData(X=matrix)
    adata.obs_names = barcodes
    adata.var_names = genes['gene_symbol'].values
    adata.var['ensembl_id'] = genes['ensembl_id'].values
    adata.var_names_make_unique()  # Fix duplicate names

    # Add cell annotations (align by barcode)
    cell_identities = cell_identities.set_index('cell_barcode')
    # Only keep cells that are in both datasets
    common_cells = list(set(adata.obs_names) & set(cell_identities.index))
    adata = adata[common_cells].copy()
    adata.obs = adata.obs.join(cell_identities.loc[common_cells])

    print(f"Loaded dataset: {adata.n_obs} cells x {adata.n_vars} genes")
    print(f"Perturbation guides: {adata.obs['guide_identity'].nunique()}")
    print(f"Gemgroups: {sorted(adata.obs['gemgroup'].unique())}")

    return adata

# Parse combinatorial targets
def parse_combinatorial_targets(guide_identity):
    """Parse guide identities to extract target genes"""
    if pd.isna(guide_identity):
        return []

    # Extract the first part before '__'
    target_part = guide_identity.split('__')[0]
    components = target_part.split('_')

    # Filter out control components and extract real gene targets
    targets = []
    for comp in components:
        if 'NegCtrl' not in comp and len(comp) > 0:
            targets.append(comp)

    # Remove duplicates while preserving order
    unique_targets = []
    for t in targets:
        if t not in unique_targets:
            unique_targets.append(t)

    return unique_targets

def categorize_perturbations(adata):
    """Categorize perturbations into control, single-gene, and combinatorial"""
    adata.obs['parsed_targets'] = adata.obs['guide_identity'].apply(parse_combinatorial_targets)
    adata.obs['n_targets'] = adata.obs['parsed_targets'].apply(len)

    # Create perturbation categories
    adata.obs['perturbation_type'] = 'unknown'
    adata.obs.loc[adata.obs['n_targets'] == 0, 'perturbation_type'] = 'control'
    adata.obs.loc[adata.obs['n_targets'] == 1, 'perturbation_type'] = 'single'
    adata.obs.loc[adata.obs['n_targets'] == 2, 'perturbation_type'] = 'dual'
    adata.obs.loc[adata.obs['n_targets'] > 2, 'perturbation_type'] = 'multi'

    print("Perturbation type distribution:")
    print(adata.obs['perturbation_type'].value_counts())

    return adata

# Load GeneRAIN embeddings
def load_generain_embeddings(file_path):
    """Load GeneRAIN embeddings"""
    print(f"Loading GeneRAIN embeddings from {file_path}...")

    genes = []
    embeddings = []

    with open(file_path, 'r') as f:
        # Skip metadata line
        first_line = f.readline().strip()
        print(f"Metadata: {first_line}")

        for line_num, line in enumerate(f, 2):
            parts = line.strip().split()
            if len(parts) < 2:
                continue

            gene_name = parts[0]
            embedding_values = [float(x) for x in parts[1:]]
            genes.append(gene_name)
            embeddings.append(embedding_values)

    df = pd.DataFrame(embeddings, index=genes)
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]

    print(f"Loaded GeneRAIN: {df.shape[0]} genes, {df.shape[1]} dimensions")
    return df

# IMPROVED: Gemgroup-matched pseudobulk
def make_gemgroup_matched_combinatorial_pseudobulk(adata, min_cells=20):
    """Create pseudobulk with gemgroup matching for combinatorial analysis"""
    print("Creating gemgroup-matched pseudobulks for combinatorial analysis...")

    # Group by guide_identity AND gemgroup
    groups = adata.obs.groupby(['guide_identity', 'gemgroup']).indices

    pseudobulk_data = []
    meta_data = []

    X = adata.X
    is_sparse = sparse.issparse(X)

    for (guide, gemgroup), cell_idx in groups.items():
        if len(cell_idx) < min_cells:
            continue

        # Sum counts for this guide-gemgroup combination
        if is_sparse:
            summed_counts = np.array(X[cell_idx].sum(axis=0)).ravel()
        else:
            summed_counts = X[cell_idx].sum(axis=0)

        pseudobulk_data.append(summed_counts)
        meta_data.append({
            'guide_identity': guide,
            'gemgroup': gemgroup,
            'n_cells': len(cell_idx),
            'targets': adata.obs.loc[adata.obs_names[cell_idx[0]], 'parsed_targets'],
            'perturbation_type': adata.obs.loc[adata.obs_names[cell_idx[0]], 'perturbation_type']
        })

    pseudobulk_df = pd.DataFrame(pseudobulk_data, columns=adata.var_names)
    pseudobulk_df.index = pd.MultiIndex.from_frame(pd.DataFrame(meta_data)[['guide_identity', 'gemgroup']])
    meta_df = pd.DataFrame(meta_data)

    print(f"Created pseudobulk for {len(pseudobulk_df)} guide-gemgroup combinations")
    print(f"Gemgroups with data: {meta_df['gemgroup'].nunique()}")
    print(f"Control combinations: {(meta_df['perturbation_type'] == 'control').sum()}")

    return pseudobulk_df, meta_df

# IMPROVED: Gemgroup-matched effect computation
def compute_gemgroup_matched_combinatorial_effects(pseudobulk_df, meta_df):
    """Compute effects with gemgroup-matched controls for combinatorial analysis"""
    print("Computing gemgroup-matched effects for combinatorial analysis...")

    # Normalize to CPM and log-transform
    lib_sizes = pseudobulk_df.sum(axis=1)
    cpm = pseudobulk_df.div(lib_sizes, axis=0) * 1e6
    logcpm = np.log1p(cpm)

    # Identify control guides
    control_guides = set(meta_df[meta_df['perturbation_type'] == 'control']['guide_identity'])

    if len(control_guides) == 0:
        raise ValueError("No control perturbations found")

    print(f"Found {len(control_guides)} control guides")

    # Get unique guides (excluding controls for effects)
    all_guides = sorted(set(meta_df['guide_identity']))
    target_guides = [g for g in all_guides if g not in control_guides]

    print(f"Computing effects for {len(target_guides)} target guides")

    effects = {}

    for guide in target_guides:
        try:
            guide_profiles = logcpm.loc[guide]
            if isinstance(guide_profiles, pd.Series):
                guide_profiles = pd.DataFrame([guide_profiles.values],
                                            index=[guide_profiles.name],
                                            columns=logcpm.columns)

            guide_effects = []

            # For each gemgroup where this guide appears
            for gemgroup in guide_profiles.index:
                # Find control profiles in the same gemgroup
                control_profiles_in_gemgroup = []
                for control_guide in control_guides:
                    try:
                        ctrl_profile = logcpm.loc[(control_guide, gemgroup)]
                        control_profiles_in_gemgroup.append(ctrl_profile)
                    except KeyError:
                        continue

                if control_profiles_in_gemgroup:
                    # Average control profiles in this gemgroup
                    if len(control_profiles_in_gemgroup) == 1:
                        control_mean = control_profiles_in_gemgroup[0]
                    else:
                        control_mean = pd.concat(control_profiles_in_gemgroup, axis=1).mean(axis=1)

                    # Compute effect: target - matched control
                    effect = guide_profiles.loc[gemgroup] - control_mean
                    guide_effects.append(effect)
                else:
                    print(f"Warning: No control found for guide {guide} in gemgroup {gemgroup}")

            if guide_effects:
                # Average effects across gemgroups
                if len(guide_effects) == 1:
                    mean_effect = guide_effects[0]
                else:
                    mean_effect = pd.concat(guide_effects, axis=1).mean(axis=1)
                effects[guide] = mean_effect.values

        except KeyError:
            print(f"Warning: Could not process guide {guide}")
            continue

    if len(effects) == 0:
        raise ValueError("No effects computed")

    # Create effects DataFrame
    effects_df = pd.DataFrame.from_dict(effects, orient='index', columns=logcpm.columns)

    # Z-score standardize each gene (with improved handling)
    effects_mean = effects_df.mean(axis=0)
    effects_std = effects_df.std(axis=0, ddof=0)
    effects_std = effects_std.replace(0, 1.0)  # Avoid division by zero

    effects_zscore = effects_df.subtract(effects_mean, axis=1).divide(effects_std, axis=1)
    effects_zscore = effects_zscore.replace([np.inf, -np.inf], 0).fillna(0)

    print(f"Computed effects for {len(effects_zscore)} guides")
    print(f"Effect range: [{effects_zscore.values.min():.3f}, {effects_zscore.values.max():.3f}]")

    return effects_zscore

# IMPROVED: Multiple combination operators
def create_combined_embeddings_multi_method(target_combinations, generain_embeddings, method='average'):
    """Create embeddings with different combination methods"""
    combined_embeddings = {}

    for combo_name, gene_list in target_combinations.items():
        if len(gene_list) == 0:
            # Control - use zero vector
            combined_embeddings[combo_name] = np.zeros(generain_embeddings.shape[1])
        elif len(gene_list) == 1:
            # Single gene
            gene = gene_list[0]
            if gene in generain_embeddings.index:
                combined_embeddings[combo_name] = generain_embeddings.loc[gene].values
            else:
                continue  # Skip if gene not found
        else:
            # Multiple genes - combine embeddings
            available_genes = [g for g in gene_list if g in generain_embeddings.index]

            if len(available_genes) == 0:
                continue  # Skip if no genes found

            gene_embeds = np.stack([generain_embeddings.loc[g].values for g in available_genes])

            if method == 'average':
                combined_embeddings[combo_name] = np.mean(gene_embeds, axis=0)
            elif method == 'sum':
                combined_embeddings[combo_name] = np.sum(gene_embeds, axis=0)
            elif method == 'hadamard':
                combined_embeddings[combo_name] = np.prod(gene_embeds, axis=0)
            elif method == 'concatenate':
                # Pad to consistent length (assume max 2 genes)
                if len(gene_embeds) == 1:
                    padded = np.concatenate([gene_embeds[0], np.zeros_like(gene_embeds[0])])
                else:
                    padded = np.concatenate(gene_embeds[:2])  # Take first 2
                combined_embeddings[combo_name] = padded

    # Convert to DataFrame
    combo_embedding_df = pd.DataFrame.from_dict(combined_embeddings, orient='index')
    print(f"Created {method} embeddings: {combo_embedding_df.shape[0]} combinations, {combo_embedding_df.shape[1]} dimensions")

    return combo_embedding_df

def test_combination_operators(target_combinations, generain_embeddings):
    """Test different ways to combine gene embeddings"""
    print("Testing different combination operators...")

    combination_methods = ['average', 'sum', 'hadamard']
    results = {}

    for method in combination_methods:
        print(f"  Creating {method} combinations...")
        combined_embeddings = create_combined_embeddings_multi_method(
            target_combinations, generain_embeddings, method=method)
        results[method] = combined_embeddings

    return results

# Distance calculations with error handling
def pairwise_cosine(M):
    """Compute cosine distances with comprehensive error handling"""
    # Handle NaN and infinity values
    M = np.nan_to_num(M, nan=0.0, posinf=1.0, neginf=-1.0)

    # Check for zero-norm rows
    row_norms = np.linalg.norm(M, axis=1)
    zero_norm_mask = row_norms == 0

    if zero_norm_mask.any():
        print(f"Warning: {zero_norm_mask.sum()} zero-norm rows detected, adding small noise")
        M[zero_norm_mask] += np.random.normal(0, 1e-6, size=(zero_norm_mask.sum(), M.shape[1]))

    try:
        distances = pairwise_distances(M, metric='cosine')
        np.fill_diagonal(distances, 0.0)
        distances = np.nan_to_num(distances, nan=1.0, posinf=1.0, neginf=0.0)
        return distances
    except Exception as e:
        print(f"Distance computation failed: {e}")
        n = M.shape[0]
        return np.eye(n)

# Statistical analysis functions
def mantel_and_spearman(D1, D2, n_perms=999):
    """Compute Spearman correlation and Mantel test"""
    if np.isnan(D1).any() or np.isnan(D2).any():
        D1 = np.nan_to_num(D1, nan=1.0)
        D2 = np.nan_to_num(D2, nan=1.0)

    iu = np.triu_indices_from(D1, k=1)
    dist1, dist2 = D1[iu], D2[iu]

    # Remove any remaining invalid values
    valid_mask = ~(np.isnan(dist1) | np.isnan(dist2) | np.isinf(dist1) | np.isinf(dist2))
    if not valid_mask.all():
        print(f"Removing {(~valid_mask).sum()} invalid distance pairs")
        dist1, dist2 = dist1[valid_mask], dist2[valid_mask]

    if len(dist1) == 0:
        return 0.0, 1.0, 0.0, 1.0

    rho, p = spearmanr(dist1, dist2)

    if n_perms > 0:
        try:
            ids = [f"X{i}" for i in range(D1.shape[0])]
            m1, m2 = DistanceMatrix(D1, ids=ids), DistanceMatrix(D2, ids=ids)
            r_m, p_m, _ = mantel(m1, m2, method='spearman', permutations=n_perms)
            return rho, p, r_m, p_m
        except Exception as e:
            print(f"Mantel test failed: {e}")
            return rho, p, rho, p
    else:
        return rho, p, rho, p

def retrieval_metrics(Dg, Dp, ks=(5,10,20)):
    """Compute retrieval metrics"""
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf)
    np.fill_diagonal(Dp2, np.inf)

    precisions = {k: [] for k in ks}
    aps = []

    for i in range(n):
        rank_g = np.argsort(Dg2[i])
        rank_p = np.argsort(Dp2[i])

        for k in ks:
            precisions[k].append(len(set(rank_g[:k]) & set(rank_p[:k]))/k)

        # Average precision calculation
        K = max(ks)
        true_set = set(rank_p[:K])
        hits, prec_sum = 0, 0.0

        for r, j in enumerate(rank_g, start=1):
            if j in true_set:
                hits += 1
                prec_sum += hits / r
                if hits == K:
                    break

        ap = prec_sum / (K if K > 0 else 1)
        aps.append(ap)

    return {
        "precision_at": {k: float(np.mean(v)) for k, v in precisions.items()},
        "auprc": float(np.mean(aps))
    }

# Plotting functions
def plot_distance_scatter(Dg, Dp, title, out="figs/improved_combinatorial_scatter.png"):
    """Plot distance correlation scatter"""
    os.makedirs("figs", exist_ok=True)
    iu = np.triu_indices_from(Dg, k=1)
    x, y = Dg[iu], Dp[iu]

    # Sample for plotting if too many points
    if len(x) > 20000:
        idx = np.random.choice(len(x), 20000, replace=False)
        x, y = x[idx], y[idx]

    rho, p = spearmanr(x, y)
    plt.figure(figsize=(6,5))
    plt.scatter(x, y, s=8, alpha=0.6)
    plt.title(f"{title}\nSpearman ρ={rho:.3f}, p={p:.2e}")
    plt.xlabel("Combined Embedding Distance")
    plt.ylabel("Combinatorial Perturbation Distance")
    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close()

def plot_method_comparison(results_by_method, out="figs/method_comparison.png"):
    """Plot comparison of different combination methods"""
    os.makedirs("figs", exist_ok=True)

    methods = list(results_by_method.keys())
    rhos = [results_by_method[m]['spearman_rho'] for m in methods]
    precisions = [results_by_method[m]['precision_at'][10] for m in methods]

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))

    # Correlation comparison
    ax1.bar(methods, rhos)
    ax1.set_ylabel('Spearman ρ')
    ax1.set_title('Correlation by Method')
    ax1.tick_params(axis='x', rotation=45)

    # Precision comparison
    ax2.bar(methods, precisions)
    ax2.set_ylabel('Precision@10')
    ax2.set_title('Retrieval by Method')
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close()

# Main improved analysis function
def run_improved_combinatorial_analysis():
    """Run improved combinatorial analysis with gemgroup matching and multiple operators"""
    os.makedirs("results", exist_ok=True)
    os.makedirs("figs", exist_ok=True)

    print("=== IMPROVED COMBINATORIAL ANALYSIS ===\n")

    # Load data
    adata = load_gse133344_data()
    adata = categorize_perturbations(adata)
    generain_emb = load_generain_embeddings(files['generain'])

    # IMPROVEMENT 1: Gemgroup-matched pseudobulk and effects
    pseudobulk_df, meta_df = make_gemgroup_matched_combinatorial_pseudobulk(adata, min_cells=20)
    effects_df = compute_gemgroup_matched_combinatorial_effects(pseudobulk_df, meta_df)

    # Create target combination mapping
    target_combinations = {}
    for _, row in meta_df.iterrows():
        guide = row['guide_identity']
        targets = row['targets']
        target_combinations[guide] = targets

    # Filter to available guides in effects
    available_guides = effects_df.index.tolist()
    available_combinations = {g: target_combinations[g] for g in available_guides if g in target_combinations}

    print(f"Available combinations for analysis: {len(available_combinations)}")

    # IMPROVEMENT 2: Test different combination operators
    embedding_results = test_combination_operators(available_combinations, generain_emb)

    # Test each combination method
    best_results = {}
    best_method = None
    best_rho = -1

    for method, combined_embeddings in embedding_results.items():
        print(f"\n--- Testing {method} combination method ---")

        # Align datasets
        common_guides = list(set(effects_df.index) & set(combined_embeddings.index))
        print(f"Common guides for {method}: {len(common_guides)}")

        if len(common_guides) < 10:
            print(f"Too few guides for {method}, skipping")
            continue

        effects_aligned = effects_df.loc[common_guides]
        embeddings_aligned = combined_embeddings.loc[common_guides]

        # Compute distances
        Dp = pairwise_cosine(effects_aligned.values)
        Dg = pairwise_cosine(embeddings_aligned.values)

        # Statistics
        rho, p, r_mantel, p_mantel = mantel_and_spearman(Dg, Dp, n_perms=999)
        retrieval_results = retrieval_metrics(Dg, Dp)

        method_results = {
            'n_combinations': len(common_guides),
            'spearman_rho': float(rho),
            'spearman_p': float(p),
            'mantel_r': float(r_mantel),
            'mantel_p': float(p_mantel),
            'precision_at': retrieval_results['precision_at'],
            'auprc': retrieval_results['auprc']
        }

        best_results[method] = method_results

        print(f"Correlation: ρ = {rho:.3f}, p = {p:.2e}")
        print(f"Mantel test: r = {r_mantel:.3f}, p = {p_mantel:.2e}")
        print(f"Precision@10: {retrieval_results['precision_at'][10]:.3f}")
        print(f"Random baseline: {10/(len(common_guides)-1):.3f}")

        # Track best method
        if rho > best_rho:
            best_rho = rho
            best_method = method

    # Save comprehensive results
    final_results = {
        'analysis_type': 'improved_combinatorial',
        'improvements_applied': [
            'gemgroup_matched_controls',
            'multiple_combination_operators'
        ],
        'best_combination_method': best_method,
        'results_by_method': best_results
    }

    with open("results/improved_combinatorial_metrics.json", "w") as f:
        json.dump(final_results, f, indent=2)

    # Generate plots
    if best_method and best_method in embedding_results:
        # Plot best method
        best_embeddings = embedding_results[best_method]
        common_guides = list(set(effects_df.index) & set(best_embeddings.index))

        effects_final = effects_df.loc[common_guides]
        embeddings_final = best_embeddings.loc[common_guides]

        Dp_final = pairwise_cosine(effects_final.values)
        Dg_final = pairwise_cosine(embeddings_final.values)

        plot_distance_scatter(Dg_final, Dp_final,
                            f"Improved Combinatorial Analysis ({best_method})")

    # Plot method comparison
    if len(best_results) > 1:
        plot_method_comparison(best_results)

    # Print final results
    print(f"\n=== IMPROVED COMBINATORIAL RESULTS ===")
    print(f"Best combination method: {best_method}")

    if best_method and best_method in best_results:
        best_res = best_results[best_method]
        print(f"Combinations analyzed: {best_res['n_combinations']}")
        print(f"Correlation: ρ = {best_res['spearman_rho']:.3f}, p = {best_res['spearman_p']:.2e}")
        print(f"Mantel test: r = {best_res['mantel_r']:.3f}, p = {best_res['mantel_p']:.2e}")
        print(f"Precision@10: {best_res['precision_at'][10]:.3f}")
        print(f"vs random: {10/(best_res['n_combinations']-1):.3f}")
        print(f"AUPRC: {best_res['auprc']:.3f}")

        # Show improvement over original
        print(f"\nComparison to original analysis:")
        print(f"  Original: ρ = 0.081, P@10 = 0.328")
        print(f"  Improved: ρ = {best_res['spearman_rho']:.3f}, P@10 = {best_res['precision_at'][10]:.3f}")
        print(f"  Correlation improvement: {(best_res['spearman_rho']/0.081 - 1)*100:+.1f}%")
        print(f"  Retrieval improvement: {(best_res['precision_at'][10]/0.328 - 1)*100:+.1f}%")

    return final_results

# Run the improved analysis
print("Ready to run improved combinatorial analysis.")
print("Execute: results = run_improved_combinatorial_analysis()")
results = run_improved_combinatorial_analysis()

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m2.1/2.1 MB[0m [31m91.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m51.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/169.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.9/169.9 kB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/9.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━[0m [32m7.3/9.7 MB[0m [31m223.6 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m9.7/9.7 MB[0m [31m206.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━