In [None]:
!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy

import scanpy as sc
import pandas as pd
import numpy as np
from scipy import sparse
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from skbio.stats.distance import mantel, DistanceMatrix
from scipy.stats import spearmanr
import pingouin as pg
import matplotlib.pyplot as plt
import os, json

from google.colab import drive
drive.mount('/content/drive')

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━[0m [32m1.4/2.1 MB[0m [31m42.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m40.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.9/169.9 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m52.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.4/204.4 kB[0m [31m24.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.2/58.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m54.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy

import pandas as pd
import numpy as np
import gzip
import scipy.io
from scipy.sparse import csr_matrix
import scanpy as sc
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from skbio.stats.distance import mantel, DistanceMatrix
from scipy.stats import spearmanr
import pingouin as pg
import matplotlib.pyplot as plt
import os, json

from google.colab import drive
drive.mount('/content/drive')

# File paths
base_path = "/content/drive/MyDrive/dataset-gene-embed/"
files = {
    'barcodes': base_path + "GSE133344_filtered_barcodes.tsv.gz",
    'cell_identities': base_path + "GSE133344_filtered_cell_identities.csv.gz",
    'matrix': base_path + "GSE133344_filtered_matrix.mtx",
    'genes': base_path + "GSE133344_filtered_genes.tsv",
    'generain': base_path + "GeneRAIN-vec.200d.txt"
}

# Load GSE133344 data
def load_gse133344_data():
    """Load GSE133344 dataset and create AnnData object"""
    print("Loading GSE133344 dataset...")

    # Load genes
    genes = pd.read_csv(files['genes'], sep='\t', header=None, names=['ensembl_id', 'gene_symbol'])

    # Load barcodes
    with gzip.open(files['barcodes'], 'rt') as f:
        barcodes = [line.strip() for line in f]

    # Load cell identities
    cell_identities = pd.read_csv(files['cell_identities'], compression='gzip')

    # Load expression matrix
    matrix = scipy.io.mmread(files['matrix']).T.tocsr()  # Transpose to cells x genes

    # Create AnnData object
    adata = sc.AnnData(X=matrix)
    adata.obs_names = barcodes
    adata.var_names = genes['gene_symbol'].values
    adata.var['ensembl_id'] = genes['ensembl_id'].values

    # Add cell annotations (align by barcode)
    cell_identities = cell_identities.set_index('cell_barcode')
    # Only keep cells that are in both datasets
    common_cells = list(set(adata.obs_names) & set(cell_identities.index))
    adata = adata[common_cells].copy()
    adata.obs = adata.obs.join(cell_identities.loc[common_cells])

    print(f"Loaded dataset: {adata.n_obs} cells x {adata.n_vars} genes")
    print(f"Perturbation guides: {adata.obs['guide_identity'].nunique()}")

    return adata

# Parse and filter for single-gene perturbations
def parse_single_gene_targets(guide_identity):
    """Extract single gene targets, return None for multi-gene or controls"""
    if pd.isna(guide_identity):
        return None

    # Extract the first part before '__'
    target_part = guide_identity.split('__')[0]
    components = target_part.split('_')

    # Filter out control components
    real_targets = [comp for comp in components if 'NegCtrl' not in comp and len(comp) > 0]

    # Return single gene or None
    if len(real_targets) == 1:
        return real_targets[0]
    elif len(real_targets) == 0:
        return 'CONTROL'  # This is a control
    else:
        return None  # Multi-gene, exclude from single-gene analysis

def filter_single_gene_data(adata):
    """Filter dataset to only single-gene perturbations and controls"""
    print("Filtering for single-gene perturbations...")

    # Parse targets
    adata.obs['single_target'] = adata.obs['guide_identity'].apply(parse_single_gene_targets)

    # Keep only single genes and controls
    keep_mask = adata.obs['single_target'].notna()
    adata_filtered = adata[keep_mask].copy()

    print(f"After filtering: {adata_filtered.n_obs} cells")

    # Show target distribution
    target_counts = adata_filtered.obs['single_target'].value_counts()
    print(f"Number of unique targets: {len(target_counts)}")
    print(f"Top 10 targets:")
    print(target_counts.head(10))

    # Statistics
    control_cells = (adata_filtered.obs['single_target'] == 'CONTROL').sum()
    perturbed_cells = adata_filtered.n_obs - control_cells
    print(f"Control cells: {control_cells}")
    print(f"Perturbed cells: {perturbed_cells}")

    return adata_filtered

# Load GeneRAIN embeddings (same as combinatorial)
def load_generain_embeddings(file_path):
    """Load GeneRAIN embeddings"""
    print(f"Loading GeneRAIN embeddings from {file_path}...")

    genes = []
    embeddings = []

    with open(file_path, 'r') as f:
        # Skip metadata line
        first_line = f.readline().strip()
        print(f"Metadata: {first_line}")

        for line_num, line in enumerate(f, 2):
            parts = line.strip().split()
            if len(parts) < 2:
                continue

            gene_name = parts[0]
            embedding_values = [float(x) for x in parts[1:]]
            genes.append(gene_name)
            embeddings.append(embedding_values)

    df = pd.DataFrame(embeddings, index=genes)
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]

    print(f"Loaded GeneRAIN: {df.shape[0]} genes, {df.shape[1]} dimensions")
    return df

# Enhanced gene mapping with aliases
def create_gene_alias_mapping():
    """Create comprehensive gene alias mapping"""
    return {
        # Common chromatin/epigenetic aliases
        'KLF1': 'KLF1', 'TBX3': 'TBX3', 'TBX2': 'TBX2', 'CEBPE': 'CEBPE',
        'RUNX1T1': 'RUNX1T1', 'ETS2': 'ETS2', 'CNN1': 'CNN1', 'SLC4A1': 'SLC4A1',
        'UBASH3B': 'UBASH3B', 'OSR2': 'OSR2', 'ARID1A': 'ARID1A', 'BCORL1': 'BCORL1',
        'FOSB': 'FOSB', 'SET': 'SET', 'BAK1': 'BAK1', 'FOXA3': 'FOXA3', 'FOXL2': 'FOXL2',
        'TP73': 'TP73', 'HES7': 'HES7', 'IRF1': 'IRF1',

        # Transcription factors
        'KLF1': 'KLF1', 'SOX6': 'SOX6', 'GATA1': 'GATA1', 'TAL1': 'TAL1',
        'LMO2': 'LMO2', 'LDB1': 'LDB1', 'E2A': 'TCF3', 'HEB': 'TCF12',

        # Additional common aliases
        'TCF3': 'TCF3', 'TCF12': 'TCF12', 'MYB': 'MYB', 'RUNX1': 'RUNX1',
        'FLI1': 'FLI1', 'ERG': 'ERG', 'ELK1': 'ELK1', 'ELK4': 'ELK4',
        'ETS1': 'ETS1', 'SPI1': 'SPI1', 'SPIB': 'SPIB', 'PU1': 'SPI1'
    }

def robust_gene_mapping(target_genes, embedding_genes):
    """Map target genes to embedding genes with alias resolution"""
    alias_map = create_gene_alias_mapping()

    # Direct matches
    direct_matches = set(target_genes) & set(embedding_genes)

    # Alias matches
    alias_matches = {}
    for target in target_genes:
        if target in alias_map and alias_map[target] in embedding_genes:
            alias_matches[target] = alias_map[target]

    # Case-insensitive matches
    target_upper = {g.upper(): g for g in target_genes}
    embed_upper = {g.upper(): g for g in embedding_genes}
    case_matches = set(target_upper.keys()) & set(embed_upper.keys())

    # Combine all mappings
    target_to_embed = {}

    # Add direct matches
    for target in direct_matches:
        target_to_embed[target] = target

    # Add alias matches
    for target, canonical in alias_matches.items():
        if target not in target_to_embed:
            target_to_embed[target] = canonical

    # Add case matches
    for case_match in case_matches:
        orig_target = target_upper[case_match]
        if orig_target not in target_to_embed:
            target_to_embed[orig_target] = embed_upper[case_match]

    print(f"Gene mapping results:")
    print(f"  - Direct matches: {len(direct_matches)}")
    print(f"  - Alias matches: {len(alias_matches)}")
    print(f"  - Case matches: {len(case_matches) - len(direct_matches)}")
    print(f"  - Total mapped: {len(target_to_embed)}/{len(target_genes)}")

    unmapped = set(target_genes) - set(target_to_embed.keys())
    if unmapped:
        print(f"  - Unmapped: {sorted(list(unmapped))[:10]}")

    return target_to_embed

# Pseudobulk aggregation
def make_single_gene_pseudobulk(adata, min_cells=30):
    """Create pseudobulk for each single gene target"""
    target_counts = adata.obs['single_target'].value_counts()

    # Filter targets with sufficient cells
    sufficient_targets = target_counts[target_counts >= min_cells].index.tolist()
    adata_sufficient = adata[adata.obs['single_target'].isin(sufficient_targets)].copy()

    print(f"Targets with ≥{min_cells} cells: {len(sufficient_targets)}")

    # Group by target
    groups = adata_sufficient.obs.groupby('single_target').indices

    pseudobulk_data = []
    meta_data = []

    for target, cell_idx in groups.items():
        # Sum counts for this target
        summed_counts = np.array(adata_sufficient.X[cell_idx].sum(axis=0)).ravel()

        pseudobulk_data.append(summed_counts)
        meta_data.append({
            'target': target,
            'n_cells': len(cell_idx),
            'is_control': target == 'CONTROL'
        })

    pseudobulk_df = pd.DataFrame(pseudobulk_data, columns=adata_sufficient.var_names)
    pseudobulk_df.index = [m['target'] for m in meta_data]
    meta_df = pd.DataFrame(meta_data)

    print(f"Created pseudobulk for {len(pseudobulk_df)} targets")
    return pseudobulk_df, meta_df

def compute_single_gene_effects(pseudobulk_df, meta_df):
    """Compute single gene perturbation effects"""
    # Normalize to CPM and log-transform
    lib_sizes = pseudobulk_df.sum(axis=1)
    cpm = pseudobulk_df.div(lib_sizes, axis=0) * 1e6
    logcpm = np.log1p(cpm)

    # Get control profile
    control_profile = logcpm.loc['CONTROL']

    # Compute effects for non-control targets
    perturbed_targets = [t for t in pseudobulk_df.index if t != 'CONTROL']
    effects = {}

    for target in perturbed_targets:
        effect = logcpm.loc[target] - control_profile
        effects[target] = effect.values

    # Create effects DataFrame
    effects_df = pd.DataFrame.from_dict(effects, orient='index', columns=pseudobulk_df.columns)

    # Z-score standardize each gene with proper handling of zero variance
    effects_mean = effects_df.mean(axis=0)
    effects_std = effects_df.std(axis=0, ddof=0)

    # Replace zero standard deviations with 1 to avoid division by zero
    effects_std = effects_std.replace(0, 1.0)

    # Standardize
    effects_zscore = effects_df.subtract(effects_mean, axis=1).divide(effects_std, axis=1)

    # Replace any infinite or NaN values with 0
    effects_zscore = effects_zscore.replace([np.inf, -np.inf], 0).fillna(0)

    print(f"Computed effects for {len(effects_zscore)} targets")
    print(f"Data range: [{effects_zscore.values.min():.3f}, {effects_zscore.values.max():.3f}]")

    return effects_zscore

# Coexpression baseline
def coexpression_embedding(adata, targets, n_components=200):
    """Generate coexpression embedding for comparison"""
    # Extract expression for target genes only
    target_gene_idx = [i for i, gene in enumerate(adata.var_names) if gene in targets]
    target_expression = adata.X[:, target_gene_idx]

    # Mean-center and apply SVD
    X = target_expression.toarray() if hasattr(target_expression, 'toarray') else target_expression
    X = X - X.mean(axis=0)

    # Handle potential numerical issues
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

    n_components_actual = min(n_components, len(targets)-1, X.shape[0]-1)
    if n_components_actual <= 0:
        n_components_actual = 1

    svd = TruncatedSVD(n_components=n_components_actual, random_state=42)
    try:
        svd.fit(X)
        gene_embeddings = svd.components_.T
    except Exception as e:
        print(f"SVD failed, using random embeddings: {e}")
        gene_embeddings = np.random.normal(0, 1, (len(targets), n_components_actual))

    target_genes = [adata.var_names[i] for i in target_gene_idx]

    return pd.DataFrame(gene_embeddings, index=target_genes,
                       columns=[f"coexpr_{i}" for i in range(gene_embeddings.shape[1])])

# Distance calculations and statistics
def pairwise_cosine(M):
    """Compute cosine distances with comprehensive error handling"""
    # Handle NaN and infinity values
    M = np.nan_to_num(M, nan=0.0, posinf=1.0, neginf=-1.0)

    # Check for zero-norm rows (would cause division by zero in cosine)
    row_norms = np.linalg.norm(M, axis=1)
    zero_norm_mask = row_norms == 0

    if zero_norm_mask.any():
        print(f"Warning: {zero_norm_mask.sum()} zero-norm rows detected, adding small noise")
        M[zero_norm_mask] += np.random.normal(0, 1e-6, size=(zero_norm_mask.sum(), M.shape[1]))

    try:
        distances = pairwise_distances(M, metric='cosine')
        # Ensure diagonal is exactly 0 and handle any remaining numerical issues
        np.fill_diagonal(distances, 0.0)
        distances = np.nan_to_num(distances, nan=1.0, posinf=1.0, neginf=0.0)
        return distances
    except Exception as e:
        print(f"Distance computation failed: {e}")
        # Return identity matrix as fallback
        n = M.shape[0]
        return np.eye(n)

def mantel_and_spearman(D1, D2, n_perms=999):
    if np.isnan(D1).any() or np.isnan(D2).any():
        D1 = np.nan_to_num(D1, nan=1.0)
        D2 = np.nan_to_num(D2, nan=1.0)

    iu = np.triu_indices_from(D1, k=1)
    rho, p = spearmanr(D1[iu], D2[iu])

    if n_perms > 0:
        try:
            ids = [f"X{i}" for i in range(D1.shape[0])]
            m1, m2 = DistanceMatrix(D1, ids=ids), DistanceMatrix(D2, ids=ids)
            r_m, p_m, _ = mantel(m1, m2, method='spearman', permutations=n_perms)
            return rho, p, r_m, p_m
        except:
            return rho, p, rho, p
    return rho, p, rho, p

def partial_corr_vectorized(Dg, Dp, Dc):
    iu = np.triu_indices_from(Dg, k=1)
    df = pd.DataFrame({"Dg": Dg[iu], "Dp": Dp[iu], "Dc": Dc[iu]})
    df = df.dropna()

    if len(df) == 0:
        return 0.0, 1.0

    try:
        res = pg.partial_corr(data=df, x="Dg", y="Dp", covar="Dc", method="spearman")
        return float(res["r"]), float(res["p-val"])
    except:
        return 0.0, 1.0

def retrieval_metrics(Dg, Dp, ks=(5,10,20)):
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf)
    np.fill_diagonal(Dp2, np.inf)

    precisions = {k: [] for k in ks}
    for i in range(n):
        rank_g = np.argsort(Dg2[i])
        rank_p = np.argsort(Dp2[i])
        for k in ks:
            precisions[k].append(len(set(rank_g[:k]) & set(rank_p[:k]))/k)

    return {k: float(np.mean(v)) for k, v in precisions.items()}

# Plotting
def plot_distance_scatter(Dg, Dp, title, out="figs/single_gene_scatter.png"):
    os.makedirs("figs", exist_ok=True)
    iu = np.triu_indices_from(Dg, k=1)
    x, y = Dg[iu], Dp[iu]

    rho, p = spearmanr(x, y)
    plt.figure(figsize=(6,5))
    plt.scatter(x, y, s=8, alpha=0.6)
    plt.title(f"{title}\nSpearman ρ={rho:.3f}, p={p:.2e}")
    plt.xlabel("Gene Embedding Distance")
    plt.ylabel("Single-Gene Perturbation Distance")
    plt.tight_layout()
    plt.savefig(out, dpi=200)
    plt.close()

# Main analysis function
def run_single_gene_analysis():
    """Run the full single-gene perturbation analysis"""
    os.makedirs("results", exist_ok=True)

    print("=== SINGLE-GENE PERTURBATION ANALYSIS ===\n")

    # Load and filter data
    adata = load_gse133344_data()
    adata_filtered = filter_single_gene_data(adata)

    # Load GeneRAIN embeddings
    generain_emb = load_generain_embeddings(files['generain'])

    # Create pseudobulk
    pseudobulk_df, meta_df = make_single_gene_pseudobulk(adata_filtered, min_cells=30)

    # Compute effects
    effects_df = compute_single_gene_effects(pseudobulk_df, meta_df)

    # Get unique targets (excluding controls)
    targets = effects_df.index.tolist()

    # Map genes to embeddings
    gene_mapping = robust_gene_mapping(targets, generain_emb.index.tolist())

    # Align datasets
    mapped_targets = list(gene_mapping.keys())
    print(f"Targets mapped to GeneRAIN: {len(mapped_targets)}")

    if len(mapped_targets) < 10:
        print("ERROR: Too few targets mapped for analysis")
        return None

    effects_aligned = effects_df.loc[mapped_targets]
    embeddings_aligned = generain_emb.loc[[gene_mapping[t] for t in mapped_targets]]

    # Generate coexpression baseline
    coexpr_emb = coexpression_embedding(adata_filtered, mapped_targets, n_components=200)
    coexpr_aligned = coexpr_emb.loc[mapped_targets]

    # Compute distance matrices
    Dp = pairwise_cosine(effects_aligned.values)
    Dg = pairwise_cosine(embeddings_aligned.values)
    Dc = pairwise_cosine(coexpr_aligned.values)

    # Statistical analysis
    rho, p, r_mantel, p_mantel = mantel_and_spearman(Dg, Dp, n_perms=999)
    rho_coexpr, p_coexpr, _, _ = mantel_and_spearman(Dc, Dp, n_perms=0)
    part_rho, part_p = partial_corr_vectorized(Dg, Dp, Dc)

    retrieval_results = retrieval_metrics(Dg, Dp)

    # Results
    results = {
        'analysis_type': 'single_gene',
        'n_targets': len(mapped_targets),
        'generain_embedding': {
            'spearman_rho': float(rho),
            'spearman_p': float(p),
            'mantel_r': float(r_mantel),
            'mantel_p': float(p_mantel),
            'precision_at': retrieval_results
        },
        'coexpression_baseline': {
            'spearman_rho': float(rho_coexpr),
            'spearman_p': float(p_coexpr)
        },
        'partial_correlation': {
            'rho': float(part_rho),
            'p': float(part_p)
        }
    }

    # Save results
    with open("results/single_gene_metrics.json", "w") as f:
        json.dump(results, f, indent=2)

    # Plot
    plot_distance_scatter(Dg, Dp, "Single-Gene Perturbation Analysis")

    # Print results
    print("\n=== SINGLE-GENE ANALYSIS RESULTS ===")
    print(f"Targets analyzed: {len(mapped_targets)}")
    print(f"GeneRAIN correlation: ρ = {rho:.3f}, p = {p:.2e}")
    print(f"Coexpression baseline: ρ = {rho_coexpr:.3f}, p = {p_coexpr:.2e}")
    print(f"Partial correlation: ρ = {part_rho:.3f}, p = {part_p:.2e}")
    print(f"Precision@10: {retrieval_results[10]:.3f}")
    print(f"Random baseline: {10/(len(mapped_targets)-1):.3f}")

    return results

# Run analysis
results = run_single_gene_analysis()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
=== SINGLE-GENE PERTURBATION ANALYSIS ===

Loading GSE133344 dataset...


  utils.warn_names_duplicates("var")


Loaded dataset: 111445 cells x 33694 genes
Perturbation guides: 290
Filtering for single-gene perturbations...


  utils.warn_names_duplicates("var")


After filtering: 69686 cells
Number of unique targets: 106
Top 10 targets:
single_target
CONTROL    11855
KLF1        1960
BAK1        1457
CEBPE       1233
UBASH3B     1202
ETS2        1201
OSR2        1003
SLC4A1      1000
SET          986
ELMSAN1      937
Name: count, dtype: int64
Control cells: 11855
Perturbed cells: 57831
Loading GeneRAIN embeddings from /content/drive/MyDrive/dataset-gene-embed/GeneRAIN-vec.200d.txt...
Metadata: 31769 200
Loaded GeneRAIN: 31769 genes, 200 dimensions


  utils.warn_names_duplicates("var")


Targets with ≥30 cells: 106
Created pseudobulk for 106 targets
Computed effects for 105 targets
Data range: [-8.100, 10.198]
Gene mapping results:
  - Direct matches: 101
  - Alias matches: 21
  - Case matches: 0
  - Total mapped: 101/105
  - Unmapped: ['C19orf26', 'C3orf72', 'ELMSAN1', 'KIAA1804']
Targets mapped to GeneRAIN: 101


  return float(res["r"]), float(res["p-val"])



=== SINGLE-GENE ANALYSIS RESULTS ===
Targets analyzed: 101
GeneRAIN correlation: ρ = 0.049, p = 4.82e-04
Coexpression baseline: ρ = -0.016, p = 2.41e-01
Partial correlation: ρ = 0.049, p = 4.88e-04
Precision@10: 0.122
Random baseline: 0.100


In [None]:
# =========================
# 0) Installs (Colab)
# =========================
!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy

# =========================
# 1) Imports & paths
# =========================
import os, json, gzip
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.io
from scipy import sparse
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from scipy.stats import spearmanr
from skbio.stats.distance import mantel, DistanceMatrix
import pingouin as pg
import matplotlib.pyplot as plt

from google.colab import drive
drive.mount('/content/drive')

# ---- Edit these paths for your Drive layout ----
BASE = "/content/drive/MyDrive/dataset-gene-embed"
FILES = {
    "barcodes": f"{BASE}/GSE133344_filtered_barcodes.tsv.gz",
    "cell_id": f"{BASE}/GSE133344_filtered_cell_identities.csv.gz",
    "matrix":   f"{BASE}/GSE133344_filtered_matrix.mtx",   # if you kept it gzipped, use .mtx.gz and mmread handles it
    "genes":    f"{BASE}/GSE133344_filtered_genes.tsv",
    "generain": f"{BASE}/GeneRAIN-vec.200d.txt",
}
os.makedirs("results", exist_ok=True)
os.makedirs("figs", exist_ok=True)

# =========================
# 2) I/O & parsing helpers
# =========================
def load_gse133344():
    """
    Load Norman 2019 (GSE133344) filtered matrix into an AnnData object.
    Ensures var_names are unique and merges the provided cell identities.
    """
    print("Loading GSE133344...")
    genes = pd.read_csv(FILES["genes"], sep="\t", header=None, names=["ensembl_id","gene_symbol"])
    # barcodes
    with gzip.open(FILES["barcodes"], "rt") as f:
        barcodes = [ln.strip() for ln in f]
    # cell identities
    cell_id = pd.read_csv(FILES["cell_id"], compression="gzip")
    # expression matrix (cells x genes)
    X = scipy.io.mmread(FILES["matrix"]).T.tocsr()

    ad = sc.AnnData(X=X)
    ad.obs_names = barcodes
    ad.var_names = genes["gene_symbol"].values
    ad.var["ensembl_id"] = genes["ensembl_id"].values
    ad.var_names_make_unique()

    cell_id = cell_id.set_index("cell_barcode")
    common = list(set(ad.obs_names) & set(cell_id.index))
    ad = ad[common].copy()
    ad.obs = ad.obs.join(cell_id.loc[common])

    ggs = sorted(pd.Series(ad.obs["gemgroup"]).dropna().unique().tolist())
    print(f"Loaded: {ad.n_obs} cells × {ad.n_vars} genes | gemgroups: {ggs}")
    return ad

def parse_single_target(guide_identity: str):
    """
    Return single-gene symbol, 'CONTROL' for negatives, or None for multi-gene.
    """
    if pd.isna(guide_identity):
        return None
    left = guide_identity.split("__")[0]
    parts = [p for p in left.split("_") if p and ("NegCtrl" not in p)]
    if len(parts) == 0:
        return "CONTROL"
    if len(parts) == 1:
        return parts[0]
    return None

def filter_single_gene(ad):
    """
    Keep cells that are CONTROL or single-gene perturbations (drop multis).
    Adds obs['single_target'].
    """
    ad = ad.copy()
    ad.obs["single_target"] = ad.obs["guide_identity"].apply(parse_single_target)
    keep = ad.obs["single_target"].notna()
    ad = ad[keep].copy()
    print(f"After single-gene filter: {ad.n_obs} cells "
          f"(controls={int((ad.obs['single_target']=='CONTROL').sum())})")
    print(f"Unique targets (incl CONTROL): {ad.obs['single_target'].nunique()}")
    return ad

# =========================
# 3) Pseudobulk & effects
# =========================
def pseudobulk_target_gemgroup(ad, min_cells=20):
    """
    Sum counts per (single_target, gemgroup); keep combos with >= min_cells.
    Returns (pseudobulk_df: MultiIndex rows, meta_df).
    """
    assert "single_target" in ad.obs.columns, "call filter_single_gene() first"
    assert "gemgroup" in ad.obs.columns, "gemgroup not present in obs"
    groups = ad.obs.groupby(["single_target","gemgroup"]).indices

    rows, meta = [], []
    X = ad.X
    is_sp = sparse.issparse(X)
    for (t, gg), idx in groups.items():
        if len(idx) < min_cells:
            continue
        v = np.array(X[idx].sum(axis=0)).ravel() if is_sp else X[idx].sum(axis=0)
        rows.append(v)
        meta.append({"target": t, "gemgroup": gg, "n_cells": len(idx)})

    pb = pd.DataFrame(rows, columns=ad.var_names)
    pb.index = pd.MultiIndex.from_frame(pd.DataFrame(meta)[["target","gemgroup"]])
    meta_df = pd.DataFrame(meta)
    print(f"Pseudobulk: {pb.shape[0]} (target,gemgroup) profiles")
    return pb, meta_df

def effects_matched_controls(pb: pd.DataFrame):
    """
    For each target, subtract mean CONTROL in the SAME gemgroup, average across gemgroups.
    Returns E_p (targets × genes) z-scored per gene.
    """
    # log1p-CPM
    lib = pb.sum(axis=1)
    logcpm = np.log1p(pb.div(lib, axis=0) * 1e6)

    targets = [t for t in logcpm.index.unique(level=0) if t != "CONTROL"]
    eff = {}
    for t in targets:
        rows_t = logcpm.loc[t]
        if isinstance(rows_t, pd.Series):
            rows_t = rows_t.to_frame().T
        diffs = []
        for gg in rows_t.index:
            try:
                ctrl_rows = logcpm.loc[("CONTROL", gg)]
            except KeyError:
                continue
            ctrl_mean = ctrl_rows if isinstance(ctrl_rows, pd.Series) else ctrl_rows.mean(axis=0)
            diffs.append(rows_t.loc[gg] - ctrl_mean)
        if diffs:
            e = pd.concat(diffs, axis=1).mean(axis=1)
            eff[t] = e.values

    E_p = pd.DataFrame.from_dict(eff, orient="index", columns=logcpm.columns)

    # per-gene z-score (robust to zero-variance)
    mu = E_p.mean(0)
    sd = E_p.std(0, ddof=0).replace(0, 1.0)
    E_p = (E_p - mu) / sd
    E_p = E_p.replace([np.inf, -np.inf], 0.0).fillna(0.0)
    print(f"Computed matched effects for {E_p.shape[0]} targets")
    return E_p

def qc_on_target_direction(E_p: pd.DataFrame):
    """
    CRISPRa expectation: targeted gene tends to increase on average.
    Report median self-effect over targets present in columns.
    """
    vals = []
    for t in E_p.index:
        if t in E_p.columns:
            vals.append(E_p.loc[t, t])
    if vals:
        med = float(np.median(vals))
        print(f"[QC] On-target self-effect median (CRISPRa should be ≥ 0): {med:.3f}")
        return med
    else:
        print("[QC] No targets present among genes to check self-effect (ok).")
        return None

# =========================
# 4) Embeddings: pretrained & coexpression baseline
# =========================
def load_generain_embeddings(path: str):
    """
    Load GeneRAIN-vec word2vec-style file (first line: N D, remainder: gene dim1..dimD).
    """
    print(f"Loading GeneRAIN vectors from: {path}")
    genes, embeds = [], []
    with open(path, "r") as f:
        _ = f.readline().strip()  # header: "31769 200" etc.
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 2:
                continue
            genes.append(parts[0])
            embeds.append([float(x) for x in parts[1:]])
    df = pd.DataFrame(embeds, index=pd.Index(genes, name="gene"))
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]
    if df.index.duplicated().any():
        df = df[~df.index.duplicated(keep="first")]
    print(f"GeneRAIN loaded: {df.shape[0]} genes × {df.shape[1]} dims")
    return df

def restrict_to_hvgs(ad, E_p: pd.DataFrame, n_top=2000):
    """
    Compute HVGs using Seurat v3 flavor in scanpy, intersect with effect genes.
    """
    ad_tmp = ad.copy()
    sc.pp.highly_variable_genes(ad_tmp, n_top_genes=n_top, flavor="seurat_v3", inplace=True)
    hvgs = set(ad_tmp.var_names[ad_tmp.var["highly_variable"].values])
    kept = [g for g in E_p.columns if g in hvgs]
    E_p_hvg = E_p.loc[:, kept]
    print(f"HVG restriction: using {len(kept)} HVGs out of {E_p.shape[1]} genes")
    return E_p_hvg

def coexpr_baseline_hvg_union_targets(ad, target_genes, n_components=200, n_hvg=3000):
    """
    Fit TruncatedSVD on cells × (HVGs ∪ targets), return embeddings for targets only.
    Guarantees coverage for targets in partial-correlation baseline.
    """
    print("Building coexpression baseline on HVGs ∪ targets...")
    ad2 = ad.copy()
    sc.pp.highly_variable_genes(ad2, n_top_genes=n_hvg, flavor="seurat_v3", inplace=True)

    var_idx = ad2.var_names                                  # pandas Index
    hv_mask = ad2.var["highly_variable"].to_numpy(dtype=bool)
    in_targets = var_idx.isin(list(target_genes))            # np.bool_ array
    sel_mask = np.logical_or(hv_mask, in_targets)            # np.bool_ array

    Xsel = ad.X[:, sel_mask]
    Xsel = Xsel.toarray() if sparse.issparse(Xsel) else np.asarray(Xsel)
    Xsel = np.nan_to_num(Xsel - Xsel.mean(axis=0))

    n_comp = int(min(n_components, Xsel.shape[1]-1, Xsel.shape[0]-1))
    n_comp = max(n_comp, 2)
    svd = TruncatedSVD(n_components=n_comp, random_state=42).fit(Xsel)
    gene_loadings = svd.components_.T
    sel_genes = var_idx[sel_mask]

    emb = pd.DataFrame(gene_loadings, index=sel_genes, columns=[f"coexpr_{i}" for i in range(n_comp)])
    avail = [g for g in target_genes if g in emb.index]
    print(f"Coexpression baseline coverage: {len(avail)}/{len(target_genes)} targets")
    return emb.loc[avail]

# =========================
# 5) Distances & statistics
# =========================
def _prepare_matrix(M):
    M = np.nan_to_num(M, nan=0.0, posinf=0.0, neginf=0.0)
    norms = np.linalg.norm(M, axis=1)
    zero = norms == 0
    if zero.any():
        M[zero] += np.random.normal(0, 1e-6, size=(zero.sum(), M.shape[1]))
    return M

def distance_cosine(M):
    M = _prepare_matrix(M.copy())
    D = pairwise_distances(M, metric="cosine")
    np.fill_diagonal(D, 0.0); D = np.nan_to_num(D)
    return D

def distance_correlation(M):
    M = _prepare_matrix(M.copy())
    C = np.corrcoef(M)
    D = 1 - C
    np.fill_diagonal(D, 0.0); D = np.nan_to_num(D)
    return D

def mantel_and_spearman(D1, D2, n_perms=999):
    iu = np.triu_indices_from(D1, k=1)
    rho, p = spearmanr(D1[iu], D2[iu])
    try:
        ids = [f"X{i}" for i in range(D1.shape[0])]
        m1, m2 = DistanceMatrix(D1, ids=ids), DistanceMatrix(D2, ids=ids)
        r_m, p_m, _ = mantel(m1, m2, method="spearman", permutations=n_perms)
    except Exception:
        r_m, p_m = rho, p
    return float(rho), float(p), float(r_m), float(p_m)

def partial_corr_vectorized(Dg, Dp, Dc):
    iu = np.triu_indices_from(Dg, 1)
    df = pd.DataFrame({"Dg": Dg[iu], "Dp": Dp[iu], "Dc": Dc[iu]}).dropna()
    if df.empty:
        return 0.0, 1.0
    try:
        res = pg.partial_corr(data=df, x="Dg", y="Dp", covar="Dc", method="spearman")
        return float(res["r"].iloc[0]), float(res["p-val"].iloc[0])
    except Exception:
        return 0.0, 1.0

def retrieval_metrics(Dg, Dp, ks=(5,10,20)):
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf); np.fill_diagonal(Dp2, np.inf)
    precisions = {k: [] for k in ks}
    aps = []
    for i in range(n):
        rg = np.argsort(Dg2[i]); rp = np.argsort(Dp2[i])
        for k in ks:
            precisions[k].append(len(set(rg[:k]) & set(rp[:k]))/k)
        # simple AP vs top-K neighbors in Dp
        K = max(ks); true_set = set(rp[:K])
        hits, prec_sum = 0, 0.0
        for r, j in enumerate(rg, start=1):
            if j in true_set:
                hits += 1; prec_sum += hits/r
                if hits == K: break
        aps.append(prec_sum / (K if K>0 else 1))
    out = {"precision_at": {k: float(np.mean(v)) for k,v in precisions.items()},
           "auprc": float(np.mean(aps))}
    return out

def label_shuffle_p(Dg, Dp, n=1000, rng=42):
    """
    Label-shuffle null that preserves Dg structure; permutes target labels in Dp.
    """
    RS = np.random.RandomState(rng)
    iu = np.triu_indices_from(Dg, 1)
    obs = spearmanr(Dg[iu], Dp[iu])[0]
    null = []
    for _ in range(n):
        perm = RS.permutation(Dp.shape[0])
        Dp_perm = Dp[perm][:, perm]
        null.append(spearmanr(Dg[iu], Dp_perm[iu])[0])
    null = np.asarray(null)
    p = (np.sum(null >= obs) + 1) / (n + 1)
    return float(obs), float(p)

def bootstrap_p_at_k(Dg, Dp, k=10, B=1000, rng=42):
    """
    Bootstrap CI for Precision@k across targets.
    """
    RS = np.random.RandomState(rng); n = Dg.shape[0]
    vals = []
    for _ in range(B):
        idx = RS.choice(n, n, replace=True)
        Dg_b, Dp_b = Dg[np.ix_(idx, idx)], Dp[np.ix_(idx, idx)]
        np.fill_diagonal(Dg_b, np.inf); np.fill_diagonal(Dp_b, np.inf)
        vals.append(retrieval_metrics(Dg_b, Dp_b, ks=(k,))["precision_at"][k])
    lo, hi = np.percentile(vals, [2.5, 97.5])
    return float(np.mean(vals)), (float(lo), float(hi))

# =========================
# 6) Runner (single gene)
# =========================
def run_improved_single_gene_analysis(min_cells_pb=20, n_hvg_effects=2000,
                                      coexpr_hvg=3000, n_perm_mantel=999,
                                      ks=(5,10,20)):
    ad = load_gse133344()
    ad = filter_single_gene(ad)

    # (target,gemgroup) pseudobulk & matched-control effects
    pb, meta = pseudobulk_target_gemgroup(ad, min_cells=min_cells_pb)
    E_p = effects_matched_controls(pb)
    qc_on_target_direction(E_p)

    # Restrict effects to HVGs for distance geometry
    E_p_hvg = restrict_to_hvgs(ad, E_p, n_top=n_hvg_effects)

    # Load pretrained gene vectors & map targets
    E_g_pre = load_generain_embeddings(FILES["generain"])
    targets = E_p_hvg.index.tolist()  # gene symbols (single targets)
    mapped_pre = [t for t in targets if t in E_g_pre.index]
    print(f"GeneRAIN mapped: {len(mapped_pre)}/{len(targets)}")

    # Coexpression baseline on HVGs ∪ targets, then align all three
    E_g_co_all = coexpr_baseline_hvg_union_targets(ad, mapped_pre, n_components=200, n_hvg=coexpr_hvg)
    final_targets = sorted(set(mapped_pre) & set(E_g_co_all.index) & set(E_p_hvg.index))
    print(f"Final aligned targets across (E_p, GeneRAIN, coexpr): {len(final_targets)}")
    if len(final_targets) < 30:
        print("ERROR: too few targets after alignment for stable statistics.")
        return None

    # Align matrices
    EpA = E_p_hvg.loc[final_targets]
    EgA = E_g_pre.loc[final_targets]
    EcA = E_g_co_all.loc[final_targets]
    # Alignment sanity check
    assert (EpA.index == EgA.index).all() and (EpA.index == EcA.index).all(), "Target index misaligned"
    print("Aligned targets:", EpA.shape[0])

    # Distances: evaluate cosine vs correlation; choose best by Spearman ρ
    Dg_cos, Dg_cor = distance_cosine(EgA.values),      distance_correlation(EgA.values)
    Dp_cos, Dp_cor = distance_cosine(EpA.values),      distance_correlation(EpA.values)
    Dc_cos, Dc_cor = distance_cosine(EcA.values),      distance_correlation(EcA.values)

    combos = {
        "cosine":      (Dg_cos, Dp_cos, Dc_cos),
        "correlation": (Dg_cor, Dp_cor, Dc_cor),
    }
    results = {}
    best_key, best_rho = None, -np.inf

    for key, (Dg, Dp, Dc) in combos.items():
        rho, p, r_m, p_m = mantel_and_spearman(Dg, Dp, n_perms=n_perm_mantel)
        rho_c, p_c, _, _ = mantel_and_spearman(Dc, Dp, n_perms=0)
        prho, pp = partial_corr_vectorized(Dg, Dp, Dc)
        ret = retrieval_metrics(Dg, Dp, ks=ks)
        obs, p_lbl = label_shuffle_p(Dg, Dp, n=1000)
        p10_mean, p10_ci = bootstrap_p_at_k(Dg, Dp, k=10, B=1000)

        results[key] = {
            "spearman_rho": rho, "spearman_p": p,
            "mantel_r": r_m, "mantel_p": p_m,
            "label_shuffle_rho": obs, "label_shuffle_p": p_lbl,
            "coexpr_rho": rho_c, "coexpr_p": p_c,
            "partial_rho": prho, "partial_p": pp,
            "precision_at": ret["precision_at"], "auprc": ret["auprc"],
            "p10_boot_mean": p10_mean, "p10_boot_ci": p10_ci,
            "n_targets": len(final_targets)
        }
        print(f"\n[{key.upper()}] ρ={rho:.3f} (p={p:.1e}); Mantel p={p_m:.2e}; "
              f"Partial ρ|coexpr={prho:.3f} (p={pp:.1e}); P@10={ret['precision_at'][10]:.3f}")

        if rho > best_rho:
            best_rho, best_key = rho, key

    # Save results
    out = {
        "analysis_type": "single_gene_improved",
        "best_metric": best_key,
        "results_by_metric": results
    }
    with open("results/improved_single_gene_metrics.json", "w") as f:
        json.dump(out, f, indent=2)

    # Quick figure: distance–distance scatter for best metric
    key = best_key
    Dg, Dp, _ = combos[key]
    iu = np.triu_indices_from(Dg, 1)
    x, y = Dg[iu], Dp[iu]
    r, pval = spearmanr(x, y)
    plt.figure(figsize=(5,5))
    plt.scatter(x, y, s=6, alpha=0.5)
    plt.title(f"GeneRAIN vs Perturbation ({key})\nSpearman ρ={r:.2f}, p={pval:.1e}")
    plt.xlabel("Gene embedding distance"); plt.ylabel("Perturbation response distance")
    plt.tight_layout(); plt.savefig("figs/single_gene_dist_scatter.png", dpi=200); plt.close()

    print(f"\nBest metric: {best_key} | Results saved to results/improved_single_gene_metrics.json")
    return out

# =========================
# 7) Run
# =========================
res = run_improved_single_gene_analysis()
res


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Loading GSE133344...
Loaded: 111445 cells × 33694 genes | gemgroups: [1, 2, 3, 4, 5, 6, 7, 8]
After single-gene filter: 69686 cells (controls=11855)
Unique targets (incl CONTROL): 106
Pseudobulk: 831 (target,gemgroup) profiles
Computed matched effects for 105 targets
[QC] On-target self-effect median (CRISPRa should be ≥ 0): 5.953
HVG restriction: using 2000 HVGs out of 33694 genes
Loading GeneRAIN vectors from: /content/drive/MyDrive/dataset-gene-embed/GeneRAIN-vec.200d.txt
GeneRAIN loaded: 31769 genes × 200 dims
GeneRAIN mapped: 101/105
Building coexpression baseline on HVGs ∪ targets...
Coexpression baseline coverage: 101/101 targets
Final aligned targets across (E_p, GeneRAIN, coexpr): 101
Aligned targets: 101

[COSINE] ρ=0.042 (p=2.9e-03); Mantel p=9.00e-03; Partial ρ|coexpr=0.036 (p=9.5e-03); P@10=0.131

[CORRELATION] ρ=0.043 (p=2.3e-03); Mantel p=2.26e

{'analysis_type': 'single_gene_improved',
 'best_metric': 'correlation',
 'results_by_metric': {'cosine': {'spearman_rho': 0.04185842204470368,
   'spearman_p': 0.0029282878320825837,
   'mantel_r': 0.041858422044703684,
   'mantel_p': 0.009,
   'label_shuffle_rho': 0.04185842204470368,
   'label_shuffle_p': 0.005994005994005994,
   'coexpr_rho': 0.032884280334120726,
   'coexpr_p': 0.01944330919657188,
   'partial_rho': 0.03647629603149489,
   'partial_p': 0.00953942545908451,
   'precision_at': {5: 0.07524752475247526,
    10: 0.1306930693069307,
    20: 0.21138613861386132},
   'auprc': 0.2460943188987724,
   'p10_boot_mean': 0.20321683168316837,
   'p10_boot_ci': (0.16732673267326734, 0.24856435643564354),
   'n_targets': 101},
  'correlation': {'spearman_rho': 0.04295693020919127,
   'spearman_p': 0.00226326633868648,
   'mantel_r': 0.04295693020919127,
   'mantel_p': 0.00226326633868648,
   'label_shuffle_rho': 0.04295693020919127,
   'label_shuffle_p': 0.004995004995004995,
   '

In [None]:
# =========================
# Enhanced Gene Embedding Validation Pipeline
# =========================

!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy gprofiler-official

import os, json, gzip
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.io
from scipy import sparse
from sklearn.decomposition import TruncatedSVD
from sklearn.metrics import pairwise_distances
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import cross_val_score
from scipy.stats import spearmanr, pearsonr
from skbio.stats.distance import mantel, DistanceMatrix
import pingouin as pg
import matplotlib.pyplot as plt
from gprofiler import GProfiler

from google.colab import drive
drive.mount('/content/drive')

# ---- Edit these paths for your Drive layout ----
BASE = "/content/drive/MyDrive/dataset-gene-embed"
FILES = {
    "barcodes": f"{BASE}/GSE133344_filtered_barcodes.tsv.gz",
    "cell_id": f"{BASE}/GSE133344_filtered_cell_identities.csv.gz",
    "matrix":   f"{BASE}/GSE133344_filtered_matrix.mtx",
    "genes":    f"{BASE}/GSE133344_filtered_genes.tsv",
    "generain": f"{BASE}/GeneRAIN-vec.200d.txt",
}
os.makedirs("results", exist_ok=True)
os.makedirs("figs", exist_ok=True)

# =========================
# Core Data Loading Functions
# =========================

def load_gse133344():
    """Load Norman 2019 (GSE133344) filtered matrix into an AnnData object."""
    print("Loading GSE133344...")
    genes = pd.read_csv(FILES["genes"], sep="\t", header=None, names=["ensembl_id","gene_symbol"])

    with gzip.open(FILES["barcodes"], "rt") as f:
        barcodes = [ln.strip() for ln in f]

    cell_id = pd.read_csv(FILES["cell_id"], compression="gzip")
    X = scipy.io.mmread(FILES["matrix"]).T.tocsr()

    ad = sc.AnnData(X=X)
    ad.obs_names = barcodes
    ad.var_names = genes["gene_symbol"].values
    ad.var["ensembl_id"] = genes["ensembl_id"].values
    ad.var_names_make_unique()

    cell_id = cell_id.set_index("cell_barcode")
    common = list(set(ad.obs_names) & set(cell_id.index))
    ad = ad[common].copy()
    ad.obs = ad.obs.join(cell_id.loc[common])

    print(f"Loaded: {ad.n_obs} cells × {ad.n_vars} genes | gemgroups: {sorted(ad.obs['gemgroup'].unique())}")
    return ad

def parse_combinatorial_targets(guide_identity):
    """Parse guide identities to extract target genes"""
    if pd.isna(guide_identity):
        return []

    target_part = guide_identity.split('__')[0]
    components = target_part.split('_')

    targets = []
    for comp in components:
        if 'NegCtrl' not in comp and len(comp) > 0:
            targets.append(comp)

    # Remove duplicates while preserving order
    unique_targets = []
    for t in targets:
        if t not in unique_targets:
            unique_targets.append(t)

    return unique_targets

def categorize_perturbations(ad):
    """Categorize perturbations and add parsed target information"""
    ad.obs['parsed_targets'] = ad.obs['guide_identity'].apply(parse_combinatorial_targets)
    ad.obs['n_targets'] = ad.obs['parsed_targets'].apply(len)

    ad.obs['perturbation_type'] = 'unknown'
    ad.obs.loc[ad.obs['n_targets'] == 0, 'perturbation_type'] = 'control'
    ad.obs.loc[ad.obs['n_targets'] == 1, 'perturbation_type'] = 'single'
    ad.obs.loc[ad.obs['n_targets'] == 2, 'perturbation_type'] = 'dual'
    ad.obs.loc[ad.obs['n_targets'] > 2, 'perturbation_type'] = 'multi'

    print("Perturbation type distribution:")
    print(ad.obs['perturbation_type'].value_counts())

    return ad

def load_generain_embeddings(path: str):
    """Load GeneRAIN-vec word2vec-style file"""
    print(f"Loading GeneRAIN vectors from: {path}")
    genes, embeds = [], []
    with open(path, "r") as f:
        _ = f.readline().strip()  # header
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 2:
                continue
            genes.append(parts[0])
            embeds.append([float(x) for x in parts[1:]])

    df = pd.DataFrame(embeds, index=pd.Index(genes, name="gene"))
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]
    if df.index.duplicated().any():
        df = df[~df.index.duplicated(keep="first")]
    print(f"GeneRAIN loaded: {df.shape[0]} genes × {df.shape[1]} dims")
    return df

# =========================
# Enhanced Pseudobulk & Effects
# =========================

def make_gemgroup_matched_pseudobulk(ad, min_cells=20):
    """Create pseudobulk with gemgroup matching"""
    print("Creating gemgroup-matched pseudobulks...")

    groups = ad.obs.groupby(['guide_identity', 'gemgroup']).indices
    pseudobulk_data = []
    meta_data = []

    X = ad.X
    is_sparse = sparse.issparse(X)

    for (guide, gemgroup), cell_idx in groups.items():
        if len(cell_idx) < min_cells:
            continue

        if is_sparse:
            summed_counts = np.array(X[cell_idx].sum(axis=0)).ravel()
        else:
            summed_counts = X[cell_idx].sum(axis=0)

        pseudobulk_data.append(summed_counts)
        meta_data.append({
            'guide_identity': guide,
            'gemgroup': gemgroup,
            'n_cells': len(cell_idx),
            'targets': ad.obs.loc[ad.obs_names[cell_idx[0]], 'parsed_targets'],
            'perturbation_type': ad.obs.loc[ad.obs_names[cell_idx[0]], 'perturbation_type']
        })

    pseudobulk_df = pd.DataFrame(pseudobulk_data, columns=ad.var_names)
    pseudobulk_df.index = pd.MultiIndex.from_frame(pd.DataFrame(meta_data)[['guide_identity', 'gemgroup']])
    meta_df = pd.DataFrame(meta_data)

    print(f"Created pseudobulk for {len(pseudobulk_df)} guide-gemgroup combinations")
    return pseudobulk_df, meta_df

def compute_gemgroup_matched_effects(pseudobulk_df, meta_df, effect_threshold=None):
    """Compute effects with gemgroup-matched controls and optional filtering"""
    print("Computing gemgroup-matched effects...")

    lib_sizes = pseudobulk_df.sum(axis=1)
    cpm = pseudobulk_df.div(lib_sizes, axis=0) * 1e6
    logcpm = np.log1p(cpm)

    control_guides = set(meta_df[meta_df['perturbation_type'] == 'control']['guide_identity'])
    print(f"Found {len(control_guides)} control guides")

    all_guides = sorted(set(meta_df['guide_identity']))
    target_guides = [g for g in all_guides if g not in control_guides]

    effects = {}

    for guide in target_guides:
        try:
            guide_profiles = logcpm.loc[guide]
            if isinstance(guide_profiles, pd.Series):
                guide_profiles = pd.DataFrame([guide_profiles.values],
                                            index=[guide_profiles.name],
                                            columns=logcpm.columns)

            guide_effects = []

            for gemgroup in guide_profiles.index:
                control_profiles_in_gemgroup = []
                for control_guide in control_guides:
                    try:
                        ctrl_profile = logcpm.loc[(control_guide, gemgroup)]
                        control_profiles_in_gemgroup.append(ctrl_profile)
                    except KeyError:
                        continue

                if control_profiles_in_gemgroup:
                    if len(control_profiles_in_gemgroup) == 1:
                        control_mean = control_profiles_in_gemgroup[0]
                    else:
                        control_mean = pd.concat(control_profiles_in_gemgroup, axis=1).mean(axis=1)

                    effect = guide_profiles.loc[gemgroup] - control_mean
                    guide_effects.append(effect)

            if guide_effects:
                if len(guide_effects) == 1:
                    mean_effect = guide_effects[0]
                else:
                    mean_effect = pd.concat(guide_effects, axis=1).mean(axis=1)
                effects[guide] = mean_effect.values

        except KeyError:
            continue

    effects_df = pd.DataFrame.from_dict(effects, orient='index', columns=logcpm.columns)

    # Z-score standardize
    effects_mean = effects_df.mean(axis=0)
    effects_std = effects_df.std(axis=0, ddof=0).replace(0, 1.0)
    effects_zscore = effects_df.subtract(effects_mean, axis=1).divide(effects_std, axis=1)
    effects_zscore = effects_zscore.replace([np.inf, -np.inf], 0).fillna(0)

    # Optional: Filter by effect magnitude
    if effect_threshold is not None:
        strong_effects = (np.abs(effects_zscore).max(axis=1) > effect_threshold)
        effects_zscore = effects_zscore[strong_effects]
        print(f"Filtered to {len(effects_zscore)} guides with strong effects (|z| > {effect_threshold})")

    print(f"Computed effects for {len(effects_zscore)} guides")
    return effects_zscore

# =========================
# Enhanced Embedding Processing
# =========================

def create_random_baseline_embeddings(target_combinations, embedding_dim=200):
    """Create random embeddings as baseline"""
    np.random.seed(42)
    random_embeddings = {}

    for combo_name, gene_list in target_combinations.items():
        if len(gene_list) == 0:
            # Control - use zero vector
            random_embeddings[combo_name] = np.zeros(embedding_dim)
        elif len(gene_list) == 1:
            # Single gene - random vector
            random_embeddings[combo_name] = np.random.normal(0, 1, embedding_dim)
        else:
            # Multiple genes - average of random vectors
            gene_embeds = np.random.normal(0, 1, (len(gene_list), embedding_dim))
            random_embeddings[combo_name] = np.mean(gene_embeds, axis=0)

    random_df = pd.DataFrame.from_dict(random_embeddings, orient='index')
    print(f"Created random baseline: {random_df.shape[0]} combinations, {random_df.shape[1]} dimensions")
    return random_df

def create_combined_embeddings_enhanced(target_combinations, generain_embeddings, method='average'):
    """Enhanced embedding combination with better handling"""
    combined_embeddings = {}
    missing_genes = set()

    for combo_name, gene_list in target_combinations.items():
        if len(gene_list) == 0:
            combined_embeddings[combo_name] = np.zeros(generain_embeddings.shape[1])
        elif len(gene_list) == 1:
            gene = gene_list[0]
            if gene in generain_embeddings.index:
                combined_embeddings[combo_name] = generain_embeddings.loc[gene].values
            else:
                missing_genes.add(gene)
                continue
        else:
            available_genes = [g for g in gene_list if g in generain_embeddings.index]
            missing_genes.update(set(gene_list) - set(available_genes))

            if len(available_genes) == 0:
                continue

            gene_embeds = np.stack([generain_embeddings.loc[g].values for g in available_genes])

            if method == 'average':
                combined_embeddings[combo_name] = np.mean(gene_embeds, axis=0)
            elif method == 'sum':
                combined_embeddings[combo_name] = np.sum(gene_embeds, axis=0)
            elif method == 'hadamard':
                combined_embeddings[combo_name] = np.prod(gene_embeds, axis=0)

    if missing_genes:
        print(f"Missing genes in embeddings: {len(missing_genes)} unique genes")

    combo_df = pd.DataFrame.from_dict(combined_embeddings, orient='index')
    print(f"Created {method} embeddings: {combo_df.shape[0]} combinations")
    return combo_df

# =========================
# Enhanced Statistical Analysis
# =========================

def pairwise_cosine_robust(M):
    """Robust cosine distance computation"""
    M = np.nan_to_num(M, nan=0.0, posinf=1.0, neginf=-1.0)

    row_norms = np.linalg.norm(M, axis=1)
    zero_norm_mask = row_norms == 0

    if zero_norm_mask.any():
        M[zero_norm_mask] += np.random.normal(0, 1e-6, size=(zero_norm_mask.sum(), M.shape[1]))

    try:
        distances = pairwise_distances(M, metric='cosine')
        np.fill_diagonal(distances, 0.0)
        distances = np.nan_to_num(distances, nan=1.0, posinf=1.0, neginf=0.0)
        return distances
    except Exception as e:
        print(f"Distance computation failed: {e}")
        return np.eye(M.shape[0])

def cross_validate_by_gemgroup(effects_df, meta_df, embeddings_df, train_gemgroups, test_gemgroups):
    """Cross-validation splitting by gemgroups"""
    print(f"Cross-validation: train on gemgroups {train_gemgroups}, test on {test_gemgroups}")

    # Get guides that appear in both train and test sets
    train_guides = set()
    test_guides = set()

    for guide in effects_df.index:
        guide_gemgroups = set(meta_df[meta_df['guide_identity'] == guide]['gemgroup'])
        if guide_gemgroups & set(train_gemgroups):
            train_guides.add(guide)
        if guide_gemgroups & set(test_gemgroups):
            test_guides.add(guide)

    common_guides = list(train_guides & test_guides & set(embeddings_df.index))
    print(f"Guides in both train/test: {len(common_guides)}")

    if len(common_guides) < 20:
        return None

    # Align data
    effects_aligned = effects_df.loc[common_guides]
    embeddings_aligned = embeddings_df.loc[common_guides]

    # Compute distances
    Dp = pairwise_cosine_robust(effects_aligned.values)
    Dg = pairwise_cosine_robust(embeddings_aligned.values)

    # Correlation
    iu = np.triu_indices_from(Dg, k=1)
    rho, p = spearmanr(Dg[iu], Dp[iu])

    return {
        'n_guides': len(common_guides),
        'spearman_rho': float(rho),
        'spearman_p': float(p)
    }

def compute_enhanced_metrics(Dg, Dp, n_permutations=999):
    """Comprehensive statistical analysis"""
    results = {}

    # Basic correlation
    iu = np.triu_indices_from(Dg, k=1)
    dist_g, dist_p = Dg[iu], Dp[iu]

    rho, p = spearmanr(dist_g, dist_p)
    results['spearman_rho'] = float(rho)
    results['spearman_p'] = float(p)

    # Mantel test
    try:
        ids = [f"X{i}" for i in range(Dg.shape[0])]
        m1, m2 = DistanceMatrix(Dg, ids=ids), DistanceMatrix(Dp, ids=ids)
        r_m, p_m, _ = mantel(m1, m2, method='spearman', permutations=n_permutations)
        results['mantel_r'] = float(r_m)
        results['mantel_p'] = float(p_m)
    except:
        results['mantel_r'] = float(rho)
        results['mantel_p'] = float(p)

    # Retrieval metrics
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf)
    np.fill_diagonal(Dp2, np.inf)

    precisions = {k: [] for k in [5, 10, 20]}
    for i in range(n):
        rank_g = np.argsort(Dg2[i])
        rank_p = np.argsort(Dp2[i])

        for k in [5, 10, 20]:
            if n > k:
                overlap = len(set(rank_g[:k]) & set(rank_p[:k]))
                precisions[k].append(overlap / k)

    results['precision_at'] = {k: float(np.mean(v)) for k, v in precisions.items()}
    results['random_baseline'] = {k: float(k/(n-1)) for k in [5, 10, 20]}

    return results

# =========================
# Biological Analysis Functions
# =========================

def analyze_pathway_enrichment(high_concordance_pairs, low_concordance_pairs):
    """Analyze pathway enrichment of concordant vs discordant gene pairs"""
    print("Analyzing pathway enrichment...")

    try:
        gp = GProfiler(return_dataframe=True)

        # Analyze high concordance genes
        if len(high_concordance_pairs) > 5:
            high_enrichment = gp.profile(
                organism='hsapiens',
                query=high_concordance_pairs,
                sources=['GO:BP', 'KEGG', 'REACTOME']
            )
            print(f"High concordance enrichment: {len(high_enrichment)} terms")
        else:
            high_enrichment = pd.DataFrame()

        # Analyze low concordance genes
        if len(low_concordance_pairs) > 5:
            low_enrichment = gp.profile(
                organism='hsapiens',
                query=low_concordance_pairs,
                sources=['GO:BP', 'KEGG', 'REACTOME']
            )
            print(f"Low concordance enrichment: {len(low_enrichment)} terms")
        else:
            low_enrichment = pd.DataFrame()

        return high_enrichment, low_enrichment

    except Exception as e:
        print(f"Pathway analysis failed: {e}")
        return pd.DataFrame(), pd.DataFrame()

def identify_concordant_gene_pairs(Dg, Dp, guide_names, percentile_threshold=90):
    """Identify gene pairs with high/low embedding-perturbation concordance"""
    iu = np.triu_indices_from(Dg, k=1)
    dist_g, dist_p = Dg[iu], Dp[iu]

    # Compute concordance (negative of absolute difference in ranks)
    rank_g = np.argsort(np.argsort(dist_g))
    rank_p = np.argsort(np.argsort(dist_p))
    concordance = -np.abs(rank_g - rank_p)

    # Get high and low concordance pairs
    high_threshold = np.percentile(concordance, percentile_threshold)
    low_threshold = np.percentile(concordance, 100 - percentile_threshold)

    high_idx = np.where(concordance >= high_threshold)[0]
    low_idx = np.where(concordance <= low_threshold)[0]

    # Extract gene pairs
    row_idx, col_idx = iu
    high_pairs = [(guide_names[row_idx[i]], guide_names[col_idx[i]]) for i in high_idx]
    low_pairs = [(guide_names[row_idx[i]], guide_names[col_idx[i]]) for i in low_idx]

    # Flatten to unique genes
    high_genes = list(set([g for pair in high_pairs for g in pair]))
    low_genes = list(set([g for pair in low_pairs for g in pair]))

    print(f"High concordance pairs: {len(high_pairs)} ({len(high_genes)} unique genes)")
    print(f"Low concordance pairs: {len(low_pairs)} ({len(low_genes)} unique genes)")

    return high_genes, low_genes, high_pairs, low_pairs

def supervised_prediction_analysis(Dg, Dp):
    """Test supervised learning for predicting perturbation similarity"""
    print("Running supervised prediction analysis...")

    # Prepare data: embedding distances -> perturbation distances
    iu = np.triu_indices_from(Dg, k=1)
    X = Dg[iu].reshape(-1, 1)  # Embedding distances as features
    y = Dp[iu]  # Perturbation distances as targets

    # Random Forest regression
    rf = RandomForestRegressor(n_estimators=100, random_state=42)
    scores = cross_val_score(rf, X, y, cv=5, scoring='r2')

    # Fit for feature importance
    rf.fit(X, y)

    results = {
        'cv_r2_mean': float(np.mean(scores)),
        'cv_r2_std': float(np.std(scores)),
        'train_r2': float(rf.score(X, y))
    }

    print(f"Supervised prediction R²: {results['cv_r2_mean']:.4f} ± {results['cv_r2_std']:.4f}")
    return results

# =========================
# Main Enhanced Analysis
# =========================

def run_enhanced_analysis():
    """Run comprehensive enhanced analysis"""
    print("=== ENHANCED GENE EMBEDDING VALIDATION ===\n")

    # Load data
    adata = load_gse133344()
    adata = categorize_perturbations(adata)
    generain_emb = load_generain_embeddings(FILES['generain'])

    # Create pseudobulk and effects with optional filtering
    pseudobulk_df, meta_df = make_gemgroup_matched_pseudobulk(adata, min_cells=20)

    # Test with and without effect filtering
    effects_all = compute_gemgroup_matched_effects(pseudobulk_df, meta_df, effect_threshold=None)
    effects_strong = compute_gemgroup_matched_effects(pseudobulk_df, meta_df, effect_threshold=2.0)

    # Create target combinations mapping
    target_combinations = {}
    for _, row in meta_df.iterrows():
        guide = row['guide_identity']
        targets = row['targets']
        if guide in effects_all.index:  # Only include guides with computed effects
            target_combinations[guide] = targets

    print(f"Target combinations for analysis: {len(target_combinations)}")

    # Create embeddings with different methods
    embedding_methods = ['average', 'sum']
    embedding_results = {}

    for method in embedding_methods:
        embedding_results[method] = create_combined_embeddings_enhanced(
            target_combinations, generain_emb, method=method)

    # Create random baseline
    embedding_results['random'] = create_random_baseline_embeddings(target_combinations)

    # Analysis results storage
    analysis_results = {
        'all_effects': {},
        'strong_effects': {},
        'cross_validation': {},
        'biological_analysis': {}
    }

    # Test each combination of effects and embeddings
    for effect_name, effects_df in [('all', effects_all), ('strong', effects_strong)]:
        print(f"\n--- Analysis with {effect_name} effects ({len(effects_df)} guides) ---")

        effect_results = {}

        for emb_method, embeddings_df in embedding_results.items():
            print(f"\nTesting {emb_method} embeddings...")

            # Align datasets
            common_guides = list(set(effects_df.index) & set(embeddings_df.index))
            print(f"Common guides: {len(common_guides)}")

            if len(common_guides) < 20:
                continue

            effects_aligned = effects_df.loc[common_guides]
            embeddings_aligned = embeddings_df.loc[common_guides]

            # Compute distances
            Dp = pairwise_cosine_robust(effects_aligned.values)
            Dg = pairwise_cosine_robust(embeddings_aligned.values)

            # Enhanced metrics
            metrics = compute_enhanced_metrics(Dg, Dp, n_permutations=999)
            metrics['n_guides'] = len(common_guides)

            # Supervised prediction (only for non-random)
            if emb_method != 'random':
                pred_results = supervised_prediction_analysis(Dg, Dp)
                metrics.update(pred_results)

            effect_results[emb_method] = metrics

            print(f"  Correlation: ρ = {metrics['spearman_rho']:.3f}, p = {metrics['spearman_p']:.2e}")
            print(f"  Precision@10: {metrics['precision_at'][10]:.3f} vs {metrics['random_baseline'][10]:.3f} random")

            # Biological analysis (only for best non-random method)
            if emb_method == 'average' and effect_name == 'strong':
                try:
                    high_genes, low_genes, high_pairs, low_pairs = identify_concordant_gene_pairs(
                        Dg, Dp, common_guides, percentile_threshold=90)

                    high_enrichment, low_enrichment = analyze_pathway_enrichment(high_genes, low_genes)

                    analysis_results['biological_analysis'] = {
                        'high_concordance_genes': high_genes[:20],  # Top 20 for storage
                        'low_concordance_genes': low_genes[:20],
                        'high_enrichment_terms': len(high_enrichment),
                        'low_enrichment_terms': len(low_enrichment)
                    }
                except Exception as e:
                    print(f"Biological analysis failed: {e}")

        analysis_results[f'{effect_name}_effects'] = effect_results

    # Cross-validation analysis
    print(f"\n--- Cross-validation by gemgroups ---")
    train_gemgroups = [1, 2, 3, 4, 5, 6]
    test_gemgroups = [7, 8]

    cv_results = {}
    for emb_method, embeddings_df in embedding_results.items():
        if emb_method == 'random':
            continue

        cv_result = cross_validate_by_gemgroup(
            effects_all, meta_df, embeddings_df, train_gemgroups, test_gemgroups)

        if cv_result:
            cv_results[emb_method] = cv_result
            print(f"  {emb_method}: ρ = {cv_result['spearman_rho']:.3f}, p = {cv_result['spearman_p']:.2e}")

    analysis_results['cross_validation'] = cv_results

    # Find best method overall
    best_method = None
    best_rho = -1

    for method, results in analysis_results['strong_effects'].items():
        if method != 'random' and results['spearman_rho'] > best_rho:
            best_rho = results['spearman_rho']
            best_method = method

    # Generate summary plot
    create_summary_plots(analysis_results, best_method)

    # Save comprehensive results
    final_results = {
        'analysis_type': 'enhanced_comprehensive',
        'best_method': best_method,
        'improvements_applied': [
            'effect_magnitude_filtering',
            'cross_validation_by_gemgroups',
            'multiple_embedding_methods',
            'random_baseline_comparison',
            'biological_pathway_analysis',
            'supervised_prediction_analysis'
        ],
        'results': analysis_results
    }

    with open("results/enhanced_analysis_results.json", "w") as f:
        json.dump(final_results, f, indent=2)

    # Print summary
    print(f"\n=== ENHANCED ANALYSIS SUMMARY ===")
    print(f"Best embedding method: {best_method}")

    if best_method and best_method in analysis_results['strong_effects']:
        best_results = analysis_results['strong_effects'][best_method]
        print(f"Strong effects analysis:")
        print(f"  Guides analyzed: {best_results['n_guides']}")
        print(f"  Correlation: ρ = {best_results['spearman_rho']:.3f}, p = {best_results['spearman_p']:.2e}")
        print(f"  Precision@10: {best_results['precision_at'][10]:.3f} vs {best_results['random_baseline'][10]:.3f} random")
        print(f"  Improvement over random: {best_results['precision_at'][10] / best_results['random_baseline'][10]:.1f}x")

        if 'cv_r2_mean' in best_results:
            print(f"  Supervised prediction R²: {best_results['cv_r2_mean']:.4f}")

    if cv_results and best_method in cv_results:
        print(f"Cross-validation:")
        print(f"  ρ = {cv_results[best_method]['spearman_rho']:.3f}, p = {cv_results[best_method]['spearman_p']:.2e}")

    return final_results

def create_summary_plots(results, best_method):
    """Create comprehensive summary plots"""
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

    # Plot 1: Method comparison for strong effects
    if 'strong_effects' in results:
        methods = []
        rhos = []
        precisions = []

        for method, res in results['strong_effects'].items():
            methods.append(method)
            rhos.append(res['spearman_rho'])
            precisions.append(res['precision_at'][10])

        axes[0,0].bar(methods, rhos)
        axes[0,0].set_title('Correlation by Method')
        axes[0,0].set_ylabel('Spearman ρ')
        axes[0,0].tick_params(axis='x', rotation=45)

        axes[0,1].bar(methods, precisions)
        axes[0,1].set_title('Precision@10 by Method')
        axes[0,1].set_ylabel('Precision@10')
        axes[0,1].tick_params(axis='x', rotation=45)

    # Plot 2: All vs Strong effects comparison
    if best_method and best_method in results.get('all_effects', {}) and best_method in results.get('strong_effects', {}):
        categories = ['All Effects', 'Strong Effects']
        rhos = [
            results['all_effects'][best_method]['spearman_rho'],
            results['strong_effects'][best_method]['spearman_rho']
        ]

        axes[1,0].bar(categories, rhos)
        axes[1,0].set_title(f'Effect Filtering Impact ({best_method})')
        axes[1,0].set_ylabel('Spearman ρ')

    # Plot 3: Cross-validation results
    if 'cross_validation' in results and results['cross_validation']:
        cv_methods = list(results['cross_validation'].keys())
        cv_rhos = [results['cross_validation'][m]['spearman_rho'] for m in cv_methods]

        axes[1,1].bar(cv_methods, cv_rhos)
        axes[1,1].set_title('Cross-Validation Results')
        axes[1,1].set_ylabel('Spearman ρ')
        axes[1,1].tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.savefig("figs/enhanced_analysis_summary.png", dpi=200, bbox_inches='tight')
    plt.close()

# =========================
# Run Enhanced Analysis
# =========================

print("Enhanced analysis pipeline ready.")
print("Execute: results = run_enhanced_analysis()")

# Uncomment to run immediately:
results = run_enhanced_analysis()

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━[0m [32m1.7/2.1 MB[0m [31m52.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m36.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.9/169.9 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.7/9.7 MB[0m [31m88.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m204.4/204.4 kB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m58.2/58.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.3/12.3 MB[0m [31m76.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# =========================
# COMPLETE, CORRECTED, SINGLE-CELL PIPELINE
# =========================

!pip -q install scanpy anndata scikit-bio pingouin scikit-learn matplotlib pandas numpy scipy statsmodels

import os, json, gzip, warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import scanpy as sc
import scipy.io
from scipy import sparse
from scipy.stats import spearmanr, mannwhitneyu, fisher_exact
from skbio.stats.distance import mantel, DistanceMatrix
from sklearn.metrics import pairwise_distances
from sklearn.decomposition import PCA
from sklearn.cross_decomposition import CCA
from scipy.linalg import orthogonal_procrustes
from statsmodels.stats.multitest import multipletests
import matplotlib.pyplot as plt
import seaborn as sns

# ---------- Colab Drive ----------
from google.colab import drive
drive.mount('/content/drive')

# ---------- Paths ----------
BASE = "/content/drive/MyDrive/dataset-gene-embed"
FILES = {
    "barcodes": f"{BASE}/GSE133344_filtered_barcodes.tsv.gz",
    "cell_id": f"{BASE}/GSE133344_filtered_cell_identities.csv.gz",
    "matrix":   f"{BASE}/GSE133344_filtered_matrix.mtx",
    "genes":    f"{BASE}/GSE133344_filtered_genes.tsv",
    "generain": f"{BASE}/GeneRAIN-vec.200d.txt",
}
os.makedirs("results", exist_ok=True)
os.makedirs("figs", exist_ok=True)

# =========================
# Utilities (serialization, plotting helpers)
# =========================

def make_jsonable(obj):
    """Recursively convert NumPy/Pandas objects to JSON-safe types."""
    import numpy as _np
    import pandas as _pd
    if isinstance(obj, (int, float, str, bool)) or obj is None:
        return obj
    if isinstance(obj, (_np.integer, _np.floating)):
        return obj.item()
    if isinstance(obj, _np.ndarray):
        return obj.tolist()
    if isinstance(obj, (_pd.Series, _pd.Index)):
        return obj.astype(object).tolist()
    if isinstance(obj, _pd.DataFrame):
        return {"__dataframe__": True, "columns": obj.columns.tolist(),
                "index": obj.index.astype(str).tolist(), "data": obj.astype(object).values.tolist()}
    if isinstance(obj, dict):
        return {k: make_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [make_jsonable(v) for v in obj]
    # fallback
    try:
        return obj.__repr__()
    except Exception:
        return str(type(obj))

def save_json(path, payload):
    with open(path, "w") as f:
        json.dump(payload, f, indent=2, default=make_jsonable)
    print(f"Saved JSON: {path}")

# =========================
# 1) Load data
# =========================

def load_gse133344():
    print("Loading GSE133344...")
    genes = pd.read_csv(FILES["genes"], sep="\t", header=None, names=["ensembl_id","gene_symbol"])
    with gzip.open(FILES["barcodes"], "rt") as f:
        barcodes = [ln.strip() for ln in f]
    cell_id = pd.read_csv(FILES["cell_id"], compression="gzip")
    X = scipy.io.mmread(FILES["matrix"]).T.tocsr()

    ad = sc.AnnData(X=X)
    ad.obs_names = barcodes
    ad.var_names = genes["gene_symbol"].values
    ad.var["ensembl_id"] = genes["ensembl_id"].values
    ad.var_names_make_unique()

    cell_id = cell_id.set_index("cell_barcode")
    common = list(set(ad.obs_names) & set(cell_id.index))
    ad = ad[common].copy()
    ad.obs = ad.obs.join(cell_id.loc[common])

    print(f"Loaded: {ad.n_obs} cells × {ad.n_vars} genes | gemgroups: {sorted(ad.obs['gemgroup'].unique())}")
    return ad

def parse_combinatorial_targets(guide_identity):
    if pd.isna(guide_identity):
        return []
    target_part = guide_identity.split('__')[0]
    components = [c for c in target_part.split('_') if c and 'NegCtrl' not in c]
    # dedupe keep order
    seen, out = set(), []
    for c in components:
        if c not in seen:
            seen.add(c); out.append(c)
    return out

def categorize_perturbations(ad):
    ad.obs['parsed_targets'] = ad.obs['guide_identity'].apply(parse_combinatorial_targets)
    ad.obs['n_targets'] = ad.obs['parsed_targets'].apply(len)
    ad.obs['perturbation_type'] = 'unknown'
    ad.obs.loc[ad.obs['n_targets'] == 0, 'perturbation_type'] = 'control'
    ad.obs.loc[ad.obs['n_targets'] == 1, 'perturbation_type'] = 'single'
    ad.obs.loc[ad.obs['n_targets'] == 2, 'perturbation_type'] = 'dual'
    ad.obs.loc[ad.obs['n_targets'] > 2, 'perturbation_type'] = 'multi'
    print("Perturbation type distribution:")
    print(ad.obs['perturbation_type'].value_counts())
    return ad

def load_generain_embeddings(path: str):
    print(f"Loading GeneRAIN vectors from: {path}")
    genes, embeds = [], []
    with open(path, "r") as f:
        _ = f.readline().strip()  # header
        for ln in f:
            parts = ln.strip().split()
            if len(parts) < 2: continue
            genes.append(parts[0]); embeds.append([float(x) for x in parts[1:]])
    df = pd.DataFrame(embeds, index=pd.Index(genes, name="gene"))
    df.columns = [f"generain_{i}" for i in range(df.shape[1])]
    if df.index.duplicated().any():
        df = df[~df.index.duplicated(keep="first")]
    print(f"GeneRAIN loaded: {df.shape[0]} genes × {df.shape[1]} dims")
    return df

# =========================
# 2) Pseudobulk & effects
# =========================

def make_gemgroup_matched_combinatorial_pseudobulk(ad, min_cells=20):
    print("Creating gemgroup-matched pseudobulks...")
    groups = ad.obs.groupby(['guide_identity', 'gemgroup']).indices
    data, meta = [], []
    X = ad.X; is_sparse = sparse.issparse(X)
    for (guide, gemgroup), idx in groups.items():
        if len(idx) < min_cells: continue
        summed = np.array(X[idx].sum(axis=0)).ravel() if is_sparse else X[idx].sum(axis=0)
        data.append(summed)
        meta.append({
            'guide_identity': guide, 'gemgroup': gemgroup, 'n_cells': len(idx),
            'targets': ad.obs.loc[ad.obs_names[idx[0]], 'parsed_targets'],
            'perturbation_type': ad.obs.loc[ad.obs_names[idx[0]], 'perturbation_type']
        })
    pb = pd.DataFrame(data, columns=ad.var_names)
    pb.index = pd.MultiIndex.from_frame(pd.DataFrame(meta)[['guide_identity', 'gemgroup']])
    meta_df = pd.DataFrame(meta)
    print(f"Created pseudobulk for {len(pb)} guide-gemgroup combinations")
    return pb, meta_df

def compute_gemgroup_matched_combinatorial_effects(pseudobulk_df, meta_df):
    print("Computing gemgroup-matched effects...")
    lib = pseudobulk_df.sum(axis=1)
    cpm = pseudobulk_df.div(lib, axis=0) * 1e6
    logcpm = np.log1p(cpm)

    control_guides = set(meta_df[meta_df['perturbation_type'] == 'control']['guide_identity'])
    print(f"Found {len(control_guides)} control guides")
    all_guides = sorted(set(meta_df['guide_identity']))
    target_guides = [g for g in all_guides if g not in control_guides]

    effects = {}
    for guide in target_guides:
        try:
            gp = logcpm.loc[guide]
            if isinstance(gp, pd.Series):
                gp = pd.DataFrame([gp.values], index=[gp.name], columns=logcpm.columns)
            per_group = []
            for gg in gp.index:
                ctrls = []
                for cg in control_guides:
                    try: ctrls.append(logcpm.loc[(cg, gg)])
                    except KeyError: continue
                if not ctrls: continue
                ctrl_mean = pd.concat(ctrls, axis=1).mean(axis=1) if len(ctrls) > 1 else ctrls[0]
                per_group.append(gp.loc[gg] - ctrl_mean)
            if per_group:
                mean_effect = pd.concat(per_group, axis=1).mean(axis=1) if len(per_group)>1 else per_group[0]
                effects[guide] = mean_effect.values
        except KeyError:
            continue
    eff = pd.DataFrame.from_dict(effects, orient='index', columns=logcpm.columns)
    mu, sd = eff.mean(axis=0), eff.std(axis=0, ddof=0).replace(0, 1.0)
    eff_z = eff.subtract(mu, axis=1).divide(sd, axis=1)
    eff_z = eff_z.replace([np.inf, -np.inf], 0).fillna(0)
    print(f"Computed effects for {len(eff_z)} guides")
    return eff_z

# =========================
# 3) Embeddings (combine genes for combos)
# =========================

def create_combined_embeddings(target_combinations, generain_embeddings, method='average'):
    combined, missing = {}, set()
    for combo_name, gene_list in target_combinations.items():
        if len(gene_list) == 0:
            combined[combo_name] = np.zeros(generain_embeddings.shape[1]); continue
        available = [g for g in gene_list if g in generain_embeddings.index]
        missing.update(set(gene_list) - set(available))
        if not available: continue
        mat = np.stack([generain_embeddings.loc[g].values for g in available])
        if method == 'sum': vec = np.sum(mat, axis=0)
        elif method == 'hadamard': vec = np.prod(mat, axis=0)
        else: vec = np.mean(mat, axis=0)
        combined[combo_name] = vec
    if missing:
        print(f"Missing genes in embeddings: {len(missing)} unique genes")
    df = pd.DataFrame.from_dict(combined, orient='index')
    print(f"Created {method} embeddings: {df.shape[0]} combinations")
    return df

# =========================
# 4) Distance + metrics
# =========================

def pairwise_cosine(M):
    M = np.nan_to_num(M, nan=0.0, posinf=1.0, neginf=-1.0)
    norms = np.linalg.norm(M, axis=1)
    if np.any(norms == 0):
        M[norms==0] += np.random.normal(0, 1e-6, size=(np.sum(norms==0), M.shape[1]))
    D = pairwise_distances(M, metric='cosine')
    np.fill_diagonal(D, 0.0)
    return np.nan_to_num(D, nan=1.0, posinf=1.0, neginf=0.0)

def mantel_and_spearman(D1, D2, n_perms=999):
    iu = np.triu_indices_from(D1, k=1)
    x, y = D1[iu], D2[iu]
    mask = ~(np.isnan(x) | np.isnan(y) | np.isinf(x) | np.isinf(y))
    x, y = x[mask], y[mask]
    if len(x)==0: return 0.0, 1.0, 0.0, 1.0
    rho, p = spearmanr(x, y)
    try:
        ids = [f"X{i}" for i in range(D1.shape[0])]
        r_m, p_m, _ = mantel(DistanceMatrix(D1, ids=ids),
                             DistanceMatrix(D2, ids=ids),
                             method='spearman', permutations=n_perms)
    except Exception:
        r_m, p_m = rho, p
    return float(rho), float(p), float(r_m), float(p_m)

def retrieval_metrics(Dg, Dp, ks=(5,10,20)):
    n = Dg.shape[0]
    Dg2, Dp2 = Dg.copy(), Dp.copy()
    np.fill_diagonal(Dg2, np.inf); np.fill_diagonal(Dp2, np.inf)
    precisions = {k: [] for k in ks}
    for i in range(n):
        rg, rp = np.argsort(Dg2[i]), np.argsort(Dp2[i])
        for k in ks:
            if n>k:
                precisions[k].append(len(set(rg[:k]) & set(rp[:k]))/k)
    return {"precision_at": {k: float(np.mean(v)) for k,v in precisions.items()}}

# =========================
# 5) Biology helpers (families, per-gene, plots)
# =========================

def load_curated_gene_families():
    # Curated, concise sets—enough for enrichment checks
    TF = {'JUN','FOS','MYC','MAX','TP53','STAT1','STAT3','NFKB1','RELA','E2F1','E2F2','GATA1','GATA2',
          'KLF1','CEBPA','RUNX1','ETS1','ETS2','FOXA1','FOXO1','FOXO3','LEF1','TCF7'}
    kinase = {'AKT1','AKT2','PIK3CA','PIK3CB','MTOR','MAPK1','MAPK3','MAPK8','MAPK14','MAP2K1','MAP2K2',
              'CDK1','CDK2','CDK4','CDK6','GSK3B','SRC','JAK1','JAK2','EGFR'}
    chromatin = {'HDAC1','HDAC2','SIRT1','EP300','CREBBP','KAT2A','KAT2B','BRD4','CHD1','SMARCA4','ARID1A'}
    metabolic = {'PFKM','PKM','LDHA','G6PD','GAPDH','ENO1','HK2','FASN','ACACA','CPT1A','ACLY','SDHA'}
    return {'TF': TF, 'kinase': kinase, 'chromatin': chromatin, 'metabolic': metabolic}

def assign_gene_families(genes, fams):
    df = pd.DataFrame(index=genes)
    df['gene_family'] = 'other'
    df.loc[[g for g in genes if g.upper() in fams['TF']], 'gene_family'] = 'transcription_factor'
    df.loc[[g for g in genes if g.upper() in fams['kinase']], 'gene_family'] = 'kinase'
    df.loc[[g for g in genes if g.upper() in fams['chromatin']], 'gene_family'] = 'chromatin_regulator'
    df.loc[[g for g in genes if g.upper() in fams['metabolic']], 'gene_family'] = 'metabolic_enzyme'
    return df

def compute_per_gene_concordance(Dg, Dp, gene_names):
    print("Computing per-gene concordance...")
    n = len(gene_names)
    rows = []
    for i,g in enumerate(gene_names):
        mask = np.ones(n, dtype=bool); mask[i]=False
        r, p = spearmanr(Dg[i,mask], Dp[i,mask])
        if np.isnan(r): r, p = 0.0, 1.0
        rows.append((g, float(r), float(p)))
    out = pd.DataFrame(rows, columns=['gene','rho','p']).set_index('gene')
    out['q'] = multipletests(out['p'], method='fdr_bh')[1]
    out = out.sort_values('rho', ascending=False)
    print(f"Per-gene: mean ρ={out['rho'].mean():.3f}, max ρ={out['rho'].max():.3f}, sig(q<0.05)={(out['q']<0.05).sum()}")
    return out

def create_procrustes_alignment_plots(embeddings_df, effects_df, label="all"):
    print("Creating Procrustes alignment visualization...")
    pca_emb, pca_pert = PCA(n_components=2, random_state=42), PCA(n_components=2, random_state=42)
    X_emb = pca_emb.fit_transform(embeddings_df.values)
    X_pert = pca_pert.fit_transform(effects_df.values)
    R,_ = orthogonal_procrustes(X_emb, X_pert)
    X_al = X_emb @ R

    fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
    for ax, xi, yi, ttl in [(axes[0], X_al[:,0], X_pert[:,0], "PC1"),
                            (axes[1], X_al[:,1], X_pert[:,1], "PC2")]:
        r,p = spearmanr(xi, yi)
        ax.scatter(xi, yi, s=25, alpha=0.7)
        z = np.polyfit(xi, yi, 1); ax.plot(xi, np.poly1d(z)(xi), "r--", lw=1)
        ax.set_title(f'{ttl} Alignment\nρ = {r:.3f}, p = {p:.2e}')
        ax.set_xlabel(f'Embedding {ttl} (aligned)'); ax.set_ylabel(f'Perturbation {ttl}'); ax.grid(True, alpha=0.3)
    axes[2].scatter(X_al[:,0], X_al[:,1], s=20, alpha=0.6, label='Embedding (aligned)')
    axes[2].scatter(X_pert[:,0], X_pert[:,1], s=20, alpha=0.6, label='Perturbation', marker='s')
    axes[2].legend(); axes[2].set_title('2D Space Overlay'); axes[2].grid(True, alpha=0.3)
    plt.tight_layout(); outpath = f"figs/procrustes_{label}.png"
    plt.savefig(outpath, dpi=200, bbox_inches='tight'); plt.close()
    return {"pc_var_emb": pca_emb.explained_variance_ratio_,
            "pc_var_pert": pca_pert.explained_variance_ratio_,
            "plot_path": outpath}

# =========================
# 6) Main analysis (global + per-gene + CCA)
# =========================

def run_complete_biological_analysis():
    print("="*70); print("COMPLETE BIOLOGICAL GENE EMBEDDING ANALYSIS"); print("="*70)

    # Load
    adata = load_gse133344()
    adata = categorize_perturbations(adata)
    generain = load_generain_embeddings(FILES['generain'])

    # Pseudobulk + effects
    pb, meta = make_gemgroup_matched_combinatorial_pseudobulk(adata, min_cells=20)
    effects = compute_gemgroup_matched_combinatorial_effects(pb, meta)

    # Target mapping
    combos = {}
    for _,row in meta.iterrows():
        g = row['guide_identity']; t = row['targets']
        if g in effects.index: combos[g] = t
    print(f"Target combinations for analysis: {len(combos)}")

    # Embeddings (avg & sum)
    emb_by = {m: create_combined_embeddings(combos, generain, method=m) for m in ['average','sum']}

    # Choose best by Spearman with perturbation distances
    best_method, best_stat = None, -np.inf
    method_stats = {}
    for m, emb_df in emb_by.items():
        common = sorted(set(effects.index) & set(emb_df.index))
        print(f"\n--- Testing {m} ---\nCommon guides: {len(common)}")
        if len(common) < 20: continue
        Eff = effects.loc[common]; Emb = emb_df.loc[common]
        Dp, Dg = pairwise_cosine(Eff.values), pairwise_cosine(Emb.values)
        rho, p, r_m, p_m = mantel_and_spearman(Dg, Dp, n_perms=999)
        retrieval = retrieval_metrics(Dg, Dp)
        method_stats[m] = {"n": len(common), "rho": rho, "p": p,
                           "mantel_r": r_m, "mantel_p": p_m,
                           "precision_at": retrieval['precision_at']}
        print(f"Correlation: ρ={rho:.3f}, p={p:.2e} | P@10={retrieval['precision_at'].get(10, np.nan):.3f}")
        if rho > best_stat:
            best_stat, best_method = rho, m

    print(f"\nBest embedding method: {best_method}")

    # Align best for global reporting
    best_emb = emb_by[best_method]
    common_all = sorted(set(effects.index) & set(best_emb.index))
    Eff_all, Emb_all = effects.loc[common_all], best_emb.loc[common_all]
    Dp_all, Dg_all = pairwise_cosine(Eff_all.values), pairwise_cosine(Emb_all.values)
    iu = np.triu_indices_from(Dg_all, k=1); rho_global, p_global = spearmanr(Dg_all[iu], Dp_all[iu])
    print(f"Overall correlation: ρ = {rho_global:.3f}, p = {p_global:.2e}")

    # ---------- Per-gene analysis on SINGLE-GENE guides ----------
    single_guides = [g for g in common_all if len(parse_combinatorial_targets(g))==1]
    if len(single_guides) >= 20:
        # map guide->gene name
        gene_names = [parse_combinatorial_targets(g)[0] for g in single_guides]
        Emb_sg = pd.DataFrame(Emb_all.loc[single_guides].values, index=gene_names, columns=Emb_all.columns)
        Eff_sg = pd.DataFrame(Eff_all.loc[single_guides].values, index=gene_names, columns=Eff_all.columns)
        # remove duplicates
        Emb_sg, Eff_sg = Emb_sg[~Emb_sg.index.duplicated(keep='first')], Eff_sg[~Eff_sg.index.duplicated(keep='first')]
        Dg_sg, Dp_sg = pairwise_cosine(Emb_sg.values), pairwise_cosine(Eff_sg.values)
        per_gene = compute_per_gene_concordance(Dg_sg, Dp_sg, Emb_sg.index.tolist())

        # gene families (meaningful now)
        fams = load_curated_gene_families()
        fam_labels = assign_gene_families(Emb_sg.index.tolist(), fams)

        # enrichment among top 25%
        thr = np.percentile(per_gene['rho'], 75)
        top_genes = per_gene[per_gene['rho']>=thr].index
        print(f"\nHigh-concordance genes (top 25%): {len(top_genes)}")
        # family enrichment quick check
        for fam_name in ['transcription_factor','kinase','chromatin_regulator','metabolic_enzyme']:
            fam_set = fam_labels[fam_labels['gene_family']==fam_name].index
            a = len(set(top_genes) & set(fam_set))
            b = len(fam_set) - a
            c = len(set(top_genes) - set(fam_set))
            d = len(Emb_sg.index) - a - b - c
            if a+b >= 3 and c+d >= 3:
                OR, p = fisher_exact([[a,b],[c,d]], alternative='greater')
                print(f"  {fam_name}: {a}/{a+b} high; OR={OR:.2f}, p={p:.3f}")

        # CCA (2 comps)
        cca = CCA(n_components=2, max_iter=500)
        Xc, Yc = cca.fit_transform(Emb_sg.values, Eff_sg.values)
        cca_r1, cca_p1 = spearmanr(Xc[:,0], Yc[:,0])
        cca_r2, cca_p2 = spearmanr(Xc[:,1], Yc[:,1])
        print(f"\nCCA: CC1 ρ={cca_r1:.3f} (p={cca_p1:.2e}), CC2 ρ={cca_r2:.3f} (p={cca_p2:.2e})")

        # Visuals
        align_all = create_procrustes_alignment_plots(Emb_all, Eff_all, label="all_guides")
        align_sg  = create_procrustes_alignment_plots(Emb_sg, Eff_sg, label="single_genes")

        # Save top concordant barplot (family-colored)
        topN = per_gene.head(20).index
        colors = {'transcription_factor':'#e74c3c','kinase':'#3498db','chromatin_regulator':'#2ecc71','metabolic_enzyme':'#f39c12','other':'#95a5a6'}
        fam_for_top = fam_labels.loc[topN, 'gene_family']
        plt.figure(figsize=(10,8))
        plt.barh(range(len(topN)), per_gene.loc[topN,'rho'].values,
                 color=[colors[f] for f in fam_for_top])
        plt.yticks(range(len(topN)), topN); plt.xlabel('Per-gene concordance (ρ)')
        plt.title('Top 20 Most Concordant Genes (single-gene subset)')
        plt.grid(axis='x', alpha=0.3)
        plt.tight_layout(); plt.savefig("figs/top20_per_gene_concordance.png", dpi=200, bbox_inches='tight'); plt.close()

        # Package results
        final = {
            "analysis_type": "complete_biological_analysis",
            "best_method": best_method,
            "overall": {"rho": rho_global, "p": p_global,
                        "n_combinations": len(common_all),
                        "method_stats": method_stats},
            "single_gene": {
                "n_genes": len(Emb_sg),
                "per_gene_summary": {
                    "mean_rho": float(per_gene['rho'].mean()),
                    "max_rho": float(per_gene['rho'].max()),
                    "n_sig_q<0.05": int((per_gene['q']<0.05).sum())
                },
                "top10": per_gene.head(10)
            },
            "cca": {"cc1_spearman": float(cca_r1), "cc1_p": float(cca_p1),
                    "cc2_spearman": float(cca_r2), "cc2_p": float(cca_p2)},
            "alignment_plots": {"all_guides": align_all, "single_genes": align_sg}
        }

        save_json("results/complete_biological_analysis.json", final)

        # Return detailed data (for any follow-ups)
        detailed = {
            "effects_aligned": Eff_all, "embeddings_aligned": Emb_all,
            "effects_single": Eff_sg,   "embeddings_single": Emb_sg,
            "Dg_all": Dg_all, "Dp_all": Dp_all,
            "Dg_single": Dg_sg, "Dp_single": Dp_sg,
            "per_gene_df": per_gene, "family_labels": fam_labels
        }
        return final, detailed

    else:
        print("Not enough single-gene guides for per-gene analysis (need ≥20). Saving global only.")
        final = {
            "analysis_type": "complete_biological_analysis",
            "best_method": best_method,
            "overall": {"rho": rho_global, "p": p_global,
                        "n_combinations": len(common_all),
                        "method_stats": method_stats}
        }
        save_json("results/complete_biological_analysis.json", final)
        return final, {"effects_aligned": Eff_all, "embeddings_aligned": Emb_all,
                       "Dg_all": Dg_all, "Dp_all": Dp_all}

# =========================
# 7) (Optional) interaction & pathway extras based on single-gene output
# =========================

def cca_overlay_plot(Emb_sg, Eff_sg, per_gene_df, top_n=10):
    cca = CCA(n_components=2, max_iter=500)
    Xc, Yc = cca.fit_transform(Emb_sg.values, Eff_sg.values)
    r1,p1 = spearmanr(Xc[:,0], Yc[:,0]); r2,p2 = spearmanr(Xc[:,1], Yc[:,1])
    fig, axes = plt.subplots(1,3, figsize=(15,4.5))
    for ax, xi, yi, ttl, r,p in [(axes[0], Xc[:,0], Yc[:,0], "CC1", r1,p1),
                                 (axes[1], Xc[:,1], Yc[:,1], "CC2", r2,p2)]:
        ax.scatter(xi, yi, s=25, alpha=0.6)
        z = np.polyfit(xi, yi, 1); ax.plot(xi, np.poly1d(z)(xi), "r--", lw=1)
        ax.set_title(f"{ttl}: ρ={r:.3f}, p={p:.2e}"); ax.grid(True, alpha=0.3)
    axes[2].scatter(Xc[:,0], Xc[:,1], s=20, alpha=0.6, label="Embedding (CCA)")
    axes[2].scatter(Yc[:,0], Yc[:,1], s=20, alpha=0.6, label="Perturbation (CCA)", marker="s")
    axes[2].legend(); axes[2].set_title("2D CCA Overlay"); axes[2].grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig("figs/cca_overlay.png", dpi=200, bbox_inches='tight'); plt.close()

    # label top genes on CC1
    names = Emb_sg.index.tolist()
    top = per_gene_df.head(top_n).index
    plt.figure(figsize=(8,6)); plt.scatter(Xc[:,0], Yc[:,0], s=15, alpha=0.4, color='lightgray')
    for g in top:
        i = names.index(g)
        plt.scatter(Xc[i,0], Yc[i,0], s=50, color='crimson')
        plt.text(Xc[i,0], Yc[i,0], g, fontsize=8,
                 bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.7))
    plt.xlabel("Embedding CC1"); plt.ylabel("Perturbation CC1")
    plt.title(f"Top {top_n} Concordant Genes (CCA1)"); plt.grid(True, alpha=0.3)
    plt.tight_layout(); plt.savefig("figs/top_genes_cca_labeled.png", dpi=200, bbox_inches='tight'); plt.close()

# =========================
# 8) Run
# =========================

print("Pipeline loaded. Running end-to-end...")
final_results, detailed_data = run_complete_biological_analysis()

# Optional: richer CCA overlay with labels (runs only if single-gene data returned)
if 'embeddings_single' in detailed_data:
    cca_overlay_plot(detailed_data['embeddings_single'],
                     detailed_data['effects_single'],
                     detailed_data['per_gene_df'],
                     top_n=10)

print("\nDone. Key files:")
print(" - results/complete_biological_analysis.json")
print(" - figs/procrustes_all_guides.png, figs/procrustes_single_genes.png")
print(" - figs/top20_per_gene_concordance.png")
print(" - figs/cca_overlay.png, figs/top_genes_cca_labeled.png (if single-gene analysis ran)")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Pipeline loaded. Running end-to-end...
COMPLETE BIOLOGICAL GENE EMBEDDING ANALYSIS
Loading GSE133344...
Loaded: 111445 cells × 33694 genes | gemgroups: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8)]
Perturbation type distribution:
perturbation_type
single     57831
dual       41759
control    11855
Name: count, dtype: int64
Loading GeneRAIN vectors from: /content/drive/MyDrive/dataset-gene-embed/GeneRAIN-vec.200d.txt
GeneRAIN loaded: 31769 genes × 200 dims
Creating gemgroup-matched pseudobulks...
Created pseudobulk for 2037 guide-gemgroup combinations
Computing gemgroup-matched effects...
Found 4 control guides
Computed effects for 268 guides
Target combinations for analysis: 268
Missing genes in embeddings: 4 unique genes
Created average embeddings: 262 combinations
Missing genes in embeddings: 4 uniq