In [1]:
"""
Build within-tissue (TS) and cross-tissue (CT) correlation matrices
from CSV expression files (rows=samples/donors, cols=genes).

- Mirrors the logic in your R `AdjacencyFromExpr`, but returns correlations.
- Supports donor aggregation (average duplicate samples by donor key).
- Offers SD-quantile + top-N filtering per tissue.
- CT correlations are computed across *common donors* between tissues.
- Two CT modes:
  * mode='pairwise' (default): accurate pairwise-complete correlations (slower).
  * mode='complete': drop donors with any NaN across both tissues, then do fast
    vectorized correlations.

Optionally, you can raise |corr|^beta to obtain adjacency blocks (TS/CT).

Eden-friendly: pure pandas/numpy; easy to swap-in Polars for I/O if you wish.
"""
from __future__ import annotations
import re
from typing import Callable, Dict, Iterable, List, Mapping, Optional, Tuple
import numpy as np
import pandas as pd
from rich import print as rprint
from rich.progress import track
from itertools import combinations
import time



# -------------------------- I/O & preprocessing --------------------------

def _log(msg: str):
    try:
        rprint(msg)
    except Exception:
        print(msg)

def load_expr_csv(
    path: str,
    index_col: int | str = 0,
    sep: str = ",",
    dtype=float,
) -> pd.DataFrame:
    """Load expression with rows=samples (donors), cols=genes.
    Assumes the first column contains sample IDs if `index_col=0`.
    """
    df = pd.read_csv(path, sep=sep, index_col=index_col)
    # force numeric where possible
    df = df.apply(pd.to_numeric, errors="coerce").astype(dtype)
    # Drop all-NaN columns (rare but safer)
    df = df.loc[:, df.notna().any(axis=0)]
    return df


def extract_donor_ids(
    sample_index: Iterable[str],
    regex: str = r"^([^-]+-[^-]+)",
) -> List[str]:
    """Extract donor IDs from sample names using a regex (defaults to
    the R pattern ^([^-]+-[^-]+) to capture the first two dash-delimited parts).
    Fallback: if no match, use the full sample name.
    """
    pat = re.compile(regex)
    donors = []
    for s in map(str, sample_index):
        m = pat.search(s)
        donors.append(m.group(1) if m else s)
    return donors


def aggregate_by_donor(
    df: pd.DataFrame,
    donor_ids: Optional[Iterable[str]] = None,
    agg: str = "mean",
) -> pd.DataFrame:
    """Average duplicate samples per donor (rows) within a tissue.

    Parameters
    ----------
    df : samples x genes
    donor_ids : iterable of donor IDs aligned to df.index. If None, tries to
                parse using `extract_donor_ids(df.index)`.
    agg : 'mean' or a pandas-accepted aggregation
    """
    if donor_ids is None:
        donor_ids = extract_donor_ids(df.index)
    g = df.groupby(pd.Index(donor_ids), sort=False)
    if agg == "mean":
        out = g.mean(numeric_only=True)
    else:
        out = g.aggregate(agg)
    return out.astype(float)


def filter_genes_by_sd(
    df: pd.DataFrame,
    sd_quantile: float = 0.0,
    max_genes_per_tissue: Optional[int] = None,
) -> pd.DataFrame:
    """Keep high-variance genes per tissue.

    - Keep genes with SD >= quantile(sd_quantile)
    - If more than max_genes_per_tissue, keep top-N by SD
    """
    sds = df.std(axis=0, ddof=1)
    thr = np.nanquantile(sds.values, sd_quantile) if sds.size else np.nan
    keep = sds.index[sds >= thr]
    df2 = df.loc[:, keep]
    if max_genes_per_tissue is not None and df2.shape[1] > max_genes_per_tissue:
        top = sds.loc[df2.columns].sort_values(ascending=False).head(max_genes_per_tissue).index
        df2 = df2.loc[:, top]
    return df2


def prefix_gene_columns(df: pd.DataFrame, tissue_name: str) -> pd.DataFrame:
    df2 = df.copy()
    df2.columns = [f"{tissue_name}_{c}" for c in df.columns]
    return df2

# -------------------------- Correlations (TS) --------------------------

def corr_within_tissue(
    df: pd.DataFrame,
    method: str = "pearson",
    absolute: bool = True,
) -> pd.DataFrame:
    """Within-tissue gene×gene correlation (pairwise complete)."""
    if method not in {"pearson", "spearman"}:
        raise ValueError("method must be 'pearson' or 'spearman'")
    C = df.corr(method=method)  # pairwise complete obs
    if absolute:
        C = C.abs()
    return C

# -------------------------- Correlations (CT) --------------------------

def _corr_series_pairwise(s: pd.Series, Y: pd.DataFrame, method: str) -> pd.Series:
    # Pairwise-complete correlations of one vector against all columns of Y.
    return Y.apply(lambda col: s.corr(col, method=method), axis=0)

def _pairwise_corr_fast(A: pd.DataFrame, B: pd.DataFrame, method: str = "pearson", absolute: bool = True, block: int = 1024) -> pd.DataFrame:
    """
    Fast, exact pairwise-complete correlation between columns of A and B using
    vectorized masked sums. Handles NaNs per pair; works for Pearson/Spearman.
    A, B: donors × genes
    """
    # Rank for Spearman
    if method == "spearman":
        A = A.rank(axis=0, method="average")
        B = B.rank(axis=0, method="average")

    # Prepare arrays (float64)
    A_vals = A.to_numpy(dtype=np.float64, copy=False)
    B_vals = B.to_numpy(dtype=np.float64, copy=False)
    M_A = np.isfinite(A_vals).astype(np.float64)
    M_B = np.isfinite(B_vals).astype(np.float64)

    # Replace NaNs with 0 for masked arithmetic
    A_vals = np.nan_to_num(A_vals, copy=False)
    B_vals = np.nan_to_num(B_vals, copy=False)

    # Precompute A-side terms once
    A_vals_masked = A_vals * M_A
    A_vals2_masked = (A_vals * A_vals) * M_A

    p = A_vals.shape[1]
    q = B_vals.shape[1]
    out = np.empty((p, q), dtype=np.float64)
    out.fill(np.nan)

    # Process B in column blocks to reduce peak memory
    for j0 in range(0, q, block):
        j1 = min(j0 + block, q)
        Mb = M_B[:, j0:j1]
        Bb = B_vals[:, j0:j1]
        Bb_masked = Bb * Mb
        Bb2_masked = (Bb * Bb) * Mb

        # Pairwise counts n_ij over donor intersection
        n = M_A.T @ Mb  # p × (j1-j0)

        # Pairwise sums over intersection
        sumA = A_vals_masked.T @ Mb               # p × (j1-j0)
        sumB = M_A.T @ Bb_masked                  # p × (j1-j0)
        sumAB = A_vals_masked.T @ Bb_masked       # p × (j1-j0)
        sumA2 = A_vals2_masked.T @ Mb             # p × (j1-j0)
        sumB2 = M_A.T @ Bb2_masked                # p × (j1-j0)

        # Pearson correlation components
        with np.errstate(invalid="ignore", divide="ignore"):
            cov = sumAB - (sumA * sumB) / n
            varA = sumA2 - (sumA * sumA) / n
            varB = sumB2 - (sumB * sumB) / n
            den = np.sqrt(varA * varB)
            Cblk = cov / den

        # Valid only where n >= 3 and variance positive
        valid = (n >= 3) & np.isfinite(Cblk) & (den > 0)
        Cblk[~valid] = np.nan
        out[:, j0:j1] = Cblk

    C = pd.DataFrame(out, index=A.columns, columns=B.columns)
    if absolute:
        C = C.abs()
    return C

# ...existing code...
def corr_cross_tissue(
    A: pd.DataFrame,
    B: pd.DataFrame,
    method: str = "pearson",
    absolute: bool = True,
    mode: str = "pairwise",
    block: int = 256,
) -> pd.DataFrame:
    """Cross-tissue correlation (genes_A × genes_B) over *common donors*.

    Parameters
    ----------
    A, B : donors×genes matrices
    method : 'pearson' or 'spearman'
    absolute : take |corr|
    mode : 'pairwise' (exact, fast, pairwise-complete via vectorized blocks)
           'complete' (fast, drops any donor with NaN, vectorized)
    block : number of B columns per block in 'pairwise' mode
    """
    if method not in {"pearson", "spearman"}:
        raise ValueError("method must be 'pearson' or 'spearman'")
    common = A.index.intersection(B.index)
    if len(common) == 0:
        raise ValueError("No common donors between the two tissues")
    A = A.loc[common]
    B = B.loc[common]
    # drop non numeric columns
    A = A.select_dtypes(include=[np.number])
    B = B.select_dtypes(include=[np.number])

    if mode == "complete":
        mask = A.notna().all(axis=1) & B.notna().all(axis=1)
        A2 = A.loc[mask]
        B2 = B.loc[mask]
        if A2.shape[0] < 3:
            raise ValueError("Too few complete donors after NaN filtering (<3)")
        if method == "spearman":
            A2 = A2.rank(axis=0, method="average")
            B2 = B2.rank(axis=0, method="average")
        # standardize
        A0 = (A2 - A2.mean(axis=0)) / A2.std(axis=0, ddof=1)
        B0 = (B2 - B2.mean(axis=0)) / B2.std(axis=0, ddof=1)
        A0 = A0.to_numpy(dtype=float)
        B0 = B0.to_numpy(dtype=float)
        n = A0.shape[0]
        C = (A0.T @ B0) / (n - 1)
        C = pd.DataFrame(C, index=A2.columns, columns=B2.columns)
        if absolute:
            C = C.abs()
        return C
    return _pairwise_corr_fast(A, B, method=method, absolute=absolute, block=block)


# -------------------------- High-level builder --------------------------

def build_ts_ct_correlations(
    tissue_names: List[str],
    tissue_files: List[str],
    sd_quantile: float = 0.0,
    max_genes_per_tissue: Optional[int] = 5000,
    cor_method: str = "pearson",
    donor_regex: str = r"^([^-]+-[^-]+)",
    aggregate_duplicates: bool = False,
    ct_mode: str = "pairwise",
    rename_gene_columns: bool = False,
    verbose: bool = True,              # NEW
    show_progress: bool = True,        # NEW
    return_stats: bool = False         # NEW
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, pd.DataFrame]]:
    """
    Load, filter, optionally aggregate by donor, and compute:
      - TS_expr: per-tissue expression (after filters, donor-aggregation optional)
      - TS_corrs: per-tissue gene×gene correlation matrices
      - CT_expr: per pair (Ti||Tj) donor-aligned expression frames (Mi, Mj)
      - CT_corrs: per pair (Ti||Tj) cross gene×gene correlations

    Returns
    -------
    TS_expr, TS_corrs, CT_expr, CT_corrs
    where:
      - TS_expr[tissue] : donors×genes
      - TS_corrs[tissue]: genes×genes
      - CT_expr[pair]   : tuple (Mi, Mj) aligned on donors
      - CT_corrs[pair]  : genes_i × genes_j
    """
    assert len(tissue_names) == len(tissue_files), "names/files length mismatch"
    T = len(tissue_names)

    stats = {
        "tissue": {},
        "ct_pairs": {},
        "totals": {"ts_total_time_s": 0.0, "ct_total_time_s": 0.0}
    }

    # 1) Load + filter
    raw_expr: Dict[str, pd.DataFrame] = {}
    itr = zip(tissue_names, tissue_files)
    if show_progress:
        itr = track(list(itr), description="Loading & filtering tissues...")
    else:
        itr = zip(tissue_names, tissue_files)

    for name, path in itr:
        t0 = time.perf_counter()
        if verbose:
            _log(f"[bold cyan] Loading[/] {name}: {path}")

        X = load_expr_csv(path)
        n_samples, n_genes = X.shape
        if verbose:
            _log(f"[bold cyan] Loaded[/] {name}: {path} ({n_samples} samples, {n_genes} genes)")

        X = filter_genes_by_sd(X, sd_quantile=sd_quantile, max_genes_per_tissue=max_genes_per_tissue)
        n_samples2, n_genes2 = X.shape
        if verbose:
            _log(f"[bold cyan] Filtered[/] {name}: {path} ({n_samples2} samples, {n_genes2} genes)")


        if rename_gene_columns:
            X = prefix_gene_columns(X, name)
        raw_expr[name] = X
        stats["tissue"][name] = {
            "n_samples_raw": int(n_samples),
            "n_genes_raw": int(n_genes),
            "n_genes_kept": int(n_genes2),
            "load_filter_time_s": round(time.perf_counter() - t0, 3)
        }

    # 2) Aggregate duplicates per donor if requested
    TS_expr: Dict[str, pd.DataFrame] = {}
    itr2 = raw_expr.items()
    if show_progress:
        itr2 = track(list(itr2), description="Aggregating duplicates...")
    else:
        itr2 = raw_expr.items()

    for name, X in itr2:
        t0 = time.perf_counter()
        if aggregate_duplicates:
            donors = extract_donor_ids(X.index, regex=donor_regex)
            n_rows_before = X.shape[0]
            Xd = aggregate_by_donor(X, donors)
            n_rows_after = Xd.shape[0]
            if verbose:
                _log(f"[bold cyan] Aggregated[/] {name}: {n_rows_before} -> {n_rows_after} rows")
        else:
            Xd = X
            if verbose:
                _log(f"[bold cyan] Skipped aggregation[/] {name}: {Xd.shape[0]} rows")

        TS_expr[name] = Xd
        stats["tissue"][name].update({
            "n_donors": int(Xd.shape[0]),
            "n_genes_pos": int(Xd.shape[1]),
            "aggregate_time_s": round(time.perf_counter() - t0, 3)
        })

    # 3) Within-tissue correlations
    TS_corrs: Dict[str, pd.DataFrame] = {}
    itr3 = TS_expr.items()
    if show_progress:
        itr3 = track(list(itr3), description="Computing TS correlations")
    else:
        itr3 = TS_expr.items()

    ts_total_t0 = time.perf_counter()
    for name, Xd in itr3:
        t0 = time.perf_counter()
        C = corr_within_tissue(Xd, method=cor_method, absolute=True)
        TS_corrs[name] = C
        if verbose:
            _log(f"[bold green]TS corr[/] {name}: genes={C.shape[0]:,} time={time.perf_counter() - t0:.2f}s")
        stats["tissue"][name]["ts_corr_time_s"] = round(time.perf_counter() - t0, 3)
    stats["totals"]["ts_total_time_s"] = round(time.perf_counter() - ts_total_t0, 3)

    # 4) Cross-tissue donor alignment + correlations
    CT_expr: Dict[str, Tuple[pd.DataFrame, pd.DataFrame]] = {}
    CT_corrs: Dict[str, pd.DataFrame] = {}

    if T >= 2:
        names = tissue_names
        pairs = [(i, j, names[i], names[j]) for i, j in combinations(range(T), 2)]
        itr4 = pairs
        if show_progress:
            itr4 = track(pairs, description="Computing CT correlations")
        ct_total_t0 = time.perf_counter()

        for i, j, ni, nj in itr4:
            A = TS_expr[ni]
            B = TS_expr[nj]
            # A.reset_index(drop=False, inplace=True)
            # B.reset_index(drop=False, inplace=True)
            # A['donor'] = A['index'].apply(lambda x: x.split("_")[-1])
            # B['donor'] = B['index'].apply(lambda x: x.split("_")[-1])
            # A.set_index('donor', inplace=True)
            # B.set_index('donor', inplace=True)
            common = (A.index).intersection(B.index)
            key = f"{ni}||{nj}"
            if verbose:
                _log(f"[bold yellow]CT pair[/] {key}: common donors={len(common):,}")

            if len(common) < 3:
                # Fix to common donors and not indexes
                # CT_expr[key] = (A.)
                CT_expr[key] = (A.loc[A.index[:0]], B.loc[B.index[:0]])
                CT_corrs[key] = pd.DataFrame(index=A.columns, columns=B.columns, dtype=float)
                stats["ct_pairs"][key] = {
                    "n_common_donors": int(len(common)),
                    "n_genes_A": int(A.shape[1]),
                    "n_genes_B": int(B.shape[1]),
                    "ct_corr_time_s": 0.0,
                    "skipped": True
                }
                continue

            Mi = A.loc[common]
            Mj = B.loc[common]
            CT_expr[key] = (Mi, Mj)
            t0 = time.perf_counter()
            C = corr_cross_tissue(Mi, Mj, method=cor_method, absolute=True, mode=ct_mode)
            CT_corrs[key] = C
            stats["ct_pairs"][key] = {
                "n_common_donors": int(len(common)),
                "n_genes_A": int(Mi.shape[1]),
                "n_genes_B": int(Mj.shape[1]),
                "ct_corr_time_s": round(time.perf_counter() - t0, 3),
                "skipped": False
            }
            if verbose:
                _log(f"  CT corr {key}: {C.shape[0]:,}×{C.shape[1]:,} time={stats['ct_pairs'][key]['ct_corr_time_s']:.2f}s")

        stats["totals"]["ct_total_time_s"] = round(time.perf_counter() - ct_total_t0, 3)

    if verbose:
        _log(f"[bold white on blue]Totals[/] TS={stats['totals']['ts_total_time_s']:.2f}s, CT={stats['totals']['ct_total_time_s']:.2f}s")

    # Preserve original return signature unless return_stats=True
    if return_stats:
        return TS_expr, TS_corrs, CT_expr, CT_corrs, stats
    return TS_expr, TS_corrs, CT_expr, CT_corrs

# -------------------------- Optional: adjacency from correlations --------------------------

def correlations_to_adjacency(
    TS_corrs: Mapping[str, pd.DataFrame],
    CT_corrs: Mapping[str, pd.DataFrame],
    TS_power_map: Optional[Mapping[str, float]] = None,
    CT_power_map: Optional[Mapping[str, float]] = None,
    default_TS: float = 6.0,
    default_CT: float = 3.0,
) -> Tuple[pd.DataFrame, Dict[str, slice]]:
    """Assemble a full block adjacency (like your R function) from TS/CT correlations.
    Returns (A, gene_blocks) where `gene_blocks[tissue]` is the slice index in A.
    """
    tissues = list(TS_corrs.keys())
    # Build index layout
    gene_order: List[str] = []
    blocks: Dict[str, slice] = {}
    start = 0
    for t in tissues:
        n = TS_corrs[t].shape[0]
        blocks[t] = slice(start, start + n)
        gene_order.extend(TS_corrs[t].index.tolist())
        start += n

    A = pd.DataFrame(0.0, index=gene_order, columns=gene_order, dtype=float)

    # TS blocks
    for t in tissues:
        beta = (TS_power_map or {}).get(t, default_TS)
        A.loc[TS_corrs[t].index, TS_corrs[t].columns] = TS_corrs[t].pow(beta)

    # CT blocks
    for key, C in CT_corrs.items():
        if C.empty:
            continue
        ti, tj = key.split("||", 1)
        beta = (CT_power_map or {}).get(key, (CT_power_map or {}).get(f"{tj}||{ti}", default_CT))
        A.loc[C.index, C.columns] = C.pow(beta)
        A.loc[C.columns, C.index] = C.T.pow(beta)

    return A, blocks



In [3]:
"""
Animate within-tissue (TS) and cross-tissue (CT) correlation DISTRIBUTIONS
as a function of WGCNA power (beta), using Plotly.

Designed to plug into your existing pipeline that produces TS_corrs and CT_corrs
from `build_ts_ct_correlations(...)` you shared.

Key idea:
- We already have |corr| in [0,1]. WGCNA adjacency is a(|corr|) = |corr|**beta.
- We sample (to keep it fast/memory‑safe), then for each beta compute a histogram
  of the transformed values. We render an animated histogram with a slider.

Outputs:
- Two interactive HTML files saved to disk:
  1) <out_html_prefix>__TS.html   (one facet per tissue)
  2) <out_html_prefix>__CT.html   (one facet per tissue‑pair)

Usage (minimal):

    tissues = ["Adipose", "Muscle", "Brain"]
    files   = [
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Adipose - Subcutaneous_old.csv",
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Muscle - Skeletal_old.csv",
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Brain - Cortex_old.csv",
    ]
    TS_expr, TS_corrs, CT_expr, CT_corrs = build_ts_ct_correlations(
        tissue_names=tissues,
        tissue_files=files,
        sd_quantile=0.0,
        max_genes_per_tissue=5000,
        cor_method="pearson",
        ct_mode="pairwise",
        show_progress=False
    )

    from animate_corr_distributions import animate_ts_ct_distributions
    out_TS, out_CT = animate_ts_ct_distributions(
        TS_corrs, CT_corrs,
        betas=list(range(1, 21)),     # 1..20
        sample_per_group=200_000,     # downsample per tissue / pair (keeps files light)
        bins=40,
        density=True,
        seed=42,
        out_html_prefix="corr_beta_anim"
    )
    print("Saved:", out_TS, out_CT)

Notes:
- CT matrices can be huge (|Gi|×|Gj|). Sampling is **strongly** recommended.
- We fix histogram bin edges across all frames so y‑axes don’t jump.
- If your corr matrices are *not* absolute, set `ABS_INPUT=False` below.

"""
from __future__ import annotations
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from pathlib import Path
import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---------------------------- Config ----------------------------
ABS_INPUT = True  # set to False if your TS/CT corr matrices are signed

# --------------------- Helpers: flatten & sample -----------------

def _rng(seed: Optional[int]) -> np.random.Generator:
    return np.random.default_rng(seed) if seed is not None else np.random.default_rng()


def flatten_upper_triangle(C: pd.DataFrame) -> np.ndarray:
    """Return upper‑triangle (i<j) **vector** of correlations from a square matrix.
    Drops NaNs. If ABS_INPUT=True, assumes C already |corr|; else applies abs.
    """
    m = C.to_numpy(copy=False)
    n = m.shape[0]
    iu = np.triu_indices(n, k=1)
    v = m[iu]
    if not ABS_INPUT:
        v = np.abs(v)
    v = v[np.isfinite(v)]
    # Keep in [0,1]
    v = v[(v >= 0) & (v <= 1)]
    return v.astype(np.float64, copy=False)


def flatten_rect(C: pd.DataFrame) -> np.ndarray:
    """Return **all** values from a rectangular CT matrix as a vector. Drops NaNs.
    If ABS_INPUT=False, applies abs. Clips to [0,1]."""
    v = C.to_numpy(copy=False).ravel()
    if not ABS_INPUT:
        v = np.abs(v)
    v = v[np.isfinite(v)]
    v = v[(v >= 0) & (v <= 1)]
    return v.astype(np.float64, copy=False)


def sample_vec(v: np.ndarray, max_n: Optional[int], seed: Optional[int]) -> np.ndarray:
    """Downsample vector `v` without replacement to length <= max_n."""
    if (max_n is None) or (v.size <= max_n):
        return v
    rng = _rng(seed)
    idx = rng.choice(v.size, size=max_n, replace=False)
    return v[idx]

# ----------------------- Histograms per beta ---------------------

def compute_histograms_for_betas(
    values_by_group: Mapping[str, np.ndarray],
    betas: Sequence[float],
    bins: int = 40,
    density: bool = True,
) -> Tuple[Dict[float, Dict[str, np.ndarray]], np.ndarray]:
    """For each beta, compute histogram counts for each group.

    Returns
    -------
    frame_counts : dict[beta -> dict[group -> counts]]
    bin_edges    : np.ndarray of bin edges in [0,1]
    """
    # Common bin edges across all frames
    bin_edges = np.linspace(0.0, 1.0, bins + 1)

    frame_counts: Dict[float, Dict[str, np.ndarray]] = {}
    bin_widths = np.diff(bin_edges)

    for b in betas:
        frame_counts[b] = {}
        for g, v in values_by_group.items():
            if v.size == 0:
                frame_counts[b][g] = np.zeros(bins, dtype=float)
                continue
            w = np.power(v, b)
            counts, _ = np.histogram(w, bins=bin_edges, range=(0.0, 1.0))
            if density:
                n = w.size
                # probability density: counts / (n * bin_width)
                counts = counts.astype(float) / (n * bin_widths)
            frame_counts[b][g] = counts.astype(float)
    return frame_counts, bin_edges

# ------------------- Animated subplot figure --------------------

def make_animated_hist_subplots(
    frame_counts: Mapping[float, Mapping[str, np.ndarray]],
    bin_edges: np.ndarray,
    subplot_titles: Sequence[str],
    title: str,
    height: int = 400,
) -> go.Figure:
    """Create a subplot figure (1 row × N cols), one histogram per group,
    animated over frames keyed by beta.
    """
    groups = list(subplot_titles)
    n = len(groups)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0

    fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.06,
                        subplot_titles=tuple(groups))

    # Initial traces at the first beta
    betas_sorted = sorted(frame_counts.keys(), key=float)
    b0 = betas_sorted[0]

    for j, g in enumerate(groups, start=1):
        y0 = frame_counts[b0].get(g, np.zeros_like(bin_centers))
        fig.add_trace(
            go.Bar(x=bin_centers, y=y0, name=g, showlegend=False),
            row=1, col=j
        )

    # Frames: one per beta, update each subplot trace in order
    frames = []
    for b in betas_sorted:
        data = []
        for g in groups:
            y = frame_counts[b].get(g, np.zeros_like(bin_centers))
            data.append(go.Bar(x=bin_centers, y=y))
        frames.append(go.Frame(name=f"beta={b}", data=data))

    fig.update(frames=frames)

    # Slider & play controls
    steps = []
    for k, b in enumerate(betas_sorted):
        steps.append({
            "args": [[f"beta={b}"], {"frame": {"duration": 0, "redraw": True},
                                      "mode": "immediate", "transition": {"duration": 0}}],
            "label": str(b),
            "method": "animate",
        })

    fig.update_layout(
        title=title,
        height=height,
        bargap=0.05,
        xaxis_title="adjacency = |corr|^beta",
        yaxis_title="density" if True else "count",
        updatemenus=[{
            "type": "buttons",
            "showactive": True,
            "x": 1.05,
            "y": 1.15,
            "xanchor": "right",
            "yanchor": "top",
            "buttons": [
                {"label": "▶ Play", "method": "animate",
                 "args": [None, {"frame": {"duration": 300, "redraw": True},
                                   "transition": {"duration": 0},
                                   "fromcurrent": True}]},
                {"label": "⏸ Pause", "method": "animate",
                 "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate"}]}
            ]
        }],
        sliders=[{
            "active": 0,
            "y": -0.08,
            "x": 0.5,
            "len": 0.9,
            "xanchor": "center",
            "yanchor": "top",
            "pad": {"b": 10, "t": 30},
            "steps": steps
        }]
    )

    # Lock axes to [0,1] on x for all subplots
    for i in range(n):
        fig.update_xaxes(range=[0, 1], row=1, col=i+1)
    return fig

# ------------------- Public API: main function -------------------

def animate_ts_ct_distributions(
    TS_corrs: Mapping[str, pd.DataFrame],
    CT_corrs: Mapping[str, pd.DataFrame],
    betas: Sequence[float] = tuple(range(1, 21)),
    sample_per_group: Optional[int] = 200_000,
    bins: int = 40,
    density: bool = True,
    seed: Optional[int] = 0,
    out_html_prefix: str = "corr_beta_anim",
    height: int = 420,
) -> Tuple[str, str]:
    """Create two animated Plotly histograms (TS and CT) over beta.

    Parameters
    ----------
    TS_corrs : dict[tissue -> square (genes×genes) DataFrame]
    CT_corrs : dict["Ti||Tj" -> rectangular (Gi×Gj) DataFrame]
    betas : iterable of beta values (>=1 recommended)
    sample_per_group : downsample limit per tissue/pair (None = no limit)
    bins : histogram bins
    density : normalize to probability density (True) or raw counts (False)
    seed : RNG seed for reproducible sampling
    out_html_prefix : file prefix for saved HTML files
    height : figure height in pixels

    Returns
    -------
    (ts_html_path, ct_html_path)
    """
    rng = _rng(seed)

    # --- Prepare TS values ---
    ts_values: Dict[str, np.ndarray] = {}
    for tname, C in TS_corrs.items():
        if C is None or C.size == 0:
            ts_values[tname] = np.array([], dtype=float)
            continue
        v = flatten_upper_triangle(C)
        v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31 - 1))
        ts_values[tname] = v

    # --- Prepare CT values ---
    ct_values: Dict[str, np.ndarray] = {}
    for pair, C in CT_corrs.items():
        if C is None or C.size == 0:
            ct_values[pair] = np.array([], dtype=float)
            continue
        v = flatten_rect(C)
        v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31 - 1))
        ct_values[pair] = v

    # Histograms over betas (TS)
    ts_frame_counts, bin_edges = compute_histograms_for_betas(
        ts_values, betas=betas, bins=bins, density=density
    )
    # Build animated subplot figure for TS
    ts_groups = list(ts_values.keys())
    if len(ts_groups) == 0:
        ts_groups = ["No TS"]
        ts_frame_counts = {b: {"No TS": np.zeros(bins)} for b in betas}
    fig_TS = make_animated_hist_subplots(
        ts_frame_counts, bin_edges, subplot_titles=ts_groups,
        title="Within‑tissue adjacency distribution vs β", height=height
    )

    # Histograms over betas (CT)
    ct_frame_counts, bin_edges_ct = compute_histograms_for_betas(
        ct_values, betas=betas, bins=bins, density=density
    )
    ct_groups = list(ct_values.keys())
    if len(ct_groups) == 0:
        ct_groups = ["No CT"]
        ct_frame_counts = {b: {"No CT": np.zeros(bins)} for b in betas}
    fig_CT = make_animated_hist_subplots(
        ct_frame_counts, bin_edges_ct, subplot_titles=ct_groups,
        title="Cross‑tissue adjacency distribution vs β", height=height
    )

    # Save HTML files
    out_TS = f"{out_html_prefix}__TS.html"
    out_CT = f"{out_html_prefix}__CT.html"
    fig_TS.write_html(out_TS, include_plotlyjs="cdn", auto_play=False)
    fig_CT.write_html(out_CT, include_plotlyjs="cdn", auto_play=False)

    return out_TS, out_CT


# ------------------- If run as a script -------------------
#if __name__ == "__main__":
    # Minimal self‑test with random small matrices (no real data)
    rng = np.random.default_rng(0)
    genesA, genesB, genesC = 300, 250, 280
    donors = 80

    def fake_corr(n):
        X = rng.standard_normal((donors, n))
        C = np.corrcoef(X, rowvar=False)
        return pd.DataFrame(np.abs(C),
                            index=[f"g{i}" for i in range(n)],
                            columns=[f"g{i}" for i in range(n)])

    def fake_ct(p, q):
        A = rng.standard_normal((donors, p))
        B = rng.standard_normal((donors, q))
        C = np.corrcoef(A, B, rowvar=False)
        C = C[:p, p:]
        return pd.DataFrame(np.abs(C),
                            index=[f"A{i}" for i in range(p)],
                            columns=[f"B{j}" for j in range(q)])

    TS_corrs = {"Adipose": fake_corr(genesA),
                "Muscle": fake_corr(genesB),
                "Brain":  fake_corr(genesC)}
    CT_corrs = {"Adipose||Muscle": fake_ct(genesA, genesB),
                "Adipose||Brain":  fake_ct(genesA, genesC),
                "Muscle||Brain":   fake_ct(genesB, genesC)}

    out_TS, out_CT = animate_ts_ct_distributions(
        TS_corrs, CT_corrs,
        betas=list(range(1, 11)),
        sample_per_group=50_000,
        bins=30,
        density=True,
        seed=123,
        out_html_prefix="_demo_corr_beta_anim",
        height=420,
    )
    print("Wrote:", out_TS, out_CT)


In [5]:
"""
Animate within-tissue (TS) and cross-tissue (CT) correlation DISTRIBUTIONS
as a function of WGCNA power (beta), using Plotly.

Designed to plug into your existing pipeline that produces TS_corrs and CT_corrs
from `build_ts_ct_correlations(...)` you shared.

Key idea:
- We already have |corr| in [0,1]. WGCNA adjacency is a(|corr|) = |corr|**beta.
- We sample (to keep it fast/memory‑safe), then for each beta compute a histogram
  of the transformed values. We render an animated histogram with a slider.

Outputs:
- Two interactive HTML files saved to disk:
  1) <out_html_prefix>__TS.html   (one facet per tissue)
  2) <out_html_prefix>__CT.html   (one facet per tissue‑pair)

Usage (minimal):

    tissues = ["Adipose", "Muscle", "Brain"]
    files   = [
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Adipose - Subcutaneous_old.csv",
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Muscle - Skeletal_old.csv",
        "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Brain - Cortex_old.csv",
    ]
    TS_expr, TS_corrs, CT_expr, CT_corrs = build_ts_ct_correlations(
        tissue_names=tissues,
        tissue_files=files,
        sd_quantile=0.0,
        max_genes_per_tissue=5000,
        cor_method="pearson",
        ct_mode="pairwise",
        show_progress=False
    )

    from animate_corr_distributions import animate_ts_ct_distributions
    out_TS, out_CT = animate_ts_ct_distributions(
        TS_corrs, CT_corrs,
        betas=list(range(1, 21)),     # 1..20
        sample_per_group=200_000,     # downsample per tissue / pair (keeps files light)
        bins=40,
        density=True,
        seed=42,
        out_html_prefix="corr_beta_anim"
    )
    print("Saved:", out_TS, out_CT)

Notes:
- CT matrices can be huge (|Gi|×|Gj|). Sampling is **strongly** recommended.
- We fix histogram bin edges across all frames so y‑axes don’t jump.
- If your corr matrices are *not* absolute, set `ABS_INPUT=False` below.

"""
from __future__ import annotations
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple
from pathlib import Path
import numpy as np
import pandas as pd

import plotly.graph_objects as go
from plotly.subplots import make_subplots

# ---------------------------- Config ----------------------------
ABS_INPUT = True  # set to False if your TS/CT corr matrices are signed

# --------------------- Helpers: flatten & sample -----------------

def _rng(seed: Optional[int]) -> np.random.Generator:
    return np.random.default_rng(seed) if seed is not None else np.random.default_rng()


def flatten_upper_triangle(C: pd.DataFrame) -> np.ndarray:
    """Return upper‑triangle (i<j) **vector** of correlations from a square matrix.
    Drops NaNs. If ABS_INPUT=True, assumes C already |corr|; else applies abs.
    """
    m = C.to_numpy(copy=False)
    n = m.shape[0]
    iu = np.triu_indices(n, k=1)
    v = m[iu]
    if not ABS_INPUT:
        v = np.abs(v)
    v = v[np.isfinite(v)]
    # Keep in [0,1]
    v = v[(v >= 0) & (v <= 1)]
    return v.astype(np.float64, copy=False)


def flatten_rect(C: pd.DataFrame) -> np.ndarray:
    """Return **all** values from a rectangular CT matrix as a vector. Drops NaNs.
    If ABS_INPUT=False, applies abs. Clips to [0,1]."""
    v = C.to_numpy(copy=False).ravel()
    if not ABS_INPUT:
        v = np.abs(v)
    v = v[np.isfinite(v)]
    v = v[(v >= 0) & (v <= 1)]
    return v.astype(np.float64, copy=False)


def sample_vec(v: np.ndarray, max_n: Optional[int], seed: Optional[int]) -> np.ndarray:
    """Downsample vector `v` without replacement to length <= max_n."""
    if (max_n is None) or (v.size <= max_n):
        return v
    rng = _rng(seed)
    idx = rng.choice(v.size, size=max_n, replace=False)
    return v[idx]

# ----------------------- Histograms per beta ---------------------

def compute_histograms_for_betas(
    values_by_group: Mapping[str, np.ndarray],
    betas: Sequence[float],
    bins: int = 40,
    density: bool = True,
) -> Tuple[Dict[float, Dict[str, np.ndarray]], np.ndarray]:
    """For each beta, compute histogram counts for each group.

    Returns
    -------
    frame_counts : dict[beta -> dict[group -> counts]]
    bin_edges    : np.ndarray of bin edges in [0,1]
    """
    # Common bin edges across all frames
    bin_edges = np.linspace(0.0, 1.0, bins + 1)

    frame_counts: Dict[float, Dict[str, np.ndarray]] = {}
    bin_widths = np.diff(bin_edges)

    for b in betas:
        frame_counts[b] = {}
        for g, v in values_by_group.items():
            if v.size == 0:
                frame_counts[b][g] = np.zeros(bins, dtype=float)
                continue
            w = np.power(v, b)
            counts, _ = np.histogram(w, bins=bin_edges, range=(0.0, 1.0))
            if density:
                n = w.size
                # probability density: counts / (n * bin_width)
                counts = counts.astype(float) / (n * bin_widths)
            frame_counts[b][g] = counts.astype(float)
    return frame_counts, bin_edges

# ------------------- Animated subplot figure --------------------

def make_animated_hist_subplots(
    frame_counts: Mapping[float, Mapping[str, np.ndarray]],
    bin_edges: np.ndarray,
    subplot_titles: Sequence[str],
    title: str,
    height: int = 400,
) -> go.Figure:
    """Create a subplot figure (1 row × N cols), one histogram per group,
    animated over frames keyed by beta.
    """
    groups = list(subplot_titles)
    n = len(groups)
    bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2.0

    palette = ["#1f77b4", "#d62728", "#2ca02c", "#9467bd", "#8c564b",
               "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]
    edge_color = "#2b2b2b"
    bar_opacity = 0.95

    fig = make_subplots(rows=1, cols=n, shared_yaxes=True, horizontal_spacing=0.06,
                        subplot_titles=tuple(groups))

    # Initial traces at the first beta
    betas_sorted = sorted(frame_counts.keys(), key=float)
    b0 = betas_sorted[0]

    for j, g in enumerate(groups, start=1):
        y0 = frame_counts[b0].get(g, np.zeros_like(bin_centers))
        color = palette[(j - 1) % len(palette)]

        fig.add_trace(
            go.Bar(x=bin_centers, y=y0, name=g, showlegend=False,
                                   marker=dict(color=color, line=dict(color=edge_color, width=1.2)),
                                   opacity=bar_opacity),
            row=1, col=j
        )

    # Frames: one per beta, update each subplot trace in order
    frames = []
    for b in betas_sorted:
        data = []
        for j, g in enumerate(groups, start=1):
                y = frame_counts[b].get(g, np.zeros_like(bin_centers))
                color = palette[(j - 1) % len(palette)]
                data.append(
                    go.Bar(
                        x=bin_centers,
                        y=y,
                        marker=dict(color=color, line=dict(color=edge_color, width=1.2)),
                        opacity=bar_opacity,
                        showlegend=False,
                    )
                )
        frames.append(go.Frame(name=f"beta={b}", data=data))

    fig.update(frames=frames)

    # Slider & play controls
    steps = []
    for k, b in enumerate(betas_sorted):
        steps.append({
            "args": [[f"beta={b}"], {"frame": {"duration": 0, "redraw": True},
                                      "mode": "immediate", "transition": {"duration": 0}}],
            "label": str(b),
            "method": "animate",
        })

    fig.update_layout(
        title=title,
        height=height,
        template="plotly_white",
        plot_bgcolor="white",
        bargap=0.05,
        xaxis_title="adjacency = |corr|^beta",
        yaxis_title="density" if True else "count",
        updatemenus=[{
            "type": "buttons",
            "showactive": True,
            "x": 1.05,
            "y": 1.15,
            "xanchor": "right",
            "yanchor": "top",
            "buttons": [
                {"label": "▶ Play", "method": "animate",
                 "args": [None, {"frame": {"duration": 300, "redraw": True},
                                   "transition": {"duration": 0},
                                   "fromcurrent": True}]},
                {"label": "⏸ Pause", "method": "animate",
                 "args": [[None], {"frame": {"duration": 0, "redraw": False},
                                    "mode": "immediate"}]}
            ]
        }],
        sliders=[{
            "active": 0,
            "y": -0.08,
            "x": 0.5,
            "len": 0.9,
            "xanchor": "center",
            "yanchor": "top",
            "pad": {"b": 10, "t": 30},
            "steps": steps
        }]
    )

    # Lock axes to [0,1] on x for all subplots
    for i in range(n):
        fig.update_xaxes(range=[0, 1], showline=True, linewidth=1, linecolor="#2b2b2b",
                         gridcolor="#dddddd", row=1, col=i+1)
        fig.update_yaxes(showline=True, linewidth=1, linecolor="#2b2b2b",
                         gridcolor="#dddddd", row=1, col=i+1)
    return fig

# ------------------- Public API: main function -------------------

def animate_ts_ct_distributions(
    TS_corrs: Mapping[str, pd.DataFrame],
    CT_corrs: Mapping[str, pd.DataFrame],
    betas: Sequence[float] = tuple(range(1, 21)),
    sample_per_group: Optional[int] = 200_000,
    bins: int = 40,
    density: bool = True,
    seed: Optional[int] = 0,
    out_html_prefix: str = "corr_beta_anim",
    height: int = 420,
) -> Tuple[str, str]:
    """Create two animated Plotly histograms (TS and CT) over beta.

    Parameters
    ----------
    TS_corrs : dict[tissue -> square (genes×genes) DataFrame]
    CT_corrs : dict["Ti||Tj" -> rectangular (Gi×Gj) DataFrame]
    betas : iterable of beta values (>=1 recommended)
    sample_per_group : downsample limit per tissue/pair (None = no limit)
    bins : histogram bins
    density : normalize to probability density (True) or raw counts (False)
    seed : RNG seed for reproducible sampling
    out_html_prefix : file prefix for saved HTML files
    height : figure height in pixels

    Returns
    -------
    (ts_html_path, ct_html_path)
    """
    rng = _rng(seed)

    # --- Prepare TS values ---
    ts_values: Dict[str, np.ndarray] = {}
    for tname, C in TS_corrs.items():
        if C is None or C.size == 0:
            ts_values[tname] = np.array([], dtype=float)
            continue
        v = flatten_upper_triangle(C)
        v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31 - 1))
        ts_values[tname] = v

    # --- Prepare CT values ---
    ct_values: Dict[str, np.ndarray] = {}
    for pair, C in CT_corrs.items():
        if C is None or C.size == 0:
            ct_values[pair] = np.array([], dtype=float)
            continue
        v = flatten_rect(C)
        v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31 - 1))
        ct_values[pair] = v

    # Histograms over betas (TS)
    ts_frame_counts, bin_edges = compute_histograms_for_betas(
        ts_values, betas=betas, bins=bins, density=density
    )
    # Build animated subplot figure for TS
    ts_groups = list(ts_values.keys())
    if len(ts_groups) == 0:
        ts_groups = ["No TS"]
        ts_frame_counts = {b: {"No TS": np.zeros(bins)} for b in betas}
    fig_TS = make_animated_hist_subplots(
        ts_frame_counts, bin_edges, subplot_titles=ts_groups,
        title="Within‑tissue adjacency distribution vs β", height=height
    )

    # Histograms over betas (CT)
    ct_frame_counts, bin_edges_ct = compute_histograms_for_betas(
        ct_values, betas=betas, bins=bins, density=density
    )
    ct_groups = list(ct_values.keys())
    if len(ct_groups) == 0:
        ct_groups = ["No CT"]
        ct_frame_counts = {b: {"No CT": np.zeros(bins)} for b in betas}
    fig_CT = make_animated_hist_subplots(
        ct_frame_counts, bin_edges_ct, subplot_titles=ct_groups,
        title="Cross‑tissue adjacency distribution vs β", height=height
    )

    # Save HTML files
    out_TS = f"{out_html_prefix}__TS.html"
    out_CT = f"{out_html_prefix}__CT.html"
    fig_TS.write_html(out_TS, include_plotlyjs="cdn", auto_play=False)
    fig_CT.write_html(out_CT, include_plotlyjs="cdn", auto_play=False)

    return out_TS, out_CT


# ------------------- If run as a script -------------------
if __name__ == "__main__":
    # Minimal self‑test with random small matrices (no real data)
    rng = np.random.default_rng(0)
    genesA, genesB, genesC = 300, 250, 280
    donors = 80

    def fake_corr(n):
        X = rng.standard_normal((donors, n))
        C = np.corrcoef(X, rowvar=False)
        return pd.DataFrame(np.abs(C),
                            index=[f"g{i}" for i in range(n)],
                            columns=[f"g{i}" for i in range(n)])

    def fake_ct(p, q):
        A = rng.standard_normal((donors, p))
        B = rng.standard_normal((donors, q))
        C = np.corrcoef(A, B, rowvar=False)
        C = C[:p, p:]
        return pd.DataFrame(np.abs(C),
                            index=[f"A{i}" for i in range(p)],
                            columns=[f"B{j}" for j in range(q)])

    TS_corrs = {"Adipose": fake_corr(genesA),
                "Muscle": fake_corr(genesB),
                "Brain":  fake_corr(genesC)}
    CT_corrs = {"Adipose||Muscle": fake_ct(genesA, genesB),
                "Adipose||Brain":  fake_ct(genesA, genesC),
                "Muscle||Brain":   fake_ct(genesB, genesC)}

    out_TS, out_CT = animate_ts_ct_distributions(
        TS_corrs, CT_corrs,
        betas=list(range(1, 11)),
        sample_per_group=50_000,
        bins=30,
        density=True,
        seed=123,
        out_html_prefix="_demo_corr_beta_anim",
        height=420,
    )
    print("Wrote:", out_TS, out_CT)


# ============ NEW ============
# Build a single HTML report with *multiple datasets*, each showing
# TS & CT animated distributions vs beta.
from typing import Callable

def build_multi_dataset_corr_report(
    datasets_tissues: Mapping[str, Sequence[str]],
    datasets_files: Mapping[str, Mapping[str, str] | Sequence[str]],
    *,
    betas: Sequence[float] = tuple(range(1, 21)),
    sample_per_group: Optional[int] = 200_000,
    bins: int = 40,
    density: bool = True,
    seed: Optional[int] = 42,
    sd_quantile: float = 0.0,
    max_genes_per_tissue: Optional[int] = 5000,
    cor_method: str = "pearson",
    ct_mode: str = "pairwise",
    out_html_path: str = "corr_beta_MULTI_REPORT.html",
    height: int = 420,
    builder_fn: Optional[Callable[..., Tuple[Dict[str, pd.DataFrame], Dict[str, pd.DataFrame], Dict[str, Tuple[pd.DataFrame, pd.DataFrame]], Dict[str, pd.DataFrame]]]] = None,
) -> str:
    """Create a single, consolidated HTML report with animated histograms
    for **multiple datasets**.

    Parameters
    ----------
    datasets_tissues : dict[dataset -> [tissue names]]
    datasets_files   : dict[dataset -> list-of-file-paths **aligned** to tissues
                             OR dict[tissue->file]]
    betas, sample_per_group, bins, density, seed : animation/histogram params
    sd_quantile, max_genes_per_tissue, cor_method, ct_mode : correlation builder params
    out_html_path : output HTML path for the consolidated report
    height : plot height in pixels
    builder_fn : optional function compatible with build_ts_ct_correlations;
                 if None, the function tries to find `build_ts_ct_correlations`
                 in globals().

    Returns
    -------
    Path to the written HTML report.

    Example
    -------
    datasets_tissues = {
        "DatasetA": ["Adipose", "Muscle", "Brain"],
        "DatasetB": ["Adipose", "Muscle", "Brain"],
        "DatasetC": ["Adipose", "Muscle", "Brain"],
    }
    datasets_files = {
        # aligned lists:
        "DatasetA": ["/path/Adipose.csv", "/path/Muscle.csv", "/path/Brain.csv"],
        # OR mapping per tissue:
        "DatasetB": {"Adipose": "/path/A.csv", "Muscle": "/path/M.csv", "Brain": "/path/B.csv"},
        "DatasetC": {"Adipose": "/path/A2.csv", "Muscle": "/path/M2.csv", "Brain": "/path/B2.csv"},
    }
    build_multi_dataset_corr_report(datasets_tissues, datasets_files)
    """
    import datetime as _dt
    import html as _html
    import plotly.io as pio

    if builder_fn is None:
        try:
            builder_fn = globals()["build_ts_ct_correlations"]
        except KeyError:
            raise RuntimeError(
                "build_ts_ct_correlations not found. Pass builder_fn=... or define it in scope.")

    rng = _rng(seed)

    # Helper to unify files structure per dataset
    def _files_for(ds: str, tissues: Sequence[str]) -> List[str]:
        fobj = datasets_files[ds]
        if isinstance(fobj, dict):
            missing = [t for t in tissues if t not in fobj]
            if missing:
                raise KeyError(f"Dataset '{ds}' missing files for tissues: {missing}")
            return [fobj[t] for t in tissues]
        else:
            flist = list(fobj)  # type: ignore[arg-type]
            if len(flist) != len(tissues):
                raise ValueError(
                    f"Dataset '{ds}': files list length {len(flist)} != tissues length {len(tissues)}")
            return flist

    # Build all figures per dataset
    html_sections: List[str] = []
    first_fig = True

    for ds, tissues in datasets_tissues.items():
        if ds not in datasets_files:
            raise KeyError(f"Dataset '{ds}' missing in datasets_files")
        tnames = list(tissues)
        tfiles = _files_for(ds, tnames)

        # Compute correlations
        TS_expr, TS_corrs, CT_expr, CT_corrs = builder_fn(
            tissue_names=tnames,
            tissue_files=tfiles,
            sd_quantile=sd_quantile,
            max_genes_per_tissue=max_genes_per_tissue,
            cor_method=cor_method,
            ct_mode=ct_mode,
            show_progress=False,
            verbose=False,
        )

        # Prepare values (sampled)
        ts_values: Dict[str, np.ndarray] = {}
        for tname, C in TS_corrs.items():
            v = flatten_upper_triangle(C) if (C is not None and C.size) else np.array([], float)
            v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31-1))
            ts_values[tname] = v

        ct_values: Dict[str, np.ndarray] = {}
        for pair, C in CT_corrs.items():
            v = flatten_rect(C) if (C is not None and C.size) else np.array([], float)
            v = sample_vec(v, sample_per_group, seed=rng.integers(0, 2**31-1))
            ct_values[pair] = v

        # Histograms + figures
        ts_frame_counts, ts_bins = compute_histograms_for_betas(ts_values, betas, bins=bins, density=density)
        ct_frame_counts, ct_bins = compute_histograms_for_betas(ct_values, betas, bins=bins, density=density)

        ts_groups = list(ts_values.keys()) or ["No TS"]
        if not ts_values:
            ts_frame_counts = {b: {"No TS": np.zeros(bins)} for b in betas}
        ct_groups = list(ct_values.keys()) or ["No CT"]
        if not ct_values:
            ct_frame_counts = {b: {"No CT": np.zeros(bins)} for b in betas}

        fig_TS = make_animated_hist_subplots(ts_frame_counts, ts_bins, ts_groups,
                                             title=f"{ds} — Within‑tissue adjacency vs β", height=height)
        fig_CT = make_animated_hist_subplots(ct_frame_counts, ct_bins, ct_groups,
                                             title=f"{ds} — Cross‑tissue adjacency vs β", height=height)

        # Convert to HTML snippets; include plotly.js only once
        ts_html = pio.to_html(fig_TS, include_plotlyjs="cdn" if first_fig else False, full_html=False,
                              auto_play=False)
        ct_html = pio.to_html(fig_CT, include_plotlyjs=False, full_html=False, auto_play=False)
        first_fig = False

        # Section HTML
        safe_ds = _html.escape(ds)
        sec = f"""
<section>
  <h2 style=\"margin-top:2rem\">Dataset: {safe_ds}</h2>
  <div style=\"margin: 1rem 0\">{ts_html}</div>
  <div style=\"margin: 1rem 0\">{ct_html}</div>
</section>
"""
        html_sections.append(sec)

    # Wrap with a simple document shell
    title = "Correlation Distributions vs β — Multi‑Dataset Report"
    when = _dt.datetime.now().strftime("%Y-%m-%d %H:%M")
    toc = "".join(
        f"<li><a href='#sec_{i}'>{_html.escape(ds)}</a></li>" for i, ds in enumerate(datasets_tissues.keys(), 1)
    )

    body = f"""
<!doctype html>
<html>
<head>
  <meta charset=\"utf-8\"/>
  <meta name=\"viewport\" content=\"width=device-width, initial-scale=1\"/>
  <title>{_html.escape(title)}</title>
  <style>
    body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, 'Fira Sans', 'Droid Sans', 'Helvetica Neue', Arial, sans-serif; margin: 20px; }}
    h1 {{ font-size: 1.6rem; margin-bottom: 0.2rem; }}
    h2 {{ font-size: 1.25rem; }}
    .meta {{ color: #555; font-size: 0.9rem; margin-bottom: 1rem; }}
  </style>
</head>
<body>
  <h1>{_html.escape(title)}</h1>
  <div class=\"meta\">Generated: {when}</div>
  {''.join(html_sections)}
</body>
</html>
"""

    Path(out_html_path).write_text(body, encoding="utf-8")
    return str(out_html_path)


Wrote: _demo_corr_beta_anim__TS.html _demo_corr_beta_anim__CT.html


In [7]:

# Minimal usage example (edit to your paths):
tissues = ["Adipose", "Muscle", "Brain"]
files   = [
    "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Adipose - Subcutaneous_old.csv",
    "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Muscle - Skeletal_old.csv",
    "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Brain - Cortex_old.csv",
]
TS_expr, TS_corrs, CT_expr, CT_corrs = build_ts_ct_correlations(
    tissue_names=tissues,
    tissue_files=files,
    sd_quantile=0.0,
    max_genes_per_tissue=5000,
    cor_method="pearson",
    ct_mode="pairwise",  # or 'complete' for speed (requires complete rows)
    show_progress=False
)
# Optional: assemble adjacency like your R function
# A, blocks = correlations_to_adjacency(TS_corrs, CT_corrs)
# A.to_parquet("adjacency_from_corr.parquet")
pass


In [11]:
# 2) התקן plotly אם צריך: pip install plotly
# 3) הפעל את פונקציית האנימציה:
out_TS, out_CT = animate_ts_ct_distributions(
    TS_corrs, CT_corrs,
    betas=list(range(1, 21)),   # טווח β
    sample_per_group=200_000,   # דגימה לכל רקמה/זוג רקמות
    bins=40,                    # מספר בינינים
    density=True,               # צפיפות (לא ספירה גולמית)
    seed=42,
    out_html_prefix="corr_beta_anim"
)
print("Saved:", out_TS, out_CT)

Saved: corr_beta_anim__TS.html corr_beta_anim__CT.html


In [8]:
rosmap_files = [
    "/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/ROSMAP_fixed_AC.csv",
    "/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/ROSMAP_fixed_MF_BA9_BA46.csv",
    "/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/ROSMAP_fixed_PCG_BA23.csv"
]

tissues = ["AC", "MF_BA9_BA46", "PCG_BA23"]

TS_expr, TS_corrs, CT_expr, CT_corrs = build_ts_ct_correlations(
    tissue_names=tissues,
    tissue_files=rosmap_files,
    sd_quantile=0.0,
    max_genes_per_tissue=500000,
    cor_method="pearson",
    ct_mode="pairwise",  # or 'complete' for speed (requires complete rows)
    show_progress=False
)

In [9]:
out_TS, out_CT = animate_ts_ct_distributions(
    TS_corrs, CT_corrs,
    betas=list(range(1, 21)), 
    sample_per_group=200_000,  
    bins=1000,                    
    density=True,            
    seed=42,
    out_html_prefix="rosmap_corr_beta_anim_linux"
)
print("Saved:", out_TS, out_CT)

Saved: rosmap_corr_beta_anim_linux__TS.html rosmap_corr_beta_anim_linux__CT.html


In [20]:

datasets_tissues = {
    "YoungGTEx": ["Adipose", "Muscle", "Brain"],
    "OldGTEx": ["Adipose", "Muscle", "Brain"],
    "ROSMAP": ["AC", "MF_BA9_BA46", "PCG_BA23"],
}

# אפשר או רשימות מיושרות לפי סדר הרקמות...
datasets_files = {
    "YoungGTEx": [
         "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Adipose - Subcutaneous_young.csv",
         "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Muscle - Skeletal_young.csv",
         "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Brain - Cortex_young.csv",
     ],
     "OldGTEx": {
         "Adipose": "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Adipose - Subcutaneous_old.csv",
         "Muscle":  "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Muscle - Skeletal_old.csv",
         "Brain":   "/Users/edeneldar/CoExpression_ReProduction/old_outputs/Brain - Cortex_old.csv",
     },
    "ROSMAP": {
        "AC": "/Users/edeneldar/CoExpression_ReProduction/ROSMAP_fixed_AC.csv",
        "MF_BA9_BA46":  "/Users/edeneldar/CoExpression_ReProduction/ROSMAP_fixed_MF_BA9_BA46.csv",
        "PCG_BA23":   "/Users/edeneldar/CoExpression_ReProduction/ROSMAP_fixed_PCG_BA23.csv",
    }
}

report_path = build_multi_dataset_corr_report(
    datasets_tissues, datasets_files,
    betas=list(range(1, 5)),
    sample_per_group=600_000,  
    bins=1000,
    density=True,
    seed=42,
    sd_quantile=0.0,
    max_genes_per_tissue=500000,
    cor_method="pearson",
    ct_mode="pairwise",
    out_html_path="corr_beta_MULTI_REPORT_1000_bins.html",
    height=420,
)

print("Wrote:", report_path)


Wrote: corr_beta_MULTI_REPORT_1000_bins.html


In [11]:
rosmap_details = pd.read_csv(r"/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/xwgcna_rosmap_autobeta_run4_Cluster_details.tsv", sep="\t")
rosmap_details5 = pd.read_csv(r"/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/xwgcna_rosmap_autobeta_run5_Cluster_details.txt", sep="\t")

FileNotFoundError: [Errno 2] No such file or directory: '/media/psylab-6028/DATA/Eden/CoExpression_ReProduction/xwgcna_rosmap_autobeta_run4_Cluster_details.tsv'

In [2]:
young_details = pd.read_csv(r"/Users/edeneldar/CoExpression_ReProduction/xwgcna_young_original_run9_Clusters_details.txt", sep="\t")
old_details = pd.read_csv(r"/Users/edeneldar/CoExpression_ReProduction/xwgcna_old_original_run9_Clusters_details.txt", sep="\t")

In [3]:
rosmap_details.columns

Index(['Cluster ID', 'Cluster Size', 'Cluster Type', 'Cluster Tissues', 'AC',
       'MF', 'PCG', 'Dominant Tissue'],
      dtype='object')

In [6]:
def plot_ct_cluster_histograms(
    details: pd.DataFrame,
    dataset_label: str,
    tissues: list[str] | None = None,
    *,
    tissue_col: str = "Tissue",
    count_cols: dict[str, str] | None = None,
    out_dir: str | None = None,
) -> None:
    """
    Save per-CT-cluster histograms of tissue composition.

    Modes:
      1) If count_cols is provided: sum details[count_cols[t]] per cluster.
      2) Else if all tissues exist as columns: sum those columns (indicator/count).
      3) Else if tissue_col exists: count rows per tissue value.
    """
    import matplotlib.pyplot as plt
    import os
    import re

    if "Cluster Type" not in details.columns or "Cluster ID" not in details.columns:
        raise KeyError("Expected columns: 'Cluster Type' and 'Cluster ID'")

    ct = details[details["Cluster Type"] == "CT"].copy()
    if ct.empty:
        print("No CT clusters found. Nothing to plot.")
        return

    # Infer tissues if not provided
    if tissues is None:
        if count_cols:
            tissues = list(count_cols.keys())
        elif tissue_col in ct.columns:
            tissues = sorted(map(str, ct[tissue_col].dropna().unique().tolist()))
        else:
            # Heuristic: treat any non-metadata columns as potential tissue indicator columns
            meta = {"Cluster ID", "Cluster Type", tissue_col}
            candidates = [c for c in ct.columns if c not in meta]
            # keep only numeric columns
            num_cols = [c for c in candidates if pd.api.types.is_numeric_dtype(ct[c])]
            tissues = num_cols

    tissues = list(map(str, tissues))

    # Determine counting strategy
    # Strategy A: explicit mapping
    if count_cols is not None:
        col_map = {str(k): v for k, v in count_cols.items()}
        missing = [col_map[t] for t in tissues if col_map.get(t) not in ct.columns]
        if missing:
            raise KeyError(f"Missing count columns in dataframe: {missing}")
        def get_counts(df):
            return [float(df[col_map[t]].sum()) for t in tissues]
    # Strategy B: columns named as tissues exist
    elif all(t in ct.columns for t in tissues):
        def get_counts(df):
            # sum indicator/count columns
            return [float(df[t].sum()) for t in tissues]
    # Strategy C: categorical tissue_col
    elif tissue_col in ct.columns:
        def get_counts(df):
            vc = df[tissue_col].astype(str).value_counts()
            return [float(vc.get(t, 0.0)) for t in tissues]
    else:
        raise KeyError("Cannot infer tissue counts. Provide count_cols=... or a tissue_col present in dataframe.")

    # Output directory
    if out_dir is None:
        out_dir = f"ct_cluster_histograms_{dataset_label}"
    os.makedirs(out_dir, exist_ok=True)

    # Simple color palette
    palette = ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd",
               "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf"]

    # Helper to make safe filenames
    def _safe(s: str) -> str:
        s = str(s)
        s = re.sub(r"[^\w.-]+", "_", s)
        return s.strip("_")[:100]

    # Plot per cluster
    for cluster_id, gdf in ct.groupby("Cluster ID"):
        counts = get_counts(gdf)
        plt.figure(figsize=(8, 5))
        colors = [palette[i % len(palette)] for i in range(len(tissues))]
        plt.bar(tissues, counts, color=colors, edgecolor="#2b2b2b")
        plt.title(f"CT Cluster: {cluster_id}")
        plt.xlabel("Tissue")
        plt.ylabel("Gene Count")
        plt.xticks(rotation=30, ha="right")
        plt.tight_layout()
        fname = os.path.join(out_dir, f"ct_cluster_{_safe(dataset_label)}_{_safe(cluster_id)}.png")
        plt.savefig(fname, dpi=160)
        plt.close()

In [8]:
plot_ct_cluster_histograms(rosmap_details)

In [5]:
young_details.columns

Index(['Cluster ID', 'Cluster Size', 'Cluster Type', 'Cluster Tissues',
       'Adipose', 'Brain', 'Muscle', 'Dominant Tissue'],
      dtype='object')

In [8]:
tissues = ['Brain', 'Adipose', 'Muscle']

In [9]:
plot_ct_cluster_histograms(young_details,'young',tissues)
plot_ct_cluster_histograms(old_details,'old',tissues)
