In [1]:
# ============================================================
# COMPOSICIÓN (Level1_refined) + BOXPLOTS por paciente y grupo
# + COMPOSICIÓN (Level2_final) + BOXPLOTS
# - usa SOLO obs (no carga X) -> seguro con RAM
# - genera CSVs + figuras .png (300 dpi)
# - post-RBC-out: si aparece RBC por cualquier motivo, se elimina aquí
# ============================================================

from pathlib import Path
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import json

NOTEBOOK_DIR = Path.cwd()

def find_project_root(start: Path) -> Path:
    for p in [start] + list(start.parents):
        if (p / "data_processed").exists():
            return p
    raise FileNotFoundError(f"No encuentro 'data_processed' subiendo desde: {start}")

PROJECT_ROOT = find_project_root(NOTEBOOK_DIR)
DATA_PROCESSED = PROJECT_ROOT / "data_processed"
IN_PATH = DATA_PROCESSED / "TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad"

OUT_SUMMARY = PROJECT_ROOT / "summary_tables_final"
OUT_FIG     = PROJECT_ROOT / "figures_final"
OUT_SUMMARY.mkdir(exist_ok=True)
OUT_FIG.mkdir(exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("IN_PATH     :", IN_PATH)
print("OUT_SUMMARY :", OUT_SUMMARY)
print("OUT_FIG     :", OUT_FIG)

# --- leer en backed para NO cargar X ---
adata_b = sc.read_h5ad(IN_PATH, backed="r")
print("Loaded:", adata_b)

# checks mínimos
needed = ["patientID", "disease", "Level1_refined", "Level2"]
missing = [c for c in needed if c not in adata_b.obs.columns]
if missing:
    adata_b.file.close()
    raise KeyError(f"Faltan columnas en obs: {missing}")

if "doublet_like" in adata_b.obs.columns:
    print("doublet_like True (debería ser 0):", int(adata_b.obs["doublet_like"].sum()))

# --- extraer SOLO obs (ligero) ---
obs_all = adata_b.obs[["patientID", "disease", "Level1_refined", "Level2"]].copy()
adata_b.file.close()

# limpieza ligera
obs_all["patientID"] = obs_all["patientID"].astype(str)
obs_all["disease"]   = obs_all["disease"].astype(str)
obs_all["Level1_refined"] = obs_all["Level1_refined"].astype(str)

# ============================================================
# 1) Level1_refined: composición + CSVs + boxplot
# ============================================================
obs = obs_all[["patientID", "disease", "Level1_refined"]].copy()

# RBC-out: si aún queda RBC por error, lo eliminamos aquí
n_rbc = int((obs["Level1_refined"] == "RBC").sum())
print("\nRBC en Level1_refined (debería ser 0):", n_rbc)
if n_rbc > 0:
    print("[WARN] Eliminando filas RBC para composición (post-RBC-out).")
    obs = obs.loc[obs["Level1_refined"] != "RBC"].copy()

# verificar pacientes por grupo
patients_per_group = (
    obs.drop_duplicates(["patientID", "disease"])
       .groupby("disease")["patientID"].nunique()
       .sort_values(ascending=False)
)
print("\nPacientes únicos por grupo (disease):")
print(patients_per_group)

# tabla counts por paciente
counts = (obs.groupby(["patientID", "disease", "Level1_refined"])
            .size()
            .unstack(fill_value=0)
            .reset_index())

celltype_cols = [c for c in counts.columns if c not in ["patientID", "disease"]]

# RBC-out extra: por si “RBC” quedó como columna residual, la quitamos
if "RBC" in celltype_cols:
    print("[WARN] Columna 'RBC' detectada en counts. Eliminándola (post-RBC-out).")
    counts = counts.drop(columns=["RBC"])
    celltype_cols = [c for c in celltype_cols if c != "RBC"]

counts["total_cells_patient"] = counts[celltype_cols].sum(axis=1)

props = counts.copy()
for c in celltype_cols:
    props[c] = props[c] / props["total_cells_patient"]

counts_csv = OUT_SUMMARY / "cell_counts_Level1refined_by_patient.csv"
props_csv  = OUT_SUMMARY / "cell_proportions_Level1refined_by_patient.csv"
counts.to_csv(counts_csv, index=False)
props.to_csv(props_csv, index=False)

print("\nSaved:", counts_csv)
print("Saved:", props_csv)

# plot Level1_refined
long = props.melt(
    id_vars=["patientID", "disease", "total_cells_patient"],
    value_vars=celltype_cols,
    var_name="celltype",
    value_name="proportion",
)

desired_order = ["B", "Plasma", "pDC", "T", "NK", "Mono", "DC", "HSCs"]
present_order = [x for x in desired_order if x in celltype_cols] + [x for x in celltype_cols if x not in desired_order]
long["celltype"] = pd.Categorical(long["celltype"], categories=present_order, ordered=True)

diseases = sorted(long["disease"].unique())
print("\nDisease levels:", diseases)

cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3'])
colors = {dis: cycle[i % len(cycle)] for i, dis in enumerate(diseases)}
print("Color map:", colors)

fig, ax = plt.subplots(figsize=(14, 6))

n_groups = len(diseases)
base_positions = np.arange(len(present_order))
width = 0.35 if n_groups == 2 else 0.25
offsets = np.linspace(-width, width, n_groups)

all_positions = []
all_data = []
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = long.loc[(long["celltype"] == ct) & (long["disease"] == dis), "proportion"].values
        all_data.append(vals)
        all_positions.append(base_positions[i] + offsets[j])

ax.boxplot(
    all_data,
    positions=all_positions,
    widths=width * 0.8,
    showfliers=False,
    patch_artist=False,
)

rng = np.random.default_rng(0)
k = 0
for i, ct in enumerate(present_order):
    for j, dis in enumerate(diseases):
        vals = all_data[k]
        x0 = all_positions[k]
        jitter = rng.normal(0, width * 0.08, size=len(vals))
        ax.scatter(
            np.full_like(vals, x0, dtype=float) + jitter,
            vals,
            s=12,
            alpha=0.7,
            color=colors[dis]
        )
        k += 1

ax.set_xticks(base_positions)
ax.set_xticklabels(present_order, rotation=45, ha="right")
ax.set_ylabel("Proportion of cells per patient")
ax.set_title("Cell-type composition (Level1_refined) by patient and disease")
ax.set_ylim(0, max(0.05, float(long["proportion"].max()) * 1.15))

legend_handles = [
    Line2D([0],[0], marker='s', linestyle='None', markersize=8, label=dis,
           markerfacecolor=colors[dis], markeredgecolor=colors[dis])
    for dis in diseases
]
ax.legend(handles=legend_handles, title="disease", frameon=False, loc="upper right")

plt.tight_layout()
out_png = OUT_FIG / "Fig1D_Composition_Level1refined_boxplots.png"
plt.savefig(out_png, dpi=300)
plt.close(fig)
print("\nSaved figure:", out_png)
print("\n[OK] Composición Level1_refined terminada.")


# ============================================================
# 2) Level2_final: composición + CSVs + boxplot
# ============================================================
MAP_PATH = OUT_SUMMARY / "Level2_final_map.json"
print("\nMAP_PATH:", MAP_PATH)
if not MAP_PATH.exists():
    raise FileNotFoundError(f"No existe {MAP_PATH}. Necesitas Level2_final_map.json en summary_tables_final/.")

with open(MAP_PATH, "r", encoding="utf-8") as f:
    level2_map = json.load(f)

order_by_group = {
    "B":     ["B_Naive", "B_Memory", "B_Activated", "B_Atypical", "B_Other"],
    "Plasma":["Plasma"],
    "pDC":   ["pDC"],
    "T":     ["CD4_Naive","CD4_Memory","CD8_Naive","CD8_Effector_Cytotoxic","Treg","MAIT","GammaDelta_T","Proliferative_T","Exhausted_T"],
    "NK":    ["NK"],
    "Mono":  ["Classical_Mono","NonClassical_Mono","ISG_Myeloid","MonoDC_Other"],
    "DC":    ["cDC1","cDC2","DC4","aDC"],
    "HSCs":  ["HSCs"],
}

def group_of_l2(l2: str) -> str:
    for g, l2_list in order_by_group.items():
        if l2 in l2_list:
            return g
    return "Other"

obs2 = obs_all[["patientID", "disease", "Level2"]].copy()

# KEY CHANGE: map sin convertir antes a str (evita NaN->"nan")
l2_obj = obs2["Level2"].astype("object")
obs2["Level2_final"] = l2_obj.replace(level2_map).astype("object")

# RBC-out robusto (por si algo raro)
if (obs2["Level2_final"].astype(str) == "RBC").any():
    print("[WARN] Hay RBC en Level2_final. Se excluirá de la composición.")
    obs2 = obs2.loc[obs2["Level2_final"].astype(str) != "RBC"].copy()

present_l2 = sorted(obs2["Level2_final"].dropna().astype(str).unique().tolist())

level2_order = []
for g, l2_list in order_by_group.items():
    for l2 in l2_list:
        if l2 in present_l2:
            level2_order.append(l2)

extras = [x for x in present_l2 if x not in level2_order]
level2_order = level2_order + sorted(extras)

obs2["Level2_plot"] = obs2["Level2_final"].astype(str).map(lambda l2: f"{group_of_l2(l2)} | {l2}")
level2_plot_order = [f"{group_of_l2(l2)} | {l2}" for l2 in level2_order]
obs2["Level2_plot"] = pd.Categorical(obs2["Level2_plot"], categories=level2_plot_order, ordered=True)

counts2 = (obs2.groupby(["patientID", "disease", "Level2_final"])
             .size()
             .unstack(fill_value=0)
             .reset_index())

celltype_cols2 = [c for c in counts2.columns if c not in ["patientID", "disease"]]
counts2["total_cells_patient"] = counts2[celltype_cols2].sum(axis=1)

props2 = counts2.copy()
for c in celltype_cols2:
    props2[c] = props2[c] / props2["total_cells_patient"]

counts2_csv = OUT_SUMMARY / "cell_counts_Level2final_by_patient.csv"
props2_csv  = OUT_SUMMARY / "cell_proportions_Level2final_by_patient.csv"
counts2.to_csv(counts2_csv, index=False)
props2.to_csv(props2_csv, index=False)

print("Saved:", counts2_csv)
print("Saved:", props2_csv)

long2 = props2.melt(
    id_vars=["patientID", "disease", "total_cells_patient"],
    value_vars=celltype_cols2,
    var_name="Level2_final",
    value_name="proportion",
)
long2["celltype"] = long2["Level2_final"].astype(str).map(lambda l2: f"{group_of_l2(l2)} | {l2}")
long2["celltype"] = pd.Categorical(long2["celltype"], categories=level2_plot_order, ordered=True)

diseases2 = sorted(long2["disease"].unique())
cycle = plt.rcParams['axes.prop_cycle'].by_key().get('color', ['C0','C1','C2','C3'])
colors2 = {dis: cycle[i % len(cycle)] for i, dis in enumerate(diseases2)}
print("Disease levels:", diseases2)

fig, ax = plt.subplots(figsize=(22, 6))

n_groups = len(diseases2)
base_positions = np.arange(len(level2_plot_order))
width = 0.35 if n_groups == 2 else 0.25
offsets = np.linspace(-width, width, n_groups)

all_positions = []
all_data = []
for i, ct in enumerate(level2_plot_order):
    for j, dis in enumerate(diseases2):
        vals = long2.loc[(long2["celltype"] == ct) & (long2["disease"] == dis), "proportion"].values
        all_data.append(vals)
        all_positions.append(base_positions[i] + offsets[j])

ax.boxplot(
    all_data,
    positions=all_positions,
    widths=width * 0.8,
    showfliers=False,
    patch_artist=False,
)

rng = np.random.default_rng(0)
k = 0
for i, ct in enumerate(level2_plot_order):
    for j, dis in enumerate(diseases2):
        vals = all_data[k]
        x0 = all_positions[k]
        jitter = rng.normal(0, width * 0.08, size=len(vals))
        ax.scatter(
            np.full_like(vals, x0, dtype=float) + jitter,
            vals,
            s=10,
            alpha=0.7,
            color=colors2[dis]
        )
        k += 1

ax.set_xticks(base_positions)
ax.set_xticklabels(level2_plot_order, rotation=90, ha="center", fontsize=6)
ax.set_ylabel("Proportion of cells per patient")
ax.set_title("Cell-type composition (Level2_final) by patient and disease")
ax.set_ylim(0, max(0.05, float(long2["proportion"].max()) * 1.15))

legend_handles = [
    Line2D([0],[0], marker='s', linestyle='None', markersize=8, label=dis,
           markerfacecolor=colors2[dis], markeredgecolor=colors2[dis])
    for dis in diseases2
]
ax.legend(handles=legend_handles, title="disease", frameon=False, loc="upper right")

plt.tight_layout()
out_png2 = OUT_FIG / "Fig2A_Composition_Level2final_boxplots.png"
plt.savefig(out_png2, dpi=300)
plt.close(fig)

print("Saved figure:", out_png2)
print("\n[OK] Composición Level2_final terminada.")


PROJECT_ROOT: D:\Users\Coni\Documents\TFM_CirrhosIS
IN_PATH     : D:\Users\Coni\Documents\TFM_CirrhosIS\data_processed\TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad
OUT_SUMMARY : D:\Users\Coni\Documents\TFM_CirrhosIS\summary_tables_final
OUT_FIG     : D:\Users\Coni\Documents\TFM_CirrhosIS\figures_final
Loaded: AnnData object with n_obs × n_vars = 220637 × 38606 backed at 'D:\\Users\\Coni\\Documents\\TFM_CirrhosIS\\data_processed\\TFM_CIRRHOSIS_main_filtered_for_analysis.h5ad'
    obs: 'orig.ident', 'nCount_RNA', 'nFeature_RNA', 'gem_id', 'patientID', 'age', 'sex', 'diagnostic', 'disease', 'disease_classification', 'disease_status', 'disease_grade', 'alternative_classification', 'comorbidity', 'sample_collection', 'scrublet_doublet_scores', 'scrublet_predicted_doublet', 'total_counts_from_X', 'n_genes_from_X', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', '