In [None]:
# Retry the merge with explicit suffix handling and pre-drop duplicate join columns from 'ind'.
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
# from caas_jupyter_tools import display_dataframe_to_user

# Reload from previous step outputs
base_out = Path("outputs/modeling_table.csv")
ind_path = Path("outputs/indicator_scores.csv")

modeling_table = pd.read_csv(base_out, parse_dates=["date"])
ind = pd.read_csv(ind_path)

# Determine join keys and drop overlapping columns from 'ind' except join keys and indicators
join_keys = []
if "sample_id" in modeling_table.columns and "sample_id" in ind.columns:
    join_keys = ["sample_id"]
elif "subject_id" in modeling_table.columns and "subject_id" in ind.columns:
    join_keys = ["subject_id"]

# Identify indicator columns in 'ind'
indicator_cols = [c for c in ind.columns if c in ["F_fiber","F_fermented","A_recent","A_any"]]
keep_cols = join_keys + indicator_cols
ind_slim = ind[keep_cols].copy().drop_duplicates(subset=join_keys)

merged = modeling_table.merge(ind_slim, on=join_keys, how="left", validate="m:1")

ordered_cols = ["subject_id","sample_id","date","t_days","butyrate","H_proxy","F_fiber","F_fermented","A_recent","A_any","butyrate_raw"]
for c in ordered_cols:
    if c not in merged.columns:
        merged[c] = np.nan
merged = merged[ordered_cols + [c for c in merged.columns if c not in ordered_cols]]

merged_out = Path("outputs/modeling_table_with_indicators.csv")
merged.to_csv(merged_out, index=False)
# display_dataframe_to_user("Modeling table (with indicators) preview", merged.head(100))

# Quick timeline plots for the top 6 subjects
subj_counts = merged.groupby("subject_id").size().sort_values(ascending=False)
top_subjects = [sid for sid in subj_counts.index[:6] if pd.notna(sid)]

for sid in top_subjects:
    sub = merged[merged["subject_id"]==sid].copy().sort_values("date")
    if sub["date"].isna().all():
        continue
    fig, ax1 = plt.subplots(figsize=(10,4.2))
    ax1.plot(sub["date"], sub["butyrate"], marker="o", label="Butyrate")
    ax1.set_title(f"Subject {sid}: Butyrate vs Diet & Antibiotics")
    ax1.set_xlabel("Date")
    ax1.set_ylabel("Butyrate (instrument units)")

    ax2 = ax1.twinx()
    if "F_fiber" in sub.columns:
        ax2.plot(sub["date"], sub["F_fiber"], marker="x", linestyle="-", label="F_fiber")
    if "F_fermented" in sub.columns:
        ax2.plot(sub["date"], sub["F_fermented"], marker="^", linestyle="--", label="F_fermented")
    ax2.set_ylabel("Diet scores (0-1)")

    if "A_recent" in sub.columns:
        for dt, ar in zip(sub["date"], sub["A_recent"].fillna(0)):
            if pd.notna(dt) and ar>0:
                ax1.axvline(dt, alpha=0.3, linestyle="--")

    ax1.legend(loc="upper left")
    ax2.legend(loc="upper right")
    plt.tight_layout()
    plt.show()

{"merged_modeling_table": str(merged_out), "subjects_plotted": top_subjects}
