In [None]:
# ============================================================
# scCODA (Level1_refined) — post-filtrado, por paciente
# scCODA 0.1.9 API: dat.from_pandas(df, covariate_columns=[...])
# RBC-out: si aparece columna RBC, se elimina aquí (robusto)
#
# Outputs principales:
# - summary_tables_final/scCODA_Level1refined_summary_ref-*.txt
# - summary_tables_final/scCODA_Level1refined_credible_effects_ref-*.csv
# - figures_final/Fig2B_scCODA_Level1refined_results_ref-*.png   <-- panel “resultados” para Figura 2
#
# Outputs secundarios (mantengo por compatibilidad):
# - figures_final/Fig1E_scCODA_Level1refined_boxplots_ref-*.png  <-- composición (input) boxplot+puntos
# - summary_tables_final/QA_scCODA_boxplot_stats_ref-*.csv
# ============================================================

from pathlib import Path
import io
import contextlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sccoda.util import cell_composition_data as dat
from sccoda.util.comp_ana import CompositionalAnalysis

# -----------------------------
# Paths
# -----------------------------
NOTEBOOK_DIR = Path.cwd()

def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data_processed").exists():
            return p
    raise FileNotFoundError(f"No encuentro 'data_processed' subiendo desde: {start}")

PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
OUT_SUMMARY = PROJECT_ROOT / "summary_tables_final"
OUT_FIG     = PROJECT_ROOT / "figures_final"
OUT_SUMMARY.mkdir(exist_ok=True)
OUT_FIG.mkdir(exist_ok=True)

counts_csv = OUT_SUMMARY / "cell_counts_Level1refined_by_patient.csv"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("counts_csv  :", counts_csv)
print("OUT_SUMMARY :", OUT_SUMMARY)
print("OUT_FIG     :", OUT_FIG)

if not counts_csv.exists():
    raise FileNotFoundError(f"No existe counts_csv:\n{counts_csv}\n¿Has generado los counts por paciente?")

# -----------------------------
# Load counts
# -----------------------------
df = pd.read_csv(counts_csv)

needed_cols = {"patientID", "disease"}
if not needed_cols.issubset(set(df.columns)):
    raise KeyError(f"Faltan columnas requeridas en counts_csv: {sorted(list(needed_cols - set(df.columns)))}")

# Mantener solo: patientID + disease + celltypes (counts enteros)
celltype_cols = [c for c in df.columns if c not in ["patientID", "disease", "total_cells_patient"]]

# --- RBC-out robusto ---
if "RBC" in celltype_cols:
    print("[WARN] Columna RBC encontrada en counts_csv. Eliminándola (post-RBC-out).")
    celltype_cols = [c for c in celltype_cols if c != "RBC"]

if len(celltype_cols) == 0:
    raise RuntimeError("No hay columnas de cell types en counts_csv (tras filtrar columnas base).")

df_sccoda = df[["patientID", "disease"] + celltype_cols].copy()

# asegurar enteros (counts) y >=0
for c in celltype_cols:
    df_sccoda[c] = pd.to_numeric(df_sccoda[c], errors="raise").astype(int)
    if (df_sccoda[c] < 0).any():
        raise ValueError(f"Counts negativos detectados en columna {c} (esto no debería ocurrir).")

# drop columnas todo-cero (scCODA puede quejarse / no aportan)
zero_cols = [c for c in celltype_cols if int(df_sccoda[c].sum()) == 0]
if zero_cols:
    print("[WARN] Columnas con suma 0 (se eliminan):", zero_cols)
    celltype_cols = [c for c in celltype_cols if c not in zero_cols]
    df_sccoda = df_sccoda[["patientID", "disease"] + celltype_cols].copy()

# IMPORTANTÍSIMO: patientID NO debe ser columna (si no, scCODA lo interpreta como cell type)
df_sccoda = df_sccoda.set_index("patientID")

# normaliza covariate a str
df_sccoda["disease"] = df_sccoda["disease"].astype(str)

print("\nDataFrame para scCODA (head):")
print(df_sccoda.head())
print("Shape:", df_sccoda.shape)
print("Cell types:", celltype_cols)
print("\nDisease counts:")
print(df_sccoda["disease"].value_counts())

if df_sccoda["disease"].nunique() < 2:
    raise RuntimeError("La covariable 'disease' solo tiene 1 nivel. scCODA no puede comparar grupos.")

# -----------------------------
# Elegir referencia explícita (reproducible) — SIN opciones
# Política: celltype más estable (mínimo CV) entre los presentes en >=90% pacientes
# Si no hay candidatos, fallback a "automatic"
# -----------------------------
counts_mat = df_sccoda[celltype_cols].astype(float)
tot = counts_mat.sum(axis=1).replace(0, np.nan)
props_mat = counts_mat.div(tot, axis=0)

present_frac = (counts_mat > 0).mean(axis=0)
candidates = present_frac[present_frac >= 0.90].index.tolist()

REFERENCE_CELL_TYPE = None
ref_reason = ""

if len(candidates) > 0:
    means = props_mat[candidates].mean(axis=0)
    stds  = props_mat[candidates].std(axis=0)
    cv = (stds / means.replace(0, np.nan)).replace([np.inf, -np.inf], np.nan).dropna()

    if len(cv) > 0:
        REFERENCE_CELL_TYPE = cv.sort_values().index[0]
        ref_reason = f"min_CV_among_present>=90% (chosen={REFERENCE_CELL_TYPE})"
    else:
        REFERENCE_CELL_TYPE = "automatic"
        ref_reason = "fallback_automatic (CV undefined)"
else:
    REFERENCE_CELL_TYPE = "automatic"
    ref_reason = "fallback_automatic (no candidates >=90% present)"

ref_tag = str(REFERENCE_CELL_TYPE).replace(" ", "_")

print("\n=== Reference selection ===")
print("Candidates (present >=90% patients):", candidates)
print("REFERENCE_CELL_TYPE:", REFERENCE_CELL_TYPE)
print("ref_tag:", ref_tag)
print("Reason:", ref_reason)

# -----------------------------
# Construir scCODA data + correr modelo
# -----------------------------
data = dat.from_pandas(df_sccoda, covariate_columns=["disease"])
print("\nscCODA data:", data)

model = CompositionalAnalysis(
    data,
    formula="disease",
    reference_cell_type=REFERENCE_CELL_TYPE,
)

NUM_RESULTS = 2000
NUM_BURNIN  = 1000
result = model.sample_hmc(num_results=NUM_RESULTS, num_burnin=NUM_BURNIN)

# -----------------------------
# Guardar summary correctamente (0.1.9: summary() imprime)
# -----------------------------
buf = io.StringIO()
with contextlib.redirect_stdout(buf):
    _ = result.summary()
summary_txt = buf.getvalue()

df_credible = result.credible_effects()

# ============================
# FIX CRÍTICO:
# credible_effects() a veces devuelve Series -> lo convertimos SIEMPRE a DataFrame
# + Normalizamos nombre de columna booleana a: "credible"
# (scCODA 0.1.9 suele devolver una única columna llamada "Final Parameter")
# ============================
if isinstance(df_credible, pd.Series):
    df_credible = df_credible.to_frame()

if not isinstance(df_credible, pd.DataFrame):
    df_credible = pd.DataFrame(df_credible)

df_credible = df_credible.copy()

# Normalizar salida típica scCODA 0.1.9: una columna booleana con nombre raro
if df_credible.shape[1] == 1:
    col0 = df_credible.columns[0]
    if str(df_credible[col0].dtype) in ("bool", "boolean"):
        df_credible = df_credible.rename(columns={col0: "credible"})

# Fallback explícito por nombre típico
if "Final Parameter" in df_credible.columns and "credible" not in df_credible.columns:
    df_credible = df_credible.rename(columns={"Final Parameter": "credible"})

print("\n=== credible_effects() (head) ===")
print(df_credible.head(20))
print("credible_effects type:", type(df_credible))
print("credible_effects shape:", df_credible.shape)
print("credible_effects columns:", list(df_credible.columns))

out_summary_txt  = OUT_SUMMARY / f"scCODA_Level1refined_summary_ref-{ref_tag}.txt"
out_credible_csv = OUT_SUMMARY / f"scCODA_Level1refined_credible_effects_ref-{ref_tag}.csv"

with open(out_summary_txt, "w", encoding="utf-8") as f:
    f.write(summary_txt)

df_credible.to_csv(out_credible_csv, index=True)

print("\nSaved:", out_summary_txt)
print("Saved:", out_credible_csv)

# ============================================================
# PANEL “RESULTADOS” (para Figura 2): tabla + mensaje
# -> guarda: Fig2B_scCODA_Level1refined_results_ref-*.png
# ============================================================

def _infer_credible_column(df):
    cols = list(getattr(df, "columns", []))
    # 1) nombre explícito
    if "credible" in cols:
        return "credible"
    # 2) heurística: primera columna booleana
    for c in cols:
        if str(df[c].dtype) in ("bool", "boolean"):
            return c
    # 3) fallback por nombre típico scCODA 0.1.9
    for c in cols:
        if str(c).lower().strip() in ("final parameter", "final_parameter"):
            return c
    return None

cred_col = _infer_credible_column(df_credible)
n_rows = int(df_credible.shape[0])

if cred_col is not None:
    try:
        n_credible = int(pd.to_numeric(df_credible[cred_col], errors="coerce").fillna(False).astype(bool).sum())
    except Exception:
        n_credible = None
else:
    n_credible = None

fig, ax = plt.subplots(figsize=(12, 4.5))
ax.axis("off")

lines = []
lines.append(f"scCODA results — Level1_refined (reference = {REFERENCE_CELL_TYPE})")
lines.append("Input: counts per patient; covariate: disease")
if n_credible is None:
    lines.append("Credible effects: (column not found in credible_effects output)")
else:
    lines.append(f"Credible effects detected: {n_credible} / {n_rows}")

ax.text(0.01, 0.98, "\n".join(lines), va="top", ha="left", fontsize=11)

# tabla
tbl = df_credible.copy()
if tbl.index.name is None:
    tbl.index.name = "celltype"
tbl = tbl.reset_index()

# formateo ligero para que entre mejor
tbl_disp = tbl.copy()
for c in tbl_disp.columns:
    if np.issubdtype(tbl_disp[c].dtype, np.number):
        tbl_disp[c] = tbl_disp[c].map(lambda x: f"{x:.3g}" if pd.notnull(x) else "")

table = ax.table(
    cellText=tbl_disp.values,
    colLabels=tbl_disp.columns.tolist(),
    loc="lower left",
    cellLoc="left",
)

table.auto_set_font_size(False)
table.set_fontsize(8)
table.scale(1, 1.2)

out_png_results = OUT_FIG / f"Fig2B_scCODA_Level1refined_results_ref-{ref_tag}.png"
plt.savefig(out_png_results, dpi=300, bbox_inches="tight")
plt.close(fig)

print("\nSaved results panel:", out_png_results)

# ============================================================
# FIGURA scCODA-style: boxplot + dots por paciente (desde proportions)
# ============================================================

plot_df = props_mat.copy()
plot_df["disease"] = df_sccoda["disease"].astype(str)
plot_df["patientID"] = plot_df.index.astype(str)

long = plot_df.melt(
    id_vars=["patientID", "disease"],
    value_vars=celltype_cols,
    var_name="celltype",
    value_name="proportion"
)

desired_order = ["B","Plasma","pDC","T","NK","Mono","DC","HSCs"]
present_order = [x for x in desired_order if x in celltype_cols] + [x for x in celltype_cols if x not in desired_order]
long["celltype"] = pd.Categorical(long["celltype"], categories=present_order, ordered=True)

diseases = sorted(long["disease"].unique())
print("\nDisease levels (plot):", diseases)

cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3'])
colors = {dis: cycle[i % len(cycle)] for i, dis in enumerate(diseases)}
print("Color map:", colors)

credible_flag = {}
try:
    tmp = df_credible.copy()
    if tmp.index.name is None:
        tmp.index.name = "celltype"
    tmp = tmp.reset_index()
    cred_cols = [c for c in tmp.columns if "credible" in str(c).lower()]
    if len(cred_cols) > 0:
        cred_col2 = cred_cols[0]
        for _, r in tmp.iterrows():
            credible_flag[str(r["celltype"])] = bool(r[cred_col2])
except Exception:
    credible_flag = {}

xticklabels = [f"{ct} *" if credible_flag.get(ct, False) else ct for ct in present_order]

fig, ax = plt.subplots(figsize=(14, 6))

n_groups = len(diseases)
base_positions = np.arange(len(present_order))
width = 0.35 if n_groups == 2 else 0.25
offsets = np.linspace(-width, width, n_groups)

all_positions = []
all_data = []
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = long.loc[(long["celltype"] == ct) & (long["disease"] == dis), "proportion"].values
        all_data.append(vals)
        all_positions.append(base_positions[i] + offsets[j])

_ = ax.boxplot(
    all_data,
    positions=all_positions,
    widths=width * 0.8,
    showfliers=False,
    patch_artist=False,
)

rng = np.random.default_rng(0)
k = 0
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = all_data[k]
        x0 = all_positions[k]
        jitter = rng.normal(0, width*0.08, size=len(vals))
        ax.scatter(
            np.full_like(vals, x0, dtype=float) + jitter,
            vals,
            s=12,
            alpha=0.7,
            color=colors[dis]
        )
        k += 1

ax.set_xticks(base_positions)
ax.set_xticklabels(xticklabels, rotation=45, ha="right")
ax.set_ylabel("Proportion of cells per patient")
ax.set_title(f"scCODA composition input (Level1_refined) — ref={REFERENCE_CELL_TYPE}  (* credible effect)")
ax.set_ylim(0, max(0.05, float(long["proportion"].max()) * 1.15))

from matplotlib.lines import Line2D
legend_handles = [
    Line2D([0],[0],
           marker='s', linestyle='None', markersize=8,
           label=dis,
           markerfacecolor=colors[dis],
           markeredgecolor=colors[dis])
    for dis in diseases
]
ax.legend(handles=legend_handles, title="disease", frameon=False, loc="upper right")

plt.tight_layout()

out_png_boxplot = OUT_FIG / f"Fig1E_scCODA_Level1refined_boxplots_ref-{ref_tag}.png"
plt.savefig(out_png_boxplot, dpi=300)
plt.close(fig)

print("\nSaved scCODA boxplot (input composition):", out_png_boxplot)

# -----------------------------
# QA numérico del plot
# -----------------------------
qa_rows = []
for ct in present_order:
    for dis in diseases:
        vals = long.loc[(long["celltype"] == ct) & (long["disease"] == dis), "proportion"].dropna().values
        if len(vals) == 0:
            qa_rows.append([ct, dis, 0, np.nan, np.nan, np.nan, np.nan, credible_flag.get(ct, False)])
            continue
        qa_rows.append([
            ct,
            dis,
            int(len(vals)),
            float(np.mean(vals)),
            float(np.median(vals)),
            float(np.quantile(vals, 0.25)),
            float(np.quantile(vals, 0.75)),
            bool(credible_flag.get(ct, False)),
        ])

qa_df = pd.DataFrame(qa_rows, columns=[
    "celltype", "disease", "n_patients",
    "mean_prop", "median_prop", "q25_prop", "q75_prop",
    "credible_effect_flag_if_available"
])

qa_path = OUT_SUMMARY / f"QA_scCODA_boxplot_stats_ref-{ref_tag}.csv"
qa_df.to_csv(qa_path, index=False)
print("Saved QA table:", qa_path)

print("\n[OK] scCODA Level1_refined completado (summary + credible + results panel + boxplot + QA).")



“A nivel de linaje (Level1_refined), scCODA no identificó cambios composicionales creíbles entre grupos.”