### 1. Imports + paths del repo + outputs

In [None]:
from pathlib import Path
import json
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

from src.paths import project_paths

print("Scanpy:", sc.__version__)

P = project_paths(Path.cwd())
PROJECT_ROOT = P["PROJECT_ROOT"]
CONFIG_DIR   = P["CONFIG_DIR"]
DATA_DIR     = P["DATA_DIR"]
RESULTS_DIR  = P["RESULTS_DIR"]
FIGURES_DIR  = P["FIGURES_DIR"]

OUT_SUMMARY = RESULTS_DIR / "summary_tables" / "composition_boxplots"
OUT_FIG     = FIGURES_DIR / "composition"
OUT_SUMMARY.mkdir(parents=True, exist_ok=True)
OUT_FIG.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("RESULTS_DIR :", RESULTS_DIR)
print("FIGURES_DIR :", FIGURES_DIR)
print("OUT_SUMMARY :", OUT_SUMMARY)
print("OUT_FIG     :", OUT_FIG)

### 2. Leer config

In [None]:
def load_simple_yaml(path: Path) -> dict:
    cfg = {}
    for line in path.read_text(encoding="utf-8").splitlines():
        line = line.strip()
        if not line or line.startswith("#"):
            continue
        if ":" not in line:
            continue
        k, v = line.split(":", 1)
        cfg[k.strip()] = v.strip().strip('"').strip("'")
    return cfg

cfg_path = CONFIG_DIR / "config.yaml"
CFG = load_simple_yaml(cfg_path) if cfg_path.exists() else {}

# Input principal (salida NB10)
IN_NAME = CFG.get("main_filtered_for_analysis_h5ad_filename", "TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad")
IN_PATH = RESULTS_DIR / IN_NAME

print("IN_PATH:", IN_PATH)
if not IN_PATH.exists():
    raise FileNotFoundError(f"No existe IN_PATH:\n{IN_PATH}")

# Keys con fallback
PATIENT_KEY = CFG.get("patient_id_key", "patientID")
LEVEL1R_KEY  = CFG.get("level1_refined_key", "Level1_refined")
LEVEL2_KEY   = CFG.get("level2_key", "Level2")

# disease/group: intentar config; si no, seleccionar la primera disponible
DISEASE_KEY = CFG.get("disease_key", None)

adata_b = sc.read_h5ad(IN_PATH, backed="r")
print("Loaded (backed):", adata_b)

try:
    # resolver disease key si no viene en config
    if DISEASE_KEY is None:
        for cand in ["disease", "condition", "group"]:
            if cand in adata_b.obs.columns:
                DISEASE_KEY = cand
                break
    if DISEASE_KEY is None:
        raise KeyError("No encuentro ninguna columna tipo disease/condition/group en adata.obs.")

    needed = [PATIENT_KEY, DISEASE_KEY, LEVEL1R_KEY, LEVEL2_KEY]
    missing = [c for c in needed if c not in adata_b.obs.columns]
    if missing:
        raise KeyError(f"Faltan columnas en obs: {missing}")

    if "doublet_like" in adata_b.obs.columns:
        print("doublet_like True (debería ser 0):", int(adata_b.obs["doublet_like"].sum()))

    # extraer SOLO obs
    obs_all = adata_b.obs[needed].copy()

finally:
    try:
        adata_b.file.close()
    except Exception:
        pass

# limpieza ligera
obs_all[PATIENT_KEY] = obs_all[PATIENT_KEY].astype(str)
obs_all[DISEASE_KEY] = obs_all[DISEASE_KEY].astype(str)
obs_all[LEVEL1R_KEY] = obs_all[LEVEL1R_KEY].astype(str)

print("Keys usados:")
print("  PATIENT_KEY:", PATIENT_KEY)
print("  DISEASE_KEY:", DISEASE_KEY)
print("  LEVEL1R_KEY :", LEVEL1R_KEY)
print("  LEVEL2_KEY  :", LEVEL2_KEY)

print("n_obs:", obs_all.shape[0])
print("disease levels:", sorted(obs_all[DISEASE_KEY].unique())[:20])

### 3. Level1_refined: composición + CSVs + boxplots

In [None]:
obs = obs_all[[PATIENT_KEY, DISEASE_KEY, LEVEL1R_KEY]].copy()

# RBC-out: si aún queda RBC por error, lo eliminamos aquí
n_rbc = int((obs[LEVEL1R_KEY] == "RBC").sum())
print("\nRBC en Level1_refined (debería ser 0):", n_rbc)
if n_rbc > 0:
    print("[WARN] Eliminando filas RBC para composición (post-RBC-out).")
    obs = obs.loc[obs[LEVEL1R_KEY] != "RBC"].copy()

# pacientes por grupo
patients_per_group = (
    obs.drop_duplicates([PATIENT_KEY, DISEASE_KEY])
       .groupby(DISEASE_KEY)[PATIENT_KEY].nunique()
       .sort_values(ascending=False)
)
print("\nPacientes únicos por grupo:")
print(patients_per_group)

# counts por paciente
counts = (obs.groupby([PATIENT_KEY, DISEASE_KEY, LEVEL1R_KEY])
            .size()
            .unstack(fill_value=0)
            .reset_index())

celltype_cols = [c for c in counts.columns if c not in [PATIENT_KEY, DISEASE_KEY]]

# RBC-out extra: por si “RBC” quedó como columna residual, la quitamos
if "RBC" in celltype_cols:
    print("[WARN] Columna 'RBC' detectada en counts. Eliminándola (post-RBC-out).")
    counts = counts.drop(columns=["RBC"])
    celltype_cols = [c for c in celltype_cols if c != "RBC"]

counts["total_cells_patient"] = counts[celltype_cols].sum(axis=1)

props = counts.copy()
den = props["total_cells_patient"].replace(0, np.nan)
for c in celltype_cols:
    props[c] = props[c] / den

counts_csv = OUT_SUMMARY / "cell_counts_Level1refined_by_patient.csv"
props_csv  = OUT_SUMMARY / "cell_proportions_Level1refined_by_patient.csv"
counts.to_csv(counts_csv, index=False)
props.to_csv(props_csv, index=False)

print("\nSaved:", counts_csv)
print("Saved:", props_csv)

# plot Level1_refined
long = props.melt(
    id_vars=[PATIENT_KEY, DISEASE_KEY, "total_cells_patient"],
    value_vars=celltype_cols,
    var_name="celltype",
    value_name="proportion",
).dropna(subset=["proportion"])

desired_order = ["B", "Plasma", "pDC", "T", "NK", "Mono", "DC", "HSCs"]
present_order = [x for x in desired_order if x in celltype_cols] + [x for x in celltype_cols if x not in desired_order]
long["celltype"] = pd.Categorical(long["celltype"], categories=present_order, ordered=True)

diseases = sorted(long[DISEASE_KEY].unique())
print("\nDisease levels:", diseases)

cycle = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0","C1","C2","C3"])
colors = {dis: cycle[i % len(cycle)] for i, dis in enumerate(diseases)}
print("Color map:", colors)

fig, ax = plt.subplots(figsize=(14, 6))

n_groups = len(diseases)
base_positions = np.arange(len(present_order))
width = 0.35 if n_groups == 2 else 0.25
offsets = np.linspace(-width, width, n_groups) if n_groups > 1 else np.array([0.0])

all_positions = []
all_data = []
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = long.loc[(long["celltype"] == ct) & (long[DISEASE_KEY] == dis), "proportion"].values
        all_data.append(vals)
        all_positions.append(base_positions[i] + offsets[j])

ax.boxplot(
    all_data,
    positions=all_positions,
    widths=width * 0.8,
    showfliers=False,
    patch_artist=False,
)

rng = np.random.default_rng(0)
k = 0
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = all_data[k]
        x0 = all_positions[k]
        if len(vals) > 0:
            jitter = rng.normal(0, width * 0.08, size=len(vals))
            ax.scatter(
                np.full_like(vals, x0, dtype=float) + jitter,
                vals,
                s=12,
                alpha=0.7,
                color=colors[dis],
            )
        k += 1

ax.set_xticks(base_positions)
ax.set_xticklabels(present_order, rotation=45, ha="right")
ax.set_ylabel("Proportion of cells per patient")
ax.set_title("Cell-type composition (Level1_refined) by patient and group")
ax.set_ylim(0, max(0.05, float(long["proportion"].max()) * 1.15))

legend_handles = [
    Line2D([0],[0], marker="s", linestyle="None", markersize=8, label=dis,
           markerfacecolor=colors[dis], markeredgecolor=colors[dis])
    for dis in diseases
]
ax.legend(handles=legend_handles, title=DISEASE_KEY, frameon=False, loc="upper right")

plt.tight_layout()
out_png = OUT_FIG / "Fig1D_Composition_Level1refined_boxplots.png"
plt.savefig(out_png, dpi=300)
plt.close(fig)
print("\nSaved figure:", out_png)
print("\n[OK] Composición Level1_refined terminada.")

### 4. Level2_final: composición + CSVs + boxplots

In [None]:
# Level2_final_map.json
candidate_maps = [
    RESULTS_DIR / "summary_tables" / "conv_t_other_cleanup" / "Level2_final_map.json",
    RESULTS_DIR / "summary_tables" / "Level2_final_map.json",
]
MAP_PATH = next((p for p in candidate_maps if p.exists()), None)

print("\nMAP_PATH:", MAP_PATH)
if MAP_PATH is None:
    raise FileNotFoundError(
        "No encuentro Level2_final_map.json.\nProbé:\n" + "\n".join([f"- {x}" for x in candidate_maps])
    )

with open(MAP_PATH, "r", encoding="utf-8") as f:
    level2_map = json.load(f)

# orden por bloques incluye DC3
order_by_group = {
    "B":     ["B_Naive", "B_Memory", "B_Activated", "B_Atypical", "B_Other"],
    "Plasma":["Plasma"],
    "pDC":   ["pDC"],
    "T":     ["CD4_Naive","CD4_Memory","CD8_Naive","CD8_Effector_Cytotoxic","Treg","MAIT","GammaDelta_T","Proliferative_T","Exhausted_T"],
    "NK":    ["NK"],
    "Mono":  ["Classical_Mono","NonClassical_Mono","ISG_Myeloid","MonoDC_Other"],
    "DC":    ["cDC1","cDC2","DC3","DC4","aDC"],
    "HSCs":  ["HSCs"],
}

def group_of_l2(l2: str) -> str:
    for g, l2_list in order_by_group.items():
        if l2 in l2_list:
            return g
    return "Other"

obs2 = obs_all[[PATIENT_KEY, DISEASE_KEY, LEVEL2_KEY]].copy()

# map sin convertir antes a str para evitar NaN->"nan"
l2_obj = obs2[LEVEL2_KEY].astype("object")
obs2["Level2_final"] = l2_obj.replace(level2_map).astype("object")

# RBC-out
if (obs2["Level2_final"].astype(str) == "RBC").any():
    print("[WARN] Hay RBC en Level2_final. Se excluirá de la composición.")
    obs2 = obs2.loc[obs2["Level2_final"].astype(str) != "RBC"].copy()

present_l2 = sorted(obs2["Level2_final"].dropna().astype(str).unique().tolist())

level2_order = []
for g, l2_list in order_by_group.items():
    for l2 in l2_list:
        if l2 in present_l2:
            level2_order.append(l2)

extras = [x for x in present_l2 if x not in level2_order]
level2_order = level2_order + sorted(extras)

obs2["Level2_plot"] = obs2["Level2_final"].astype(str).map(lambda l2: f"{group_of_l2(l2)} | {l2}")
level2_plot_order = [f"{group_of_l2(l2)} | {l2}" for l2 in level2_order]
obs2["Level2_plot"] = pd.Categorical(obs2["Level2_plot"], categories=level2_plot_order, ordered=True)

# counts por paciente
counts2 = (obs2.groupby([PATIENT_KEY, DISEASE_KEY, "Level2_final"])
             .size()
             .unstack(fill_value=0)
             .reset_index())

celltype_cols2 = [c for c in counts2.columns if c not in [PATIENT_KEY, DISEASE_KEY]]
counts2["total_cells_patient"] = counts2[celltype_cols2].sum(axis=1)

props2 = counts2.copy()
den2 = props2["total_cells_patient"].replace(0, np.nan)
for c in celltype_cols2:
    props2[c] = props2[c] / den2

counts2_csv = OUT_SUMMARY / "cell_counts_Level2final_by_patient.csv"
props2_csv  = OUT_SUMMARY / "cell_proportions_Level2final_by_patient.csv"
counts2.to_csv(counts2_csv, index=False)
props2.to_csv(props2_csv, index=False)

print("Saved:", counts2_csv)
print("Saved:", props2_csv)

# long para plot
long2 = props2.melt(
    id_vars=[PATIENT_KEY, DISEASE_KEY, "total_cells_patient"],
    value_vars=celltype_cols2,
    var_name="Level2_final",
    value_name="proportion",
).dropna(subset=["proportion"])

long2["celltype"] = long2["Level2_final"].astype(str).map(lambda l2: f"{group_of_l2(l2)} | {l2}")
long2["celltype"] = pd.Categorical(long2["celltype"], categories=level2_plot_order, ordered=True)

diseases2 = sorted(long2[DISEASE_KEY].unique())
cycle = plt.rcParams["axes.prop_cycle"].by_key().get("color", ["C0","C1","C2","C3"])
colors2 = {dis: cycle[i % len(cycle)] for i, dis in enumerate(diseases2)}
print("Disease levels:", diseases2)

fig, ax = plt.subplots(figsize=(22, 6))

n_groups = len(diseases2)
base_positions = np.arange(len(level2_plot_order))
width = 0.35 if n_groups == 2 else 0.25
offsets = np.linspace(-width, width, n_groups) if n_groups > 1 else np.array([0.0])

all_positions = []
all_data = []
for i, ct in enumerate(level2_plot_order):
    for j, dis in enumerate(diseases2):
        vals = long2.loc[(long2["celltype"] == ct) & (long2[DISEASE_KEY] == dis), "proportion"].values
        all_data.append(vals)
        all_positions.append(base_positions[i] + offsets[j])

ax.boxplot(
    all_data,
    positions=all_positions,
    widths=width * 0.8,
    showfliers=False,
    patch_artist=False,
)

rng = np.random.default_rng(0)
k = 0
for i, ct in enumerate(level2_plot_order):
    for j, dis in enumerate(diseases2):
        vals = all_data[k]
        x0 = all_positions[k]
        if len(vals) > 0:
            jitter = rng.normal(0, width * 0.08, size=len(vals))
            ax.scatter(
                np.full_like(vals, x0, dtype=float) + jitter,
                vals,
                s=10,
                alpha=0.7,
                color=colors2[dis],
            )
        k += 1

ax.set_xticks(base_positions)
ax.set_xticklabels(level2_plot_order, rotation=90, ha="center", fontsize=6)
ax.set_ylabel("Proportion of cells per patient")
ax.set_title("Cell-type composition (Level2_final) by patient and group")
ax.set_ylim(0, max(0.05, float(long2["proportion"].max()) * 1.15))

legend_handles = [
    Line2D([0],[0], marker="s", linestyle="None", markersize=8, label=dis,
           markerfacecolor=colors2[dis], markeredgecolor=colors2[dis])
    for dis in diseases2
]
ax.legend(handles=legend_handles, title=DISEASE_KEY, frameon=False, loc="upper right")

plt.tight_layout()
out_png2 = OUT_FIG / "Fig2A_Composition_Level2final_boxplots.png"
plt.savefig(out_png2, dpi=300)
plt.close(fig)

print("Saved figure:", out_png2)
print("\n[OK] Composición Level2_final terminada.")