In [None]:
from multiprocessing import Pool, cpu_count

def process_gene(args):
    """
    args = (gid, vec, edges_i, edges_j, rand_i, rand_j, P_NULL)
    returns: (gene_id, real_score, pval)
    """
    gid, vec, ei, ej, ri, rj, P = args

    # Compute real spatial score
    S_real = spatial_score(vec, ei, ej, ri, rj)

    # Null distribution
    rng = np.random.default_rng()  # safe independent RNG
    null_dist = null_scores(vec, ei, ej, ri, rj, P=P, rng=rng)

    # Empirical p-value
    p = empirical_pvalue(S_real, null_dist)

    return gid, S_real, p

In [None]:
#!/usr/bin/env python3
"""
Full spatial-variability detection pipeline using Null-distribution + FDR.

Steps:
â€¢ Detect spatial variability per experiment
â€¢ Compute spatial_score (real), null distribution, pvals, qvals
â€¢ Label genes spatial based on qval < 0.05
â€¢ Merge all experiments
â€¢ Train global logistic regression model
â€¢ Produce a threshold plot for s2_FSV

Author: ChatGPT
"""

import os
from multiprocessing import Pool, cpu_count
import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.metrics import roc_curve
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import StratifiedKFold, cross_val_score
import matplotlib.pyplot as plt

# ================================================================
# Parameters
# ================================================================
K_NEIGHBORS = 10
RANDOM_EDGE_COUNT = 5000
P_NULL = 1000                   # permutations per gene for null distribution
FDR_THRESHOLD = 0.05
RANDOM_SEED = 42

# ================================================================
# Graph utilities
# ================================================================
from sklearn.neighbors import NearestNeighbors

def build_knn_graph(coords, k=K_NEIGHBORS):
    nbrs = NearestNeighbors(n_neighbors=k+1, algorithm="auto")
    nbrs.fit(coords)
    distances, indices = nbrs.kneighbors(coords)

    edges_i = []
    edges_j = []

    for i in range(coords.shape[0]):
        for j in indices[i][1:]:
            edges_i.append(i)
            edges_j.append(j)

    return np.array(edges_i), np.array(edges_j)

def sample_random_edges(n_nodes, count, rng):
    i = rng.integers(0, n_nodes, count)
    j = rng.integers(0, n_nodes, count)
    return i, j

# ================================================================
# Spatial score functions
def spatial_score(vec, ei, ej, ri, rj):
    dn = (vec[ei] - vec[ej]) ** 2
    dr = (vec[ri] - vec[rj]) ** 2

    Dn = np.nanmean(dn)
    Dr = np.nanmean(dr)

    # Fallback: if constant expression or invalid mean â†’ neutral (non-spatial)
    if np.isnan(Dn) or Dn <= 0:
        return 1.0

    # Also if random diffs invalid
    if np.isnan(Dr):
        return 1.0

    return Dr / (Dn + 1e-12)


def null_scores(vec, ei, ej, ri, rj, P=50, rng=None):
    """Permutation-based null distribution."""
    if rng is None:
        rng = np.random.default_rng()

    n = len(vec)
    scores = []
    for _ in range(P):
        perm = rng.permutation(n)
        vec_perm = vec[perm]
        s = spatial_score(vec_perm, ei, ej, ri, rj)
        scores.append(s)
    return np.array(scores)

def empirical_pvalue(real_score, null_dist):
    if np.isnan(real_score):
        return np.nan
    return (np.sum(null_dist >= real_score) + 1) / (len(null_dist) + 1)

def bh_fdr(pvals):
    """Benjamini-Hochberg FDR correction for p-values."""
    p = np.asarray(pvals)
    n = np.sum(~np.isnan(p))
    order = np.argsort(p)
    ranks = np.empty_like(order)
    ranks[order] = np.arange(1, len(p) + 1)

    q = p * n / ranks
    q[q > 1] = 1

    # enforce monotonicity
    q_sorted = np.minimum.accumulate(q[order][::-1])[::-1]
    q[order] = q_sorted
    return q

# ================================================================
# Per-experiment spatial detection using FDR
# ================================================================
def run_experiment(exp_path):
    exp_name = exp_path.name
    print(f"\n=== Processing Experiment {exp_name} ===")

    somde_path = exp_path / "somde_result.csv"
    ndf_path   = exp_path / "ndf.csv"
    ninfo_path = exp_path / "ninfo.csv"

    somde = pd.read_csv(somde_path, index_col=0)
    somde.index = somde.index.astype(str)

    # Load without treating first column as index
    ndf = pd.read_csv(ndf_path, index_col=None)
    
    # Force index to match SOMDE gene order (unique, correct)
    ndf.index = somde.index.astype(str)

    # Load coordinates
    print("Loading / building node coordinates ...")
    if os.path.exists(ninfo_path):
        ninfo = pd.read_csv(ninfo_path)
        if {"x", "y"}.issubset(ninfo.columns) and ninfo.shape[0] == ndf.shape[1]:
            coords = ninfo[["x", "y"]].to_numpy(dtype=float)
    
            # Check if all-zero or degenerate
            if np.all(coords == 0) or np.unique(coords, axis=0).shape[0] < 3:
                print("WARNING: ninfo coordinates are degenerate â†’ using synthetic 1D coords.")
                coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])
            else:
                print("Using x,y from ninfo.csv.")
    
        else:
            print("ninfo.csv present but invalid â†’ using synthetic 1D coords.")
            coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])
    else:
        print("ninfo.csv not found â†’ using synthetic 1D coords.")
        coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])

    # Build graph
    print("Building spatial graph...")
    edges_i, edges_j = build_knn_graph(coords, K_NEIGHBORS)
    rng = np.random.default_rng(RANDOM_SEED)
    rand_i, rand_j = sample_random_edges(coords.shape[0], RANDOM_EDGE_COUNT, rng)

    # For each gene compute real score + null distribution + p-value
    real_scores = []
    pvals = []
    gene_ids = somde.index.tolist()
    
    print(f"Computing spatial scores + null distributions in parallel (P={P_NULL}) ...")
    
    # Prepare task list
    tasks = []
    n_nodes = coords.shape[0]
    
    for gid in gene_ids:
        if gid in ndf.index:
            vec = ndf.loc[gid].to_numpy(float)
        else:
            vec = np.full(n_nodes, np.nan)
    
        tasks.append((gid, vec, edges_i, edges_j, rand_i, rand_j, P_NULL))
    
    # Parallel pool
    n_workers = max(1, cpu_count() - 1)
    print(f"Using {n_workers} workers ...")
    
    with Pool(n_workers) as pool:
        results = pool.map(process_gene, tasks)
    
    # Unpack results
    gene_ids_out = []
    real_scores = []
    pvals = []
    
    for gid, S_real, p in results:
        gene_ids_out.append(gid)
        real_scores.append(S_real)
        pvals.append(p)
    
    # Convert to arrays (optional)
    real_scores = np.array(real_scores)
    pvals = np.array(pvals)


    # FDR correction
    qvals = bh_fdr(pvals)

    df = pd.DataFrame({
        "gene_id": gene_ids_out,
        "spatial_score": real_scores,
        "spatial_pval": pvals,     # CHANGE
        "spatial_qval": qvals      # CHANGE
    }).set_index("gene_id")

    df["is_spatial_expr"] = df["spatial_pval"] < FDR_THRESHOLD
    df["experiment"] = exp_name

    # Save per-experiment label table
    out1 = f"spatial_FDR_results_{exp_name}.csv"
    df.to_csv(out1)
    print(f"Wrote {out1}")

    # Merge with SOMDE features
    somde["gene_id"] = somde.index
    somde["experiment"] = exp_name
    merged = somde.merge(df, on=["gene_id","experiment"], how="left")

    out2 = f"somde_with_labels_{exp_name}.csv"
    merged.to_csv(out2)
    print(f"Wrote {out2}")

    return df, merged

# ================================================================
# GLOBAL MERGE + LOGISTIC REGRESSION MODEL
# ================================================================
def run_global_model(all_features_file, all_labels_file):
    print("\n=== GLOBAL MODEL TRAINING ===")

    labels = pd.read_csv(all_labels_file)
    features = pd.read_csv(all_features_file)

    # Drop accidental index column
    if "Unnamed: 0" in features.columns:
        features = features.drop(columns=["Unnamed: 0"])

    # Drop duplicated columns
    if features.columns.duplicated().any():
        print("Removing duplicate columns:", features.columns[features.columns.duplicated()])
        features = features.loc[:, ~features.columns.duplicated()]

    # ðŸš¨ FIX: Remove any old is_spatial_expr
    if "is_spatial_expr" in features.columns:
        print("Removing is_spatial_expr from features (comes only from labels)")
        features = features.drop(columns=["is_spatial_expr"])

    # ensure proper types
    features["gene_id"] = features["gene_id"].astype(str)
    labels["gene_id"] = labels["gene_id"].astype(str)

    # Merge cleanly
    df = features.merge(
        labels[["gene_id","spatial_score","spatial_pval","spatial_qval","is_spatial_expr"]],
        on="gene_id",
        how="left"
    )

    print(f"Merged global dataset: {df.shape}")

    # Feature selection
    exclude = {
        "gene_id", "g", "experiment",
        "is_spatial_expr",
        "spatial_score", "pval", "qval"  # do NOT use labels or spatial features
    }

    X_cols = [c for c in df.columns
              if c not in exclude and pd.api.types.is_numeric_dtype(df[c])]
    X = df[X_cols].copy()
    
    # Replace all Â±âˆž with NaN
    X = X.replace([np.inf, -np.inf], np.nan)
    
    # Option 1: Fill NaN with 0 (neutral)
    X = X.fillna(0.0)
    y = df["is_spatial_expr"].astype(int)

    print(f"Training on {len(y)} genes; positives={y.sum()}")

    pipe = Pipeline([
        ("scaler", StandardScaler()),
        ("clf", LogisticRegression(
            penalty="l2",
            solver="liblinear",
            class_weight="balanced",
            max_iter=5000))
    ])

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    roc = cross_val_score(pipe, X, y, cv=cv, scoring="roc_auc")
    pr  = cross_val_score(pipe, X, y, cv=cv, scoring="average_precision")

    print(f"ROC AUC: {roc.mean():.4f} Â± {roc.std():.3f}")
    print(f"PR  AUC: {pr.mean():.4f} Â± {pr.std():.3f}")

    # Final model
    pipe.fit(X, y)
    coefs = pipe.named_steps["clf"].coef_.ravel()
    imp = pd.Series(coefs, index=X_cols).sort_values(key=np.abs, ascending=False)
    imp.to_csv("GLOBAL_lr_feature_importance.csv")

    print("\nTop features:")
    print(imp.head(20))

# ================================================================
# MAIN PIPELINE
# ================================================================
def main():
    ROOT = Path("somde_results")
    experiments = [d for d in ROOT.iterdir() if d.is_dir()]

    all_labels = []
    all_features = []

    for exp_path in experiments:
        df_labels, df_features = run_experiment(exp_path)
        all_labels.append(df_labels)
        all_features.append(df_features)

    labels_all = pd.concat(all_labels)
    features_all = pd.concat(all_features)

    labels_all.to_csv("ALL_spatial_labels.csv")
    features_all.to_csv("ALL_somde_features_with_labels.csv")

    print("\nWrote ALL_spatial_labels.csv and ALL_somde_features_with_labels.csv")

    run_global_model("ALL_somde_features_with_labels.csv", "ALL_spatial_labels.csv")



In [None]:
main()

In [None]:
pd.read_csv("ALL_spatial_labels.csv").columns

---

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -------------------------------------------------------------------
# CONFIG
# -------------------------------------------------------------------
ROOT = "somde_results"  # folder where experiment subfolders live
MAX_GENES_PER_GROUP = 100
CMAP = "coolwarm"       # blue â†’ red

# -------------------------------------------------------------------
# LOAD MERGED TABLES
# -------------------------------------------------------------------
features = pd.read_csv("ALL_somde_features_with_labels.csv")
labels = pd.read_csv("ALL_spatial_labels.csv")

# Ensure identifiers are strings
features["gene_id"] = features["gene_id"].astype(str)
labels["gene_id"]   = labels["gene_id"].astype(str)

# Deduplicate columns if any lingering duplicates
if features.columns.duplicated().any():
    features = features.loc[:, ~features.columns.duplicated()]

# Merge (safe even if labels already inside features)
df = features.merge(
    labels[["gene_id", "experiment", "is_spatial_expr"]],
    on=["gene_id", "experiment"],
    how="inner",
    suffixes=("", "_y")
)

# If merge reintroduced duplicated is_spatial_expr columns:
if "is_spatial_expr_y" in df.columns:
    df = df.drop(columns=["is_spatial_expr_y"])

print("Merged shape:", df.shape)


# -------------------------------------------------------------------
# SPLIT INTO GROUPS
# -------------------------------------------------------------------
spatial_genes = df[df["is_spatial_expr"] == True].sort_values(by='is_spatial_expr', ascending=False)
nonspatial_genes = df[df["is_spatial_expr"] == False].sort_values(by='is_spatial_expr', ascending=True)

print("Spatial genes:", spatial_genes.shape[0])
print("Non-spatial genes:", nonspatial_genes.shape[0])

# Pick first N from each
genes_to_plot = {
    "Spatial": spatial_genes.head(MAX_GENES_PER_GROUP),
    "NonSpatial": nonspatial_genes.head(MAX_GENES_PER_GROUP)
}


# -------------------------------------------------------------------
# HELPER: LOAD EXPERIMENT DATA
# -------------------------------------------------------------------
def load_experiment(exp_name):
    exp_path = os.path.join(ROOT, exp_name)
    somde_path = os.path.join(exp_path, "somde_result.csv")
    ndf_path   = os.path.join(exp_path, "ndf.csv")
    ninfo_path = os.path.join(exp_path, "ninfo.csv")

    somde = pd.read_csv(somde_path, index_col=0)
    somde.index = somde.index.astype(str)

    # Load ndf properly â€“ do NOT assume its index is meaningful â†’ overwrite with somde.index
    ndf = pd.read_csv(ndf_path, index_col=None)
    assert ndf.shape[0] == somde.shape[0], "ndf and somde row mismatch"
    ndf.index = somde.index

    # Load coordinates
    if os.path.exists(ninfo_path):
        ninfo = pd.read_csv(ninfo_path)
        if {"x", "y"}.issubset(ninfo.columns):
            coords = ninfo[["x", "y"]].to_numpy(float)

            # If degenerate, synthesize 1D coordinates
            if np.all(coords == 0) or np.unique(coords, axis=0).shape[0] < 3:
                print(f"WARNING: degenerate coordinates in {exp_name}, using synthetic 1D coords")
                coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])
        else:
            coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])
    else:
        coords = np.column_stack([np.arange(ndf.shape[1]), np.zeros(ndf.shape[1])])

    return somde, ndf, coords


# -------------------------------------------------------------------
# HELPER: GET EXPRESSION VECTOR FOR GENE
# -------------------------------------------------------------------
def get_gene_expression_vector(gid, somde_df, ndf_df):
    """
    gid is gene_id (index into somde_result)
    expression row in ndf is aligned by index.
    """
    if gid not in somde_df.index:
        raise ValueError(f"Gene {gid} not found in somde_result!")

    # Directly use the matching row from ndf (aligned earlier)
    return ndf_df.loc[gid].to_numpy(float)


# -------------------------------------------------------------------
# HELPER: PLOT FUNCTION
# -------------------------------------------------------------------
def plot_spatial_expression(gene_name, expression_vec, coords, group_name, experiment_name):
    x = coords[:, 0]
    y = coords[:, 1]
    v = expression_vec

    plt.figure(figsize=(6, 5))

    sc = plt.scatter(
        x, y, c=v,
        cmap=CMAP,
        s=60,
        edgecolor="k"
    )
    plt.colorbar(sc, label="Expression level")

    plt.title(f"{group_name} gene: {gene_name}\nExperiment: {experiment_name}")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.tight_layout()

    # Save per-gene figure
    outname = f"out/plot_{group_name}_{experiment_name}_{gene_name}.png"
    plt.savefig(outname, dpi=150)
    plt.close()
    print("Saved:", outname)


# -------------------------------------------------------------------
# MAIN LOOP: VISUALIZATION
# -------------------------------------------------------------------
for group_name, gdf in genes_to_plot.items():
    print(f"\n=== Plotting {group_name} genes ===")
    for _, row in gdf.iterrows():
        gid = row["gene_id"]
        gsym = row["g"]
        exp  = row["experiment"]

        # Load correct experiment data
        somde, ndf, coords = load_experiment(exp)

        # Expression vector
        vec = get_gene_expression_vector(gid, somde, ndf)

        # Plot
        plot_spatial_expression(
            gene_name=gsym,
            expression_vec=vec,
            coords=coords,
            group_name=group_name,
            experiment_name=exp
        )

print("\nDone plotting!")
