# Stage 01 — Prepare canonical items table (parquet-first)

This notebook creates the canonical `items.parquet` used by all downstream stages.

Outputs:
- `exports/stage_01_prepare/items.parquet`

Key fields:
- `item_id` (stable, deterministic)
- `source`, `split`, `label`
- `text` (optional; default template avoids label leakage)
- `image_path`, `width`, `height`, `mpp`

**Quality focus**
- no missing `image_path`
- unique `item_id`
- label/dataset distributions visible before embedding.


In [None]:
# --- Colab-first setup ---
import os, sys, time
from pathlib import Path

FORCE_REBUILD = False
FAST_MODE = True
EDA_LEVEL = "core"

SHOW_PLOTS = True
SAVE_PLOTS = True

DRIVE_SEARCH_BASE = "/content/drive/MyDrive"

def _is_colab() -> bool:
    return "google.colab" in sys.modules

if _is_colab():
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

def _resolve_project_root() -> Path:
    ev = os.environ.get("HISTO_PROJECT_ROOT")
    if ev and Path(ev).exists():
        return Path(ev)

    base = Path(DRIVE_SEARCH_BASE)
    candidates = []
    if base.exists():
        for p in base.glob("**/pipeline_config.yaml"):
            parent = p.parent
            if (parent / "label_taxonomy.yaml").exists():
                candidates.append(parent)
    if candidates:
        candidates = sorted(candidates, key=lambda p: p.stat().st_mtime, reverse=True)
        return candidates[0]

    p = Path.cwd()
    for _ in range(10):
        if (p / "pipeline_config.yaml").exists():
            return p
        p = p.parent
    raise FileNotFoundError("Could not resolve PROJECT_ROOT. Set HISTO_PROJECT_ROOT env var.")

PROJECT_ROOT = _resolve_project_root()
sys.path.insert(0, str(PROJECT_ROOT))
print("PROJECT_ROOT:", PROJECT_ROOT)

# Install deps
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", "-r", str(PROJECT_ROOT / "requirements.txt")])

import yaml
cfg = yaml.safe_load((PROJECT_ROOT / "pipeline_config.yaml").read_text())

EXPORTS_DIR = PROJECT_ROOT / str(cfg.get("paths", {}).get("exports_dir", "exports"))
RAW_DIR = PROJECT_ROOT / str(cfg.get("paths", {}).get("raw_dir", "data/raw"))
STAGING_DIR = PROJECT_ROOT / str(cfg.get("paths", {}).get("staging_dir", "data/staging"))

EXPORTS_DIR.mkdir(parents=True, exist_ok=True)
RAW_DIR.mkdir(parents=True, exist_ok=True)
STAGING_DIR.mkdir(parents=True, exist_ok=True)

SAFE_MODE = bool(cfg.get("project", {}).get("safe_mode", True))
SEED = int(cfg.get("project", {}).get("seed", 1337))

print("SAFE_MODE:", SAFE_MODE)


In [None]:
# --- Stage paths + registries ---
from pathlib import Path
import pandas as pd

from histo_cartography.viz import ensure_dir, save_and_display, register_plot, display_image
from histo_cartography.artifact_registry import register_artifact, append_stage_manifest
from histo_cartography.critic import run_critic, write_critic_report, critic_result_table, critic_issues_table

stage_dir = EXPORTS_DIR / "stage_01_prepare"
plots_dir = ensure_dir(stage_dir / "plots")
qa_dir = ensure_dir(stage_dir / "qa")
eda_dir = ensure_dir(stage_dir / "eda")

items_path = stage_dir / "items.parquet"
viz_records = []

print("stage_dir:", stage_dir)


## PEEP — Config sanity

In [None]:
# PEEP — show data config
data_cfg = cfg.get("data", {})
dataset_keys = data_cfg.get("dataset_keys") or [data_cfg.get("dataset_key", "CRC_VAL_HE_7K")]
split = str(data_cfg.get("split", "val"))
max_items = data_cfg.get("max_items_safe", 512) if SAFE_MODE else data_cfg.get("max_items_full", None)

use_text_modality = bool(data_cfg.get("use_text_modality", False))
text_template_version = str(data_cfg.get("text_template_version", "v2_no_label"))

print("dataset_keys:", dataset_keys)
print("split:", split)
print("max_items:", max_items)
print("use_text_modality:", use_text_modality)
print("text_template_version:", text_template_version)


## Stage logic — Build items.parquet (idempotent)

In [None]:
# --- Build items.parquet ---
from pathlib import Path
import pandas as pd

from histo_cartography import datasets
from histo_cartography.exports import save_parquet

t0 = time.time()

data_cfg = cfg.get("data", {})
dataset_keys = data_cfg.get("dataset_keys") or [data_cfg.get("dataset_key", "CRC_VAL_HE_7K")]
split = str(data_cfg.get("split", "val"))

verify_md5 = bool(data_cfg.get("download", {}).get("verify_md5", True))
allow_large = bool(data_cfg.get("download", {}).get("allow_large", False))
max_items = data_cfg.get("max_items_safe", 512) if SAFE_MODE else data_cfg.get("max_items_full", None)
overwrite = bool(data_cfg.get("force_reextract", False))

# Text config (avoid label leakage by default)
use_text_modality = bool(data_cfg.get("use_text_modality", False))
text_template_version = str(data_cfg.get("text_template_version", "v2_no_label"))

if items_path.exists() and not FORCE_REBUILD:
    items = pd.read_parquet(items_path)
    print(f"✅ Loaded existing items.parquet: {items.shape}")
else:
    parts = []
    for dk in dataset_keys:
        items_df, images_dir = datasets.prepare_dataset_to_staging(
            dk,
            raw_dir=RAW_DIR,
            staging_dir=STAGING_DIR,
            split=split,
            safe_mode=SAFE_MODE,
            max_items=max_items,
            seed=SEED,
            overwrite=overwrite,
            verify_md5=verify_md5,
            allow_large=allow_large,
            mpp=float(data_cfg.get("mpp", 0.5)),
            use_text_modality=use_text_modality,
            text_template_version=text_template_version,
        )
        parts.append(items_df)

    items = pd.concat(parts, ignore_index=True) if parts else pd.DataFrame()
    assert len(items) > 0, "items is empty. Check dataset_key and download settings."
    assert items["image_path"].isna().sum() == 0, "items has missing image_path values."
    assert items["item_id"].isna().sum() == 0, "items has missing item_id values."

    # Uniqueness
    dup = items["item_id"].duplicated().sum()
    if dup:
        raise ValueError(f"item_id not unique: duplicates={dup}")

    save_parquet(items, items_path)

runtime_sec = time.time() - t0
print("runtime_sec:", round(runtime_sec, 2))
items.head()


## CHECKPOINT — Items health gates

In [None]:
# CHECKPOINT: critic on items
from IPython.display import display

crit_items = run_critic(
    df=items,
    stage="stage_01_prepare",
    gate="checkpoint_items",
    required_cols=["item_id","source","split","label","text","image_path","width","height","mpp"],
    id_col="item_id",
    min_rows=100 if not SAFE_MODE else 10,
    key_nonnull_cols=["item_id","image_path"],
)

write_critic_report(crit_items, qa_dir / "critic_checkpoint_items.json")
display(critic_result_table(crit_items))
display(critic_issues_table(crit_items).head(50))


In [None]:
# Register artifact + stage manifest
schema_version = str(cfg.get("project", {}).get("schema_version", "0.1.0"))

register_artifact(
    project_root=PROJECT_ROOT,
    stage="stage_01_prepare",
    artifact="items",
    path=items_path,
    schema_version=schema_version,
    inputs=[],
    df=items,
    warnings_count=int(crit_items.warnings_count),
    fails_count=int(crit_items.fails_count),
    runtime_sec=float(runtime_sec),
    notes="canonical items table",
)

append_stage_manifest(
    project_root=PROJECT_ROOT,
    stage="stage_01_prepare",
    inputs=[],
    outputs=[items_path],
    schema_version=schema_version,
    warnings_count=int(crit_items.warnings_count),
    fails_count=int(crit_items.fails_count),
    runtime_sec=float(runtime_sec),
    notes="stage 01 run summary",
)


## POST — EDA (one plot per cell)

In [None]:
# POST plot 1 — Missingness
from histo_cartography.eda_reports import plot_missingness

fig = plot_missingness(items, top_k=25, title="Stage 01: items missingness (top 25)")
out_path = plots_dir / "missingness_top25.png"
save_and_display(fig, out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="missingness_top25", title="Items missingness (top 25)", path=out_path, tags=["post","missingness"], is_core=True)


In [None]:
# POST plot 2 — Source/dataset distribution
import matplotlib.pyplot as plt

vc = items["source"].astype(str).value_counts()

fig = plt.figure(figsize=(8, 4))
plt.bar(vc.index.astype(str), vc.values)
plt.xticks(rotation=45, ha="right")
plt.title("Source/dataset distribution")
plt.ylabel("n items")
plt.tight_layout()

out_path = plots_dir / "source_distribution.png"
save_and_display(fig, out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="source_distribution", title="Source/dataset distribution", path=out_path, tags=["post","distribution","dataset"], is_core=True)


In [None]:
# POST plot 3 — Label distribution (top 30)
import matplotlib.pyplot as plt

vc = items["label"].astype(str).value_counts().head(30)

fig = plt.figure(figsize=(8, 4))
plt.bar(vc.index.astype(str), vc.values)
plt.xticks(rotation=45, ha="right")
plt.title("Label distribution (top 30)")
plt.ylabel("n items")
plt.tight_layout()

out_path = plots_dir / "label_distribution_top30.png"
save_and_display(fig, out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="label_distribution_top30", title="Label distribution (top 30)", path=out_path, tags=["post","distribution","label"], is_core=True)


In [None]:
# POST plot 4 — Image resolution scatter
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(5, 5))
plt.scatter(items["width"].astype(float), items["height"].astype(float), s=10, alpha=0.5)
plt.xlabel("width")
plt.ylabel("height")
plt.title("Image resolution scatter")
plt.tight_layout()

out_path = plots_dir / "image_resolution_scatter.png"
save_and_display(fig, out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="image_resolution_scatter", title="Image resolution scatter", path=out_path, tags=["post","images"], is_core=True)


In [None]:
# POST plot 5 — Text length histogram
import matplotlib.pyplot as plt

lens = items["text"].astype(str).map(len)

fig = plt.figure(figsize=(7, 4))
plt.hist(lens, bins=30, edgecolor="black")
plt.title("Text length histogram")
plt.xlabel("chars")
plt.ylabel("count")
plt.tight_layout()

out_path = plots_dir / "text_length_hist.png"
save_and_display(fig, out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="text_length_hist", title="Text length histogram", path=out_path, tags=["post","text"], is_core=True)


In [None]:
# POST plot 6 — Quick montage sample across labels (glass-box)
from histo_cartography.image_viz import montage_sample

out_path = plots_dir / "montage_sample_stage01.png"
montage_sample(items, out_path=out_path, image_col="image_path", n=49, random_state=SEED)
display_image(out_path)
register_plot(viz_records, stage="stage_01_prepare", plot_id="montage_sample_stage01", title="Sample montage (Stage 01)", path=out_path, tags=["post","montage"], is_core=True)


In [None]:
# Write viz index (parquet + csv)
from IPython.display import display
from histo_cartography.viz import write_viz_index, viz_records_to_df

viz_index_path = stage_dir / "viz_index.parquet"
write_viz_index(viz_records, out_parquet=viz_index_path, out_csv=stage_dir / "viz_index.csv")
display(viz_records_to_df(viz_records).head(60))
print("✅ wrote viz_index:", viz_index_path)


## Next actions
- Run Stage 02 to compute embeddings.
- If you enabled `use_text_modality=True`, ensure `text_template_version` does not leak labels unless intended.