In [None]:
# ============================================================
# PSEUDOBULK (por paciente) + DE Healthy vs Cirrhosis + VOLCANO
# (EXTENDIDO: permite DE por Level2_final para subpoblaciones seleccionadas)
#
# - Usa el objeto FILTRADO: TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad (RBC-out)
# - Memory-safe: lee en backed="r" y procesa genes por CHUNKS
#
# Modo recomendado para el paso "7) DE a nivel Level2":
#   ANALYSIS_LEVEL = "Level2_final"
#
# Outputs (por cada grupo analizado):
#   * summary_tables_final/pseudobulk_{ANALYSIS_LEVEL}_{tag}_mean_log1p_10k.csv
#   * summary_tables_final/DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.csv
#   * figures_final/Volcano_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.png
#   * ..._FOR_FIGURE.csv + ..._FOR_FIGURE.png (plot con TODOS los genes + highlights + labels TOP20)
#   * tablas top10 agregadas (FOR_FIGURE / FOR_REPORT / FOR_REPORT_LINFO_CLEAN)
# ============================================================

from pathlib import Path
import json
import re
from typing import Optional, List, Dict, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

# stats
try:
    from scipy.stats import ttest_ind
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False


# -----------------------------
# Paths
# -----------------------------
NOTEBOOK_DIR = Path.cwd()

def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data_processed").exists():
            return p
    raise FileNotFoundError(f"No encuentro 'data_processed' subiendo desde: {start}")

PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
DATA_PROCESSED = PROJECT_ROOT / "data_processed"
IN_PATH = DATA_PROCESSED / "TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad"

OUT_SUMMARY = PROJECT_ROOT / "summary_tables_final"
OUT_FIG     = PROJECT_ROOT / "figures_final"
OUT_SUMMARY.mkdir(exist_ok=True)
OUT_FIG.mkdir(exist_ok=True)

# Mapa Level2_final (Conv_T_other -> CD4_Memory, etc.)
MAP_PATH = OUT_SUMMARY / "Level2_final_map.json"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("IN_PATH     :", IN_PATH)
print("OUT_SUMMARY :", OUT_SUMMARY)
print("OUT_FIG     :", OUT_FIG)
print("MAP_PATH    :", MAP_PATH)


# -----------------------------
# Parámetros
# -----------------------------
LAYER = "log1p_10k"

# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
# CAMBIO CLAVE (Paso 7): análisis por Level2_final
#   - "Level2_final": corre DE por subtipos Level2_final seleccionados
#   - "Level1_refined": modo legacy (si lo necesitas)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
ANALYSIS_LEVEL = "Level2_final"   # "Level2_final" (recomendado) o "Level1_refined"

# Targets a correr (si ANALYSIS_LEVEL == "Level2_final")
LEVEL2_FINAL_TO_RUN = [
    # B / Plasma
    "B_Naive", "B_Memory", "B_Activated", "B_Atypical", "Plasma",
    # T
    "CD4_Naive", "CD4_Memory", "CD8_Naive", "CD8_Effector_Cytotoxic", "Treg", "MAIT", "GammaDelta_T",
    # NK
    "NK",
    # Mono / DC
    "Classical_Mono", "NonClassical_Mono", "ISG_Myeloid",
    "cDC1", "cDC2", "DC4", "aDC",
    # pDC
    "pDC",
]

# Targets legacy (si ANALYSIS_LEVEL == "Level1_refined")
LEVEL1_TO_RUN = ["T", "Mono", "NK", "B", "DC"]

# Genes
GENE_MODE = "HVG"     # "HVG" recomendado
MAX_GENES = 2000
GENE_CHUNK = 200

# Estadística / umbrales
PSEUDOCOUNT = 1e-6
ALPHA_FDR = 0.05

# Filtrado pacientes por grupo (subtipo/celltype)
MIN_CELLS_PER_PATIENT = 20
MIN_PATIENTS_PER_GROUP = 4          # min pacientes por disease
MIN_PATIENTS_PER_TARGET = 6         # min pacientes totales con suficientes células para ese target

# RBC-out robusto (por si apareciera por accidente en obs)
EXCLUDE_LEVEL1REFINED = {"RBC"}
EXCLUDE_LEVEL2 = {"RBC"}

# Volcano “FOR_FIGURE” (resaltado/etiquetas)
MIN_ABS_LOG2FC = 0.5
MIN_MEAN_LOG1P = 0.20
MIN_FRAC_PATIENTS = 0.50

TOP_N_LABEL = 20

# Blacklist opcional para linfoides: SOLO afecta a lo que resaltas/etiquetas (no al fondo)
AMBIENT_MYELOID = {"S100A8","S100A9","S100A12","LYZ","LST1","TREM1","CXCL8"}

# Columna esperada de fold-change (para evitar KeyErrors por CSVs viejos)
FC_COL = "log2FC_Healthy_vs_Cirrhosis"


# -----------------------------
# Helpers
# -----------------------------
def bh_fdr(pvals: np.ndarray) -> np.ndarray:
    """Benjamini-Hochberg FDR, devuelve q-values"""
    p = np.asarray(pvals, dtype=float)
    n = p.size
    order = np.argsort(p)
    ranked = p[order]
    q = ranked * n / (np.arange(1, n + 1))
    q = np.minimum.accumulate(q[::-1])[::-1]
    out = np.empty_like(q)
    out[order] = np.clip(q, 0, 1)
    return out

def safe_tag(name: str) -> str:
    """Tag seguro para nombres de archivo."""
    s = str(name).strip()
    s = s.replace(" ", "_")
    s = re.sub(r"[^A-Za-z0-9_\-\.]+", "_", s)
    s = re.sub(r"_+", "_", s)
    return s

def get_gene_list(adata_b) -> List[str]:
    if GENE_MODE == "HVG" and "highly_variable" in adata_b.var.columns:
        hv = adata_b.var["highly_variable"].values.astype(bool)
        genes = adata_b.var_names[hv].tolist()
        if len(genes) == 0:
            print("[WARN] highly_variable existe pero vacío; uso primeras MAX_GENES.")
            genes = adata_b.var_names[:MAX_GENES].tolist()
        else:
            genes = genes[:MAX_GENES]
        return genes
    return adata_b.var_names[:MAX_GENES].tolist()

def ensure_fc_col(df: pd.DataFrame, where: str = "") -> pd.DataFrame:
    """
    Garantiza que exista FC_COL en df.
    - Si existe, ok.
    - Si hay 1 alternativa típica, renombra.
    - Si no, lanza KeyError con info.
    """
    if FC_COL in df.columns:
        return df

    # alternativas típicas (por si CSV viejo)
    candidates = []
    for c in df.columns:
        cl = str(c).lower()
        if "log2fc" in cl or "logfc" in cl or ("fold" in cl and "log" in cl):
            candidates.append(c)

    if len(candidates) == 1:
        print(f"[WARN] {where}: renombrando columna '{candidates[0]}' -> '{FC_COL}'")
        return df.rename(columns={candidates[0]: FC_COL})

    raise KeyError(f"[{where}] No encuentro '{FC_COL}'. Columnas={df.columns.tolist()}")

def volcano_plot_base(df_de: pd.DataFrame, out_png: Path, title: str, alpha_fdr: float = 0.05):
    """Volcano base (rápido)."""
    df_de = ensure_fc_col(df_de, where=f"volcano_plot_base({out_png.name})")

    x = df_de[FC_COL].to_numpy(dtype=float)
    q = df_de["FDR"].to_numpy(dtype=float)
    y = -np.log10(np.clip(q, 1e-300, 1.0))

    fig, ax = plt.subplots(figsize=(7.5, 5.5))
    ax.scatter(x, y, s=10, alpha=0.8)
    ax.axhline(-np.log10(alpha_fdr), linestyle="--", linewidth=1)
    ax.axvline(0, linestyle=":", linewidth=1)
    ax.set_xlabel("log2FC (Healthy vs Cirrhosis) [pseudobulk per patient]")
    ax.set_ylabel("-log10(FDR)")
    ax.set_title(title)
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close(fig)

def volcano_for_figure(
    df_all: pd.DataFrame,
    out_png: Path,
    title: str,
    alpha_fdr: float = 0.05,
    highlight_mask: Optional[np.ndarray] = None,
    genes_to_label: Optional[List[str]] = None,
    abs_log2fc_line: Optional[float] = None,
):
    """
    Volcano plot:
      - plotea TODOS los genes (df_all)
      - colorea por estado (sig up / sig down / no sig)
      - resalta (opcional) un subconjunto highlight_mask
      - etiqueta (opcional) genes_to_label
    Requisitos de columnas: gene, FC_COL, FDR
    """
    df_all = ensure_fc_col(df_all, where=f"volcano_for_figure({out_png.name})")

    x = df_all[FC_COL].to_numpy(dtype=float)
    q = df_all["FDR"].to_numpy(dtype=float)
    y = -np.log10(np.clip(q, 1e-300, 1.0))

    sig = q < alpha_fdr
    up = sig & (x > 0)
    down = sig & (x < 0)
    ns = ~sig

    cycle = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0","C1","C2","C3"])
    col_down = cycle[0 % len(cycle)]
    col_up   = cycle[1 % len(cycle)]
    col_ns   = "0.75"

    fig, ax = plt.subplots(figsize=(7.8, 5.8))
    ax.scatter(x[ns], y[ns], s=10, alpha=0.6, c=col_ns, edgecolors="none", label="Not significant")
    ax.scatter(x[down], y[down], s=12, alpha=0.85, c=col_down, edgecolors="none",
               label=f"FDR<{alpha_fdr} (Higher in Cirrhosis)")
    ax.scatter(x[up], y[up], s=12, alpha=0.85, c=col_up, edgecolors="none",
               label=f"FDR<{alpha_fdr} (Higher in Healthy)")

    if highlight_mask is not None:
        hm = np.asarray(highlight_mask, dtype=bool)
        hm = hm & np.isfinite(x) & np.isfinite(y)
        ax.scatter(
            x[hm], y[hm],
            s=26, alpha=0.95,
            facecolors="none", edgecolors="k", linewidths=0.7,
            label="Highlighted (filters)"
        )

    ax.axhline(-np.log10(alpha_fdr), linestyle="--", linewidth=1)
    ax.axvline(0, linestyle=":", linewidth=1)
    if abs_log2fc_line is not None and abs_log2fc_line > 0:
        ax.axvline(+abs_log2fc_line, linestyle="--", linewidth=0.8)
        ax.axvline(-abs_log2fc_line, linestyle="--", linewidth=0.8)

    ax.set_xlabel("log2FC (Healthy vs Cirrhosis) [pseudobulk per patient]")
    ax.set_ylabel("-log10(FDR)")
    ax.set_title(title)

    if genes_to_label is not None and len(genes_to_label) > 0:
        df_lab = df_all[df_all["gene"].astype(str).isin([str(g) for g in genes_to_label])].copy()
        df_lab = df_lab.sort_values("FDR", ascending=True)

        offsets = [(6, 6), (6, -10), (-18, 6), (-18, -10)]
        for k, (_, r) in enumerate(df_lab.iterrows()):
            gx = float(r[FC_COL])
            gy = -np.log10(max(float(r["FDR"]), 1e-300))
            ox, oy = offsets[k % len(offsets)]
            ax.annotate(
                str(r["gene"]),
                (gx, gy),
                textcoords="offset points",
                xytext=(ox, oy),
                ha="left",
                fontsize=8,
                arrowprops=dict(arrowstyle="-", lw=0.4, alpha=0.6),
            )

    ax.legend(frameon=False, fontsize=8, loc="upper right")
    plt.tight_layout()
    plt.savefig(out_png, dpi=300)
    plt.close(fig)


# ============================================================
# 0) Cargar mapa Level2_final
# ============================================================
if ANALYSIS_LEVEL == "Level2_final":
    if not MAP_PATH.exists():
        raise FileNotFoundError(
            f"No existe MAP_PATH:\n{MAP_PATH}\n"
            "Este notebook requiere Level2_final_map.json (Conv_T_other -> CD4_Memory, etc.)."
        )
    with open(MAP_PATH, "r", encoding="utf-8") as f:
        level2_map = json.load(f)
else:
    level2_map = {}


# ============================================================
# 1) Cargar objeto (backed) + checks + preparar obs/pacientes/genes
# ============================================================
adata_b = sc.read_h5ad(IN_PATH, backed="r")

required_obs = ["patientID", "disease", "Level1_refined"]
if ANALYSIS_LEVEL == "Level2_final":
    required_obs.append("Level2")

for col in required_obs:
    if col not in adata_b.obs.columns:
        adata_b.file.close()
        raise KeyError(f"Falta columna en obs: {col}")

if LAYER not in adata_b.layers.keys():
    adata_b.file.close()
    raise KeyError(f"No existe layer '{LAYER}' en adata.layers. Layers: {list(adata_b.layers.keys())}")

obs_cols = list(dict.fromkeys(required_obs))
obs = adata_b.obs[obs_cols].copy()

obs["patientID"] = obs["patientID"].astype(str)
obs["disease"] = obs["disease"].astype(str)
obs["Level1_refined"] = obs["Level1_refined"].astype(str)

if ANALYSIS_LEVEL == "Level2_final":
    obs["Level2"] = obs["Level2"].astype(str)
    obs["Level2_final"] = obs["Level2"].replace(level2_map).astype(str)

# RBC-out robusto
if (obs["Level1_refined"].isin(EXCLUDE_LEVEL1REFINED)).any():
    print("[WARN] Detectado Level1_refined=RBC en obs. Se excluirá del análisis (RBC-out).")

valid_mask_global = ~obs["Level1_refined"].isin(EXCLUDE_LEVEL1REFINED)

if ANALYSIS_LEVEL == "Level2_final":
    if (obs["Level2_final"].isin(EXCLUDE_LEVEL2)).any():
        print("[WARN] Detectado Level2_final=RBC en obs. Se excluirá del análisis (RBC-out).")
    valid_mask_global = valid_mask_global & (~obs["Level2_final"].isin(EXCLUDE_LEVEL2))

# pacientes y grupos
patient_meta = obs.loc[valid_mask_global].drop_duplicates(["patientID", "disease"]).set_index("patientID")
patients = patient_meta.index.tolist()

print("\nPacientes totales:", len(patients))
print(patient_meta["disease"].value_counts())

# genes
genes = get_gene_list(adata_b)
print("\nGenes usados:", len(genes), f"(mode={GENE_MODE}, MAX_GENES={MAX_GENES})")

# targets a correr
if ANALYSIS_LEVEL == "Level2_final":
    targets = list(LEVEL2_FINAL_TO_RUN)
    target_col = "Level2_final"
else:
    targets = list(LEVEL1_TO_RUN)
    target_col = "Level1_refined"

print("\nANALYSIS_LEVEL:", ANALYSIS_LEVEL)
print("target_col    :", target_col)
print("n targets (raw):", len(targets))

# ============================================================
# 1.05) Filtrar targets a los que realmente existen en el objeto
# ============================================================
present_targets = set(obs.loc[valid_mask_global, target_col].dropna().astype(str).unique().tolist())
missing_targets = [t for t in targets if t not in present_targets]
targets = [t for t in targets if t in present_targets]

print("Targets presentes tras filtrar:", len(targets))
if missing_targets:
    print("[INFO] Targets no presentes (se omiten):", missing_targets)

if len(targets) == 0:
    adata_b.file.close()
    raise RuntimeError("Tras filtrar, no queda ningún target presente. Revisa target list / columnas.")

# ============================================================
# 1.1) (opcional) mapa para decidir blacklist en linfoides (solo si Level2_final)
# ============================================================
order_by_group = {
    "B":     ["B_Naive", "B_Memory", "B_Activated", "B_Atypical", "B_Other"],
    "Plasma":["Plasma"],
    "pDC":   ["pDC"],
    "T":     ["CD4_Naive","CD4_Memory","CD8_Naive","CD8_Effector_Cytotoxic","Treg","MAIT","GammaDelta_T","Proliferative_T","Exhausted_T"],
    "NK":    ["NK"],
    "Mono":  ["Classical_Mono","NonClassical_Mono","ISG_Myeloid","MonoDC_Other"],
    "DC":    ["cDC1","cDC2","DC4","aDC"],
    "HSCs":  ["HSCs"],
}

def group_of_l2(l2: str) -> str:
    for g, l2_list in order_by_group.items():
        if l2 in l2_list:
            return g
    return "Other"

APPLY_BLACKLIST_FOR_GROUPS = {"T", "NK", "B"}  # linfoides


# ============================================================
# 2) Loop pseudobulk + DE + volcanos (base)
# ============================================================
results_summary: List[Tuple[str, int, str]] = []

for target in targets:
    print("\n==============================")
    print(f"{ANALYSIS_LEVEL} target:", target)

    mask_target = valid_mask_global & (obs[target_col].values == target)

    idx_by_patient: Dict[str, np.ndarray] = {}
    n_cells_by_patient: Dict[str, int] = {}
    for pid in patients:
        idx = np.where(mask_target & (obs["patientID"].values == pid))[0]
        idx_by_patient[pid] = idx
        n_cells_by_patient[pid] = int(idx.size)

    keep_pids = [pid for pid, n in n_cells_by_patient.items() if n >= MIN_CELLS_PER_PATIENT]
    print(f"Pacientes con >={MIN_CELLS_PER_PATIENT} células:", len(keep_pids), "/", len(patients))

    if len(keep_pids) < MIN_PATIENTS_PER_TARGET:
        print("[SKIP] Muy pocos pacientes con suficientes células para este target.")
        results_summary.append((str(target), len(keep_pids), "SKIP_low_n"))
        continue

    pb = np.full((len(keep_pids), len(genes)), np.nan, dtype=np.float32)
    diseases: List[str] = []

    for i, pid in enumerate(keep_pids):
        idx_cells = idx_by_patient[pid]
        n = idx_cells.size
        diseases.append(patient_meta.loc[pid, "disease"])

        sums = np.zeros(len(genes), dtype=np.float64)

        for start in range(0, len(genes), GENE_CHUNK):
            gchunk = genes[start:start + GENE_CHUNK]
            view = adata_b[idx_cells, gchunk]
            X = view.layers[LAYER]

            try:
                chunk_sum = np.asarray(X.sum(axis=0)).ravel()
            except Exception:
                chunk_sum = np.sum(np.asarray(X), axis=0)

            sums[start:start + len(gchunk)] = chunk_sum

        pb[i, :] = (sums / float(n)).astype(np.float32)

    diseases = np.array(diseases, dtype=str)
    keep_pids = np.array(keep_pids, dtype=str)

    # pseudobulk DF
    df_pb = pd.DataFrame(pb, index=keep_pids, columns=genes)
    df_pb.insert(0, "disease", diseases)

    tag = safe_tag(target)
    out_pb_csv = OUT_SUMMARY / f"pseudobulk_{ANALYSIS_LEVEL}_{tag}_mean_{LAYER}.csv"
    df_pb.to_csv(out_pb_csv)
    print("Saved pseudobulk:", out_pb_csv)

    # DE (paciente-level)
    A = df_pb["disease"].values == "Cirrhosis"
    B = df_pb["disease"].values == "Healthy"
    nA, nB = int(A.sum()), int(B.sum())
    print("n patients Cirrhosis:", nA, "| Healthy:", nB)

    if nA < MIN_PATIENTS_PER_GROUP or nB < MIN_PATIENTS_PER_GROUP:
        print("[SKIP] Muy pocos pacientes por grupo para DE en este target.")
        results_summary.append((str(target), len(keep_pids), "SKIP_group_n"))
        continue

    # log2FC en escala lineal aproximada: expm1(mean log1p)
    X_lin = np.expm1(df_pb[genes].values.astype(np.float64))
    meanA = np.nanmean(X_lin[A, :], axis=0)  # Cirrhosis
    meanB = np.nanmean(X_lin[B, :], axis=0)  # Healthy
    log2fc = np.log2((meanB + PSEUDOCOUNT) / (meanA + PSEUDOCOUNT))  # Healthy vs Cirrhosis

    # pvals: Welch t-test sobre valores pseudobulk en log1p (paciente-level)
    X_log = df_pb[genes].values.astype(np.float64)
    if HAVE_SCIPY:
        _, pvals = ttest_ind(X_log[B, :], X_log[A, :], axis=0, equal_var=False, nan_policy="omit")
        pvals = np.nan_to_num(pvals, nan=1.0, posinf=1.0, neginf=1.0)
    else:
        pvals = np.ones_like(log2fc, dtype=float)

    fdr = bh_fdr(pvals)

    df_de = pd.DataFrame({
        "gene": genes,
        "mean_lin_Cirrhosis": meanA,
        "mean_lin_Healthy": meanB,
        FC_COL: log2fc,
        "pval": pvals,
        "FDR": fdr,
    }).sort_values("FDR", ascending=True)

    out_de_csv = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.csv"
    df_de.to_csv(out_de_csv, index=False)
    print("Saved DE:", out_de_csv)

    out_png = OUT_FIG / f"Volcano_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.png"
    volcano_plot_base(
        df_de,
        out_png,
        title=f"{ANALYSIS_LEVEL}={target}: Healthy vs Cirrhosis (pseudobulk, per patient)",
        alpha_fdr=ALPHA_FDR
    )
    print("Saved volcano (base):", out_png)

    n_sig = int((df_de["FDR"].values < ALPHA_FDR).sum())
    results_summary.append((str(target), len(keep_pids), f"OK_sigFDR<{ALPHA_FDR}:{n_sig}"))

# cerramos backed
adata_b.file.close()

print("\n=== RESUMEN (base) ===")
for row in results_summary:
    print(row)

print("\n[OK] Pseudobulk + DE + volcano (base) terminado.")


# ============================================================
# 3) Volcanos "FOR_FIGURE" + CSVs filtrados (desde los outputs base)
#    - df_plot (filtrado) se guarda como CSV
#    - el plot usa df_f COMPLETO (todos los genes)
#    - etiqueta TOP 20 (balanceado up/down)
# ============================================================
print("\n[SEC3] Generando volcanos FOR_FIGURE (all genes + highlights + TOP labels) ...")

sec3_summary: List[Tuple[str, int, int]] = []  # (target, n_highlight, n_labels)
targets_for_figure_generated: List[str] = []   # <- SOLO los que generamos AHORA

targets_for_figure_generated = []

for target in targets:
    tag = safe_tag(target)
    de_path = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.csv"
    pb_path = OUT_SUMMARY / f"pseudobulk_{ANALYSIS_LEVEL}_{tag}_mean_{LAYER}.csv"

    if not de_path.exists() or not pb_path.exists():
        print("[SKIP] faltan archivos para", target)
        continue

    df_de = pd.read_csv(de_path)
    try:
        df_de = ensure_fc_col(df_de, where=str(de_path))
    except KeyError as e:
        print("[WARN] CSV DE incompatible, salto target:", target)
        print(" ", e)
        continue

    df_pb = pd.read_csv(pb_path, index_col=0)  # patientID índice

    gene_cols = [c for c in df_pb.columns if c != "disease"]
    X = df_pb[gene_cols].astype(float).values

    mean_log1p = X.mean(axis=0)
    frac_pat = (X > 0.1).mean(axis=0)

    df_f = df_de.merge(
        pd.DataFrame({"gene": gene_cols, "mean_log1p": mean_log1p, "frac_patients": frac_pat}),
        on="gene",
        how="left",
    )

    # ensure fold-change column
    df_f = ensure_fc_col(df_f, where=f"merge(df_de, df_pb) target={target}")

    keep = (
        (df_f["FDR"] < ALPHA_FDR) &
        (df_f[FC_COL].abs() >= MIN_ABS_LOG2FC) &
        (df_f["mean_log1p"] >= MIN_MEAN_LOG1P) &
        (df_f["frac_patients"] >= MIN_FRAC_PATIENTS)
    )

    df_plot = df_f.loc[keep].copy()

    # blacklist SOLO afecta a lo que resaltas/etiquetas (no al fondo)
    keep2 = keep.copy()
    if ANALYSIS_LEVEL == "Level1_refined":
        apply_blacklist = (str(target) in {"T", "NK", "B"})
    else:
        apply_blacklist = (group_of_l2(str(target)) in APPLY_BLACKLIST_FOR_GROUPS)

    if apply_blacklist:
        keep2 = keep2 & (~df_f["gene"].isin(AMBIENT_MYELOID))
        df_plot = df_plot[~df_plot["gene"].isin(AMBIENT_MYELOID)].copy()

    df_plot = df_plot.sort_values("FDR", ascending=True)

    out_csv = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis_FOR_FIGURE.csv"
    df_plot.to_csv(out_csv, index=False)
    targets_for_figure_generated.append(target)


    # -------- seleccionar genes a etiquetar (TOP 20) --------
    cand = df_f.loc[keep2].copy().sort_values("FDR", ascending=True)

    n_each = TOP_N_LABEL // 2
    cand_up = cand[cand[FC_COL] > 0].head(n_each)
    cand_dn = cand[cand[FC_COL] < 0].head(n_each)

    genes_to_label = pd.concat([cand_up, cand_dn], axis=0)["gene"].astype(str).tolist()

    if len(genes_to_label) < TOP_N_LABEL:
        extra = cand[~cand["gene"].astype(str).isin(genes_to_label)].head(TOP_N_LABEL - len(genes_to_label))
        genes_to_label += extra["gene"].astype(str).tolist()

    out_png = OUT_FIG / f"Volcano_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis_FOR_FIGURE.png"
    volcano_for_figure(
        df_all=df_f,  # <- TODOS los genes
        out_png=out_png,
        title=f"{ANALYSIS_LEVEL}={target}: Healthy vs Cirrhosis (pseudobulk) — volcano (all genes)",
        alpha_fdr=ALPHA_FDR,
        highlight_mask=keep2.to_numpy(),
        genes_to_label=genes_to_label,
        abs_log2fc_line=MIN_ABS_LOG2FC,
    )

    n_high = int(np.sum(keep2))
    sec3_summary.append((str(target), n_high, len(genes_to_label)))
    targets_for_figure_generated.append(str(target))

    print(f"{target}: genes destacados (keep2)={n_high} | genes etiquetados={len(genes_to_label)}")
    print("  Saved:", out_png)
    print("  Saved:", out_csv)

print("\n[OK] Volcanos 'FOR_FIGURE' generados.")
print("Targets generados en ESTA corrida (FOR_FIGURE):", targets_for_figure_generated)


# ============================================================
# 4) Tablas top10 (FOR_FIGURE y FOR_REPORT)
#    - Por cada target, usa SU CSV_FOR_FIGURE (GENERADO EN ESTA CORRIDA)
#    - Objetivo ahora: DIAGNÓSTICO del KeyError (encontrar qué CSV está mal)
# ============================================================
print("\n[SEC4] Generando tablas top10 agregadas ...")

FC_COL = "log2FC_Healthy_vs_Cirrhosis"

HOUSEKEEPING = {
    "GAPDH","ACTB","ACTG1","B2M","MALAT1","EEF1A1","RPLP0","RPSA","TMSB10","FTH1","FTL"
}
def is_bad_gene(g: str) -> bool:
    g = str(g)
    if g in HOUSEKEEPING: return True
    if g.startswith("MT-"): return True
    if g.startswith("RPL") or g.startswith("RPS"): return True
    return False

rows_fig: List[List[object]] = []
rows_rep: List[List[object]] = []
rows_rep_linfo: List[List[object]] = []

# USAR SOLO CSVs generados en SEC3 (si existe la lista)
try:
    targets_iter = targets_for_figure_generated
    print("[SEC4] usando targets_for_figure_generated:", len(targets_iter))
except NameError:
    targets_iter = targets
    print("[SEC4][WARN] targets_for_figure_generated NO existe -> usando targets:", len(targets_iter))

for target in targets_iter:
    tag = safe_tag(target)
    path = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis_FOR_FIGURE.csv"

    print("\n[SEC4] target =", target)
    print("[SEC4] path   =", path)

    if not path.exists():
        print("[SEC4][MISSING] No existe el FOR_FIGURE CSV -> skip")
        continue

    df = pd.read_csv(path)
    
    if df.shape[0] == 0:
        print("[WARN] FOR_FIGURE CSV vacío (0 genes). Lo salto:", target)
        continue

    print("[SEC4] columns =", df.columns.tolist())
    print("[SEC4] columns_repr =", [repr(c) for c in df.columns])
    print(df.head(2))

    # ---- DIAGNÓSTICO CLAVE ----
    if FC_COL not in df.columns:
        print(f"\n[SEC4][ERROR REAL] Este CSV NO tiene la columna {FC_COL}. ESTE ES EL CULPABLE.")
        print("CSV culpable:", path)

        base_path = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.csv"
        print("DE base esperado:", base_path)
        if base_path.exists():
            df_base = pd.read_csv(base_path, nrows=2)
            print("[SEC4] base DE columns =", df_base.columns.tolist())
        else:
            print("[SEC4] base DE NO existe.")

        # cortamos aquí a propósito
        raise KeyError(f"Falta {FC_COL} en {path}")

    # a partir de aquí ya es el SEC4 normal (solo si el CSV está bien)
    df = df.sort_values("FDR", ascending=True)
    df["gene"] = df["gene"].astype(str)

    # FOR_FIGURE top10
    top_healthy = df[df[FC_COL] > 0].head(10)
    top_cirr    = df[df[FC_COL] < 0].head(10)

    for _, r in top_healthy.iterrows():
        rows_fig.append([str(target), "Higher_in_Healthy", r["gene"], r[FC_COL], r["FDR"]])
    for _, r in top_cirr.iterrows():
        rows_fig.append([str(target), "Higher_in_Cirrhosis", r["gene"], r[FC_COL], r["FDR"]])

    # FOR_REPORT top10 (quita housekeeping/ribo/MT)
    mask_keep = ~df["gene"].map(is_bad_gene).astype(bool)
    df_rep = df.loc[mask_keep].copy()

    top_healthy_r = df_rep[df_rep[FC_COL] > 0].head(10)
    top_cirr_r    = df_rep[df_rep[FC_COL] < 0].head(10)

    for _, r in top_healthy_r.iterrows():
        rows_rep.append([str(target), "Higher_in_Healthy", r["gene"], r[FC_COL], r["FDR"]])
    for _, r in top_cirr_r.iterrows():
        rows_rep.append([str(target), "Higher_in_Cirrhosis", r["gene"], r[FC_COL], r["FDR"]])

    # FOR_REPORT_LINFO_CLEAN (extra: quitar “ambient myeloid” en linfoides)
    df_rep_l = df_rep.copy()
    if ANALYSIS_LEVEL == "Level1_refined":
        apply_blacklist = (str(target) in {"T", "NK", "B"})
    else:
        apply_blacklist = (group_of_l2(str(target)) in APPLY_BLACKLIST_FOR_GROUPS)

    if apply_blacklist:
        df_rep_l = df_rep_l[~df_rep_l["gene"].isin(AMBIENT_MYELOID)].copy()

    top_healthy_l = df_rep_l[df_rep_l[FC_COL] > 0].head(10)
    top_cirr_l    = df_rep_l[df_rep_l[FC_COL] < 0].head(10)

    for _, r in top_healthy_l.iterrows():
        rows_rep_linfo.append([str(target), "Higher_in_Healthy", r["gene"], r[FC_COL], r["FDR"]])
    for _, r in top_cirr_l.iterrows():
        rows_rep_linfo.append([str(target), "Higher_in_Cirrhosis", r["gene"], r[FC_COL], r["FDR"]])

df_fig = pd.DataFrame(rows_fig, columns=[ANALYSIS_LEVEL, "direction", "gene", FC_COL, "FDR"])
df_rep = pd.DataFrame(rows_rep, columns=[ANALYSIS_LEVEL, "direction", "gene", FC_COL, "FDR"])
df_rep_l = pd.DataFrame(rows_rep_linfo, columns=[ANALYSIS_LEVEL, "direction", "gene", FC_COL, "FDR"])

out_fig = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_FOR_FIGURE_top10_by_target.csv"
out_rep = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_FOR_REPORT_top10_by_target.csv"
out_rep_l = OUT_SUMMARY / f"DE_pseudobulk_{ANALYSIS_LEVEL}_FOR_REPORT_LINFO_CLEAN_top10_by_target.csv"

df_fig.to_csv(out_fig, index=False)
df_rep.to_csv(out_rep, index=False)
df_rep_l.to_csv(out_rep_l, index=False)

print("\nSaved:", out_fig)
print("Saved:", out_rep)
print("Saved:", out_rep_l)
print("\n[OK] Tablas top10 generadas.")

In [None]:
df_de = pd.read_csv(de_path, nrows=5)
print("DE columns:", df_de.columns.tolist())


In [None]:
# ============================================================
# PSEUDOBULK (por paciente) + DE Healthy vs Cirrhosis + VOLCANO
# (EXTENDIDO: permite DE por Level2_final para subpoblaciones seleccionadas)
#
# - Usa el objeto FILTRADO: TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad (RBC-out)
# - Memory-safe: lee en backed="r" y procesa genes por CHUNKS
#
# Modo recomendado para el paso "7) DE a nivel Level2":
#   ANALYSIS_LEVEL = "Level2_final"
#
# Outputs (por cada grupo analizado):
#   * summary_tables_final/pseudobulk_{ANALYSIS_LEVEL}_{tag}_mean_log1p_10k.csv
#   * summary_tables_final/DE_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.csv
#   * figures_final/Volcano_pseudobulk_{ANALYSIS_LEVEL}_{tag}_Healthy_vs_Cirrhosis.png
#   * ..._FOR_FIGURE.csv + ..._FOR_FIGURE.png (plot con TODOS los genes + highlights + labels TOP20)
#   * tablas top10 agregadas (FOR_FIGURE / FOR_REPORT / FOR_REPORT_LINFO_CLEAN)
# ============================================================

from pathlib import Path
import json
import re
from typing import Optional, List, Dict, Tuple

import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt

# stats
try:
    from scipy.stats import ttest_ind
    HAVE_SCIPY = True
except Exception:
    HAVE_SCIPY = False


# -----------------------------
# Paths
# -----------------------------
NOTEBOOK_DIR = Path.cwd()

def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data_processed").exists():
            return p
    raise FileNotFoundError(f"No encuentro 'data_processed' subiendo desde: {start}")

PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
DATA_PROCESSED = PROJECT_ROOT / "data_processed"
IN_PATH = DATA_PROCESSED / "TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad"

OUT_SUMMARY = PROJECT_ROOT / "summary_tables_final"
OUT_FIG     = PROJECT_ROOT / "figures_final"
OUT_SUMMARY.mkdir(exist_ok=True)
OUT_FIG.mkdir(exist_ok=True)

# Mapa Level2_final (Conv_T_other -> CD4_Memory, etc.)
MAP_PATH = OUT_SUMMARY / "Level2_final_map.json"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("IN_PATH     :", IN_PATH)
print("OUT_SUMMARY :", OUT_SUMMARY)
print("OUT_FIG     :", OUT_FIG)
print("MAP_PATH    :", MAP_PATH)

from pathlib import Path
import datetime as dt

print("CWD:", Path.cwd())

# Reutiliza tus variables si ya existen
print("PROJECT_ROOT:", PROJECT_ROOT)
print("OUT_FIG     :", OUT_FIG)

# Lista TODOS los volcanos en ese OUT_FIG
vols = sorted(OUT_FIG.glob("Volcano_pseudobulk_Level2_final_*_FOR_FIGURE.png"))
print("\n[OUT_FIG] N volcanos FOR_FIGURE:", len(vols))
for p in vols[:30]:
    t = dt.datetime.fromtimestamp(p.stat().st_mtime)
    print(t.strftime("%Y-%m-%d %H:%M:%S"), "-", p.name)

# Busca también si existen en Documents (root “equivocado” típico)
docs = Path(r"D:\Users\Coni\Documents")
vols_docs = sorted(docs.glob("figures_final/Volcano_pseudobulk_Level2_final_*_FOR_FIGURE.png"))
print("\n[Documents/figures_final] N volcanos FOR_FIGURE:", len(vols_docs))
for p in vols_docs[:30]:
    t = dt.datetime.fromtimestamp(p.stat().st_mtime)
    print(t.strftime("%Y-%m-%d %H:%M:%S"), "-", p)


In [None]:
# ============================================================
# MINI-SUMMARY para Results 3.3 (DE pseudobulk) — ligero y “pegable”
# - Lee CSVs DE completos si existen; si no, cae a *_FOR_FIGURE.csv
# - Saca: #DE genes (FDR<0.05), top genes por dirección, presencia de “stress genes”
# - Genera un TXT/MD compacto en summary_tables_final/
# ============================================================

from pathlib import Path
import pandas as pd
import numpy as np

# --- localiza PROJECT_ROOT ---
PROJECT_ROOT = Path.cwd()
if not (PROJECT_ROOT / "summary_tables_final").exists():
    if (PROJECT_ROOT.parent / "summary_tables_final").exists():
        PROJECT_ROOT = PROJECT_ROOT.parent

SUM_DIR = PROJECT_ROOT / "summary_tables_final"
OUT_TXT = SUM_DIR / "Fig3_DE_Level2final_compact_summary.txt"

ANALYSIS_LEVEL = "Level2_final"
FC_COL = "log2FC_Healthy_vs_Cirrhosis"
FDR_COL = "FDR"

# >>> Elige aquí tus 6 paneles <<<
TARGETS = [
    "Classical_Mono",
    "NK",
    "CD4_Memory",
    "CD4_Naive",
    "CD8_Naive",
    "B_Naive",
]

# Genes “recurrent activation/stress” del texto legacy (ajusta si quieres)
KEY_GENES = [
    "JUN","DUSP1","TNFAIP3","KLF6","CD69",
    "CXCL8","NFKB1","NAMPT","VIM","SRGN",
    "CSF1R","CD86","ENTPD1"
]

FDR_THR = 0.05
TOP_N = 10   # top genes por dirección (por FDR)
ROUND_FC = 3
ROUND_FDR = 3

def _to_num(s):
    # robusto a coma decimal (si algún CSV lo trae así)
    return pd.to_numeric(s.astype(str).str.replace(",", ".", regex=False), errors="coerce")

def load_de(target: str):
    full = SUM_DIR / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{target}_Healthy_vs_Cirrhosis.csv"
    fig  = SUM_DIR / f"DE_pseudobulk_{ANALYSIS_LEVEL}_{target}_Healthy_vs_Cirrhosis_FOR_FIGURE.csv"
    if full.exists():
        path = full
        mode = "FULL"
    elif fig.exists():
        path = fig
        mode = "FOR_FIGURE_FALLBACK"
    else:
        return None, None, None

    # leer solo columnas mínimas
    usecols = None
    df = pd.read_csv(path)
    # normaliza columnas mínimas
    need = {"gene", FC_COL, FDR_COL}
    missing = need - set(df.columns)
    if missing:
        raise KeyError(f"{path.name}: faltan columnas {missing}. Cols={df.columns.tolist()}")

    df = df[["gene", FC_COL, FDR_COL] + ([c for c in ["pval"] if c in df.columns])].copy()
    df["gene"] = df["gene"].astype(str)

    df[FC_COL] = _to_num(df[FC_COL])
    df[FDR_COL] = _to_num(df[FDR_COL])

    df = df.dropna(subset=[FC_COL, FDR_COL])
    return df, path.name, mode

def top_table(df, direction: str, n=TOP_N):
    # direction: "Higher_in_Cirrhosis" => log2FC (Healthy vs Cirrhosis) NEGATIVO
    if direction == "Higher_in_Cirrhosis":
        sub = df[df[FC_COL] < 0].copy()
    else:
        sub = df[df[FC_COL] > 0].copy()

    sub = sub.sort_values(FDR_COL, ascending=True).head(n)
    if sub.shape[0] == 0:
        return []
    out = []
    for _, r in sub.iterrows():
        out.append(f"{r['gene']}({r[FC_COL]:.{ROUND_FC}f},FDR={r[FDR_COL]:.{ROUND_FDR}g})")
    return out

lines = []
lines.append("FIG3 — DE pseudobulk Level2_final (Healthy vs Cirrhosis)")
lines.append(f"PROJECT_ROOT: {PROJECT_ROOT}")
lines.append(f"SUM_DIR     : {SUM_DIR}")
lines.append(f"Targets     : {', '.join(TARGETS)}")
lines.append(f"Key genes   : {', '.join(KEY_GENES)}")
lines.append("")

# acumula para “recurrent genes”
freq_cirr = {}
freq_heal = {}

for t in TARGETS:
    df, fname, mode = load_de(t)
    if df is None:
        lines.append(f"[MISSING] {t}: no encuentro CSV FULL ni FOR_FIGURE")
        lines.append("")
        continue

    sig = df[df[FDR_COL] < FDR_THR]
    n_sig = int(sig.shape[0])
    n_sig_cirr = int((sig[FC_COL] < 0).sum())
    n_sig_heal = int((sig[FC_COL] > 0).sum())

    top_cirr = top_table(df, "Higher_in_Cirrhosis", TOP_N)
    top_heal = top_table(df, "Higher_in_Healthy", TOP_N)

    # freq (solo top list para “recurrent” rápido)
    for g in [x.split("(")[0] for x in top_cirr]:
        freq_cirr[g] = freq_cirr.get(g, 0) + 1
    for g in [x.split("(")[0] for x in top_heal]:
        freq_heal[g] = freq_heal.get(g, 0) + 1

    present_keys = [g for g in KEY_GENES if (df["gene"] == g).any()]
    # separa keys por dirección (mirando signo del FC)
    key_dir = []
    for g in present_keys:
        rmin = df.loc[df["gene"] == g].sort_values(FDR_COL).head(1)
        if rmin.shape[0] == 1:
            fc = float(rmin[FC_COL].iloc[0])
            dd = "Cirrhosis↑" if fc < 0 else "Healthy↑"
            key_dir.append(f"{g}:{dd}")

    lines.append(f"=== {t} ===")
    lines.append(f"source: {fname} ({mode})")
    lines.append(f"sig(FDR<{FDR_THR}): {n_sig}  | Cirrhosis↑: {n_sig_cirr}  | Healthy↑: {n_sig_heal}")
    lines.append(f"top Cirrhosis↑ (log2FC<0): " + (", ".join(top_cirr[:10]) if top_cirr else "NA"))
    lines.append(f"top Healthy↑   (log2FC>0): " + (", ".join(top_heal[:10]) if top_heal else "NA"))
    lines.append(f"key genes hit: " + (", ".join(key_dir) if key_dir else "none"))
    lines.append("")

# recurrent summary
def top_freq(d, k=15):
    items = sorted(d.items(), key=lambda x: (-x[1], x[0]))
    return items[:k]

lines.append("=== Recurrent genes across the 6 panels (from TOP lists) ===")
lines.append("Cirrhosis↑ most frequent: " + ", ".join([f"{g}×{n}" for g, n in top_freq(freq_cirr)]))
lines.append("Healthy↑ most frequent  : " + ", ".join([f"{g}×{n}" for g, n in top_freq(freq_heal)]))
lines.append("")

OUT_TXT.write_text("\n".join(lines), encoding="utf-8")
print("\n".join(lines[:80]))  # imprime solo el inicio en pantalla (ligero)
print("\n[OK] Guardado resumen compacto en:", OUT_TXT)
