In [None]:
import os
import shutil
import tempfile
import pytest
import numpy as np
import pandas as pd
import scipy.sparse as sp
from pathlib import Path
from anndata import AnnData
import scanpy as sc

from multiomic_transformer.pipeline.config import get_paths
from multiomic_transformer.pipeline.io_utils import (
    ensure_dir,
    write_parquet_safe,
    checkpoint_exists,
    write_done_flag,
    parquet_exists,
    read_parquet_safely,
)
from multiomic_transformer.pipeline.qc_and_pseudobulk import filter_and_qc, pseudo_bulk, run_qc_and_pseudobulk
from multiomic_transformer.pipeline.peak_gene_mapping import run_peak_gene_mapping


In [13]:
import argparse
import logging
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.stats import spearmanr, pearsonr

from multiomic_transformer.pipeline.io_utils import (
    ensure_dir,
    write_parquet_safe,
    checkpoint_exists,
    write_done_flag,
    StageTimer,
)
from multiomic_transformer.pipeline.config import get_paths


# =====================================================================
# Utility: compute correlations across pseudobulk samples
# =====================================================================
def compute_tf_tg_correlations(expr_df: pd.DataFrame, tfs: list[str], tgs: list[str]) -> pd.DataFrame:
    """
    Compute Pearson and Spearman correlations between TFs and TGs
    across pseudobulk samples.

    Args:
        expr_df : DataFrame (genes × samples)
        tfs : list of TF gene names present in expr_df.index
        tgs : list of TG gene names present in expr_df.index

    Returns:
        pd.DataFrame with columns:
          ["TF", "TG", "pearson_corr", "spearman_corr"]
    """
    tf_tg_corr = []
    expr_df = expr_df.loc[expr_df.index.intersection(tfs + tgs)]

    # Convert to float32 for efficiency
    expr_mat = expr_df.astype(np.float32)

    for tf in tfs:
        if tf not in expr_mat.index:
            continue
        tf_values = expr_mat.loc[tf].values
        for tg in tgs:
            if tg not in expr_mat.index:
                continue
            tg_values = expr_mat.loc[tg].values
            if np.all(tf_values == 0) or np.all(tg_values == 0):
                pear, spear = np.nan, np.nan
            else:
                pear, _ = pearsonr(tf_values, tg_values)
                spear, _ = spearmanr(tf_values, tg_values)
            tf_tg_corr.append((tf, tg, pear, spear))

    return pd.DataFrame(tf_tg_corr, columns=["TF", "TG", "pearson_corr", "spearman_corr"])


# =====================================================================
# Main function: integrate features
# =====================================================================
def run_tf_tg_feature_construction(
    pseudobulk_file: Path,
    reg_potential_file: Path,
    peak_gene_links_file: Path,
    output_file: Path,
    force: bool = False,
):
    """
    Main entrypoint for Stage 3 TF–TG feature construction.
    """
    if checkpoint_exists(output_file) and not force:
        print(f"[SKIP] {output_file} already exists.")
        return

    with StageTimer("TF–TG Feature Construction"):

        # ---------------------------------------------------------------
        # 1. Load inputs
        # ---------------------------------------------------------------
        print(f"Loading pseudobulk expression: {pseudobulk_file}")
        expr_df = pd.read_parquet(pseudobulk_file)
        if expr_df.shape[0] < expr_df.shape[1]:
            # ensure genes × samples
            print("Transposing pseudobulk expression to [genes × samples]")
            expr_df = expr_df.T

        print(f"Loading regulatory potential: {reg_potential_file}")
        tf_tg_reg = pd.read_parquet(reg_potential_file)

        print(f"Loading peak–gene links: {peak_gene_links_file}")
        peak_gene_links = pd.read_parquet(peak_gene_links_file)

        # ---------------------------------------------------------------
        # 2. Derive TF and TG sets
        # ---------------------------------------------------------------
        tfs = sorted(tf_tg_reg["TF"].unique().tolist())
        tgs = sorted(tf_tg_reg["TG"].unique().tolist())
        print(f"Found {len(tfs)} TFs and {len(tgs)} TGs in regulatory potential file")

        # ---------------------------------------------------------------
        # 3. Mean expression features (across pseudobulk samples)
        # ---------------------------------------------------------------
        common_tfs = [tf for tf in tfs if tf in expr_df.index]
        common_tgs = [tg for tg in tgs if tg in expr_df.index]

        mean_tf_expr = expr_df.loc[common_tfs].mean(axis=1).rename("mean_tf_expr")
        mean_tg_expr = expr_df.loc[common_tgs].mean(axis=1).rename("mean_tg_expr")

        mean_tf_expr = mean_tf_expr.reset_index().rename(columns={"index": "TF"})
        mean_tg_expr = mean_tg_expr.reset_index().rename(columns={"index": "TG"})

        # ---------------------------------------------------------------
        # 4. Correlation features (across pseudobulk samples)
        # ---------------------------------------------------------------
        print("Computing TF–TG correlations across pseudobulk samples")
        corr_df = compute_tf_tg_correlations(expr_df, common_tfs, common_tgs)
        print(f"Correlation matrix computed for {len(corr_df):,} TF–TG pairs")

        # ---------------------------------------------------------------
        # 5. Merge all feature sources
        # ---------------------------------------------------------------
        merged = tf_tg_reg.merge(mean_tf_expr, on="TF", how="left")
        merged = merged.merge(mean_tg_expr, on="TG", how="left")
        merged = merged.merge(corr_df, on=["TF", "TG"], how="left")

        # Merge in distance-based features if available
        if {"peak_id", "TSS_dist"}.issubset(peak_gene_links.columns):
            # normalize TG naming
            if "TG" not in peak_gene_links.columns and "gene_id" in peak_gene_links.columns:
                peak_gene_links = peak_gene_links.rename(columns={"gene_id": "TG"})

            dist_df = (
                peak_gene_links[["peak_id", "TG", "TSS_dist"]]
                .groupby("TG", as_index=False)
                .agg(TSS_dist=("TSS_dist", "mean"))
            )
            merged = merged.merge(dist_df, on="TG", how="left")
            merged["neg_log_tss_dist"] = -np.log1p(merged["TSS_dist"].fillna(0))
        else:
            print("No TSS distance column found — skipping distance features.")

        # ---------------------------------------------------------------
        # 6. Derived features
        # ---------------------------------------------------------------
        merged["expr_product"] = merged["mean_tf_expr"] * merged["mean_tg_expr"]
        merged["log_reg_pot"] = np.log1p(merged.get("reg_potential", 0))
        merged["motif_present"] = (merged.get("motif_density", 0) > 0).astype(int)

        # ---------------------------------------------------------------
        # 7. Save output
        # ---------------------------------------------------------------
        if merged is None or merged.empty:
            print("No TF–TG features generated; skipping Parquet write.")
            return

        try:
            ensure_dir(output_file.parent)
            write_parquet_safe(merged, output_file)
            print(f"[DONE] Stage 3 complete → {output_file} ({merged.shape[0]:,} rows)")
        except Exception as e:
            logging.error(f"Failed to write TF–TG features: {e}")
            raise

In [14]:
def test_run_tf_tg_feature_construction(tmp_path):
    """
    Validate Stage 3 TF–TG feature integration on synthetic data.

    This test verifies:
      • Correct merging of expression, regulatory potential, and distance data
      • Creation of expected feature columns
      • Proper .done checkpoint writing
      • Correlation values within [-1, 1]
      • No NaN or infinite values in key numerical fields
    """


    # --------------------------------------------------------------
    # 1. Create minimal synthetic inputs (guaranteed TF–TG overlap)
    # --------------------------------------------------------------
    # Expression for both TFs and TGs (ensure non-constant variance)
    expr_df = pd.DataFrame(
        {
            "S1": [0.1, 0.5, 0.3, 0.7],
            "S2": [0.2, 0.6, 0.4, 0.8],
            "S3": [0.3, 0.7, 0.5, 0.9],
        },
        index=["TF1", "TF2", "TG1", "TG2"]
    )
    pseudobulk_file = tmp_path / "pseudobulk_expr.parquet"
    expr_df.to_parquet(pseudobulk_file)


    # Ensure TF/TG names overlap with expression index
    tf_tg_reg = pd.DataFrame({
        "TF": ["TF1", "TF1", "TF2", "TF2"],
        "TG": ["TG1", "TG2", "TG1", "TG2"],   # matches expr_df.index
        "reg_potential": [0.2, 0.5, 0.3, 0.1],
        "motif_density": [3, 0, 5, 2],
    })
    reg_potential_file = tmp_path / "tf_tg_regulatory_potential.parquet"
    tf_tg_reg.to_parquet(reg_potential_file)

    peak_gene_links = pd.DataFrame({
        "peak_id": ["p1", "p2", "p3", "p4"],
        "TG": ["TG1", "TG1", "TG2", "TG2"],
        "TSS_dist": [1000, 50000, 10000, 8000]
    })
    peak_gene_links_file = tmp_path / "peak_gene_links.parquet"
    peak_gene_links.to_parquet(peak_gene_links_file)

    output_file = tmp_path / "tf_tg_features.parquet"

    # --------------------------------------------------------------
    # 2. Run feature construction
    # --------------------------------------------------------------
    run_tf_tg_feature_construction(
        pseudobulk_file=pseudobulk_file,
        reg_potential_file=reg_potential_file,
        peak_gene_links_file=peak_gene_links_file,
        output_file=output_file,
        force=True,
    )

tmp_path = Path("/gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/outputs/testing")
test_run_tf_tg_feature_construction(tmp_path)

Loading pseudobulk expression: /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/outputs/testing/pseudobulk_expr.parquet
Loading regulatory potential: /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/outputs/testing/tf_tg_regulatory_potential.parquet
Loading peak–gene links: /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/outputs/testing/peak_gene_links.parquet
Found 2 TFs and 2 TGs in regulatory potential file
Computing TF–TG correlations across pseudobulk samples
Correlation matrix computed for 4 TF–TG pairs
[DONE] Stage 3 complete → /gpfs/Labs/Uzun/SCRIPTS/PROJECTS/2024.SINGLE_CELL_GRN_INFERENCE.MOELLER/outputs/testing/tf_tg_features.parquet (4 rows)


In [15]:
df_out = pd.read_parquet(tmp_path / "tf_tg_features.parquet")