In [None]:
"""balanced_sampling_150.py – create a *150-example* balanced dataset and split into train/val.

Steps:
1. Load source records (with EuroVoc descriptors).
2. Determine the smallest per-label cap that yields ≥ 150 distinct examples with
   the greedy balanced sampler (rare labels first).
3. If more than 150 are selected, randomly drop extras to reach 150 exactly.
4. Split 2/3, 1/3 into train/val (shuffled, stratification not re-enforced).
5. Write JSONL files and show distribution summaries/plots for both splits.
"""

from __future__ import annotations

import json
import random
from collections import Counter, defaultdict
from pathlib import Path
from typing import List


In [None]:

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
DATA_PATH = Path("data/dev_with_descriptors.jsonl")
OUT_TRAIN = Path("data/build_hour_train.jsonl")
OUT_val = Path("data/build_hour_val.jsonl")

SEED = 1
random.seed(SEED)

TARGET_TOTAL = 150  # total examples across train+val
TRAIN_RATIO = 0.67   # fraction sent to train; rest to val
LABEL_FIELD = "level_1"


In [None]:
# ---------------------------------------------------------------------------
# Utilities (re-used from previous script)
# ---------------------------------------------------------------------------

def load_records(path: Path) -> List[dict]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]


def gather_label_indices(records: List[dict], label_field: str) -> dict[str, List[int]]:
    per_label: dict[str, List[int]] = defaultdict(list)
    for idx, rec in enumerate(records):
        for lbl in rec.get("eurovoc_concepts", {}).get(label_field, []):
            per_label[lbl].append(idx)
    return per_label


def greedy_balanced_sampling(
    records: List[dict],
    label_field: str,
    cap: int,
    already_selected: set[int] | None = None,
) -> set[int]:
    """Greedy selection with per-label cap (see balanced_sampling.py for logic)."""
    if already_selected is None:
        already_selected = set()

    per_label = gather_label_indices(records, label_field)
    label_counts = Counter()
    for idx in already_selected:
        for lbl in records[idx].get("eurovoc_concepts", {}).get(label_field, []):
            label_counts[lbl] += 1

    selected: set[int] = set(already_selected)
    labels_sorted = sorted(per_label.keys(), key=lambda l: len(per_label[l]))

    for lbl in labels_sorted:
        if label_counts[lbl] >= cap:
            continue
        candidates = [i for i in per_label[lbl] if i not in selected]
        random.shuffle(candidates)
        for idx in candidates:
            labels_idx = records[idx].get("eurovoc_concepts", {}).get(label_field, [])
            if all(label_counts[x] < cap for x in labels_idx):
                selected.add(idx)
                for x in labels_idx:
                    label_counts[x] += 1
                if label_counts[lbl] >= cap:
                    break
    return selected - already_selected


def write_split(path: Path, indices: List[int], records: List[dict]) -> None:
    """Write dataset entries in the same schema used in exploration.py.

    Each line is a JSON object:
        { "item": { "id": "<seq>", "text_input": "…", "reference_answer": [..] } }
    where <seq> is a zero-based index within the current split file.
    """

    with path.open("w", encoding="utf-8") as f_out:
        for new_id, idx in enumerate(indices):
            rec = records[idx]
            level1 = rec.get("eurovoc_concepts", {}).get("level_1", [])
            obj = {
                "item": {
                    "id": str(new_id),
                    "text_input": rec.get("text", {}).get("en", ""),
                    "reference_answer": level1,
                }
            }
            f_out.write(json.dumps(obj, ensure_ascii=False) + "\n")
    print(f"Wrote {len(indices)} records → {path}")


import plotly.graph_objects as go

def plot_distribution(name: str, indices: List[int], records: List[dict], label_field: str) -> None:
    # Count occurrences of each descriptor in the selected indices
    counter = Counter()
    for idx in indices:
        counter.update(records[idx].get("eurovoc_concepts", {}).get(label_field, []))

    # Sort descriptors alphabetically (case-insensitive, descending)
    sorted_items = sorted(counter.items(), key=lambda x: x[1])
    descriptors, counts = zip(*sorted_items) if sorted_items else ([], [])

    # Plot with Plotly using a single brand color (OpenAI blue), dark mode
    bar_color = "#0071cf"
    fig = go.Figure(
        go.Bar(
            x=list(counts),
            y=list(descriptors),
            orientation="h",
            marker=dict(color=bar_color),
        )
    )

    fig.update_layout(
        title=f"{name} – {label_field} frequencies",
        xaxis_title="Frequency",
        yaxis_title="",
        template="plotly_dark",
        height=700,
        width=900,
        yaxis=dict(automargin=True),
        font=dict(family="Inter, sans-serif", color="#FFFFFF"),
        plot_bgcolor="#111111",
        paper_bgcolor="#111111",
    )

    fig.show()


In [None]:

# ---------------------------------------------------------------------------
# Interactive workflow
# ---------------------------------------------------------------------------

# 1. Load data
records = load_records(DATA_PATH)
per_label = gather_label_indices(records, LABEL_FIELD)
print(f"Loaded {len(records):,} records with {len(per_label)} unique {LABEL_FIELD} labels")


In [None]:
# 2. Greedy balanced sampling to reach TARGET_TOTAL
selected: set[int] = set()
cap = 1

while True:
    new = greedy_balanced_sampling(records, LABEL_FIELD, cap, already_selected=selected)
    selected.update(new)
    print(f"Cap={cap:<2}  => selected so far: {len(selected)}")
    if len(selected) >= TARGET_TOTAL or cap > 50:  # safety break
        break
    cap += 1

# If overshoot, randomly drop extras to reach exactly TARGET_TOTAL
if len(selected) > TARGET_TOTAL:
    surplus = len(selected) - TARGET_TOTAL
    selected = set(random.sample(list(selected), TARGET_TOTAL))
    print(f"Trimmed {surplus} excess examples to hit {TARGET_TOTAL}")

selected_list = sorted(selected)
random.shuffle(selected_list)


In [None]:
# 3. Train / val split using custom quota-based stratification

def deficit_stratified_split(indices: list[int], train_ratio: float) -> tuple[list[int], list[int]]:
    """Symmetric deficit-based multi-label stratification.

    Chooses for each item the split (train or val) that currently has the
    greater summed deficiency for that item's labels. Guarantees exact
    train/val sizes and tends to equalize label distributions.
    """

    rng_local = random.Random(SEED)

    total = len(indices)
    train_target_size = int(round(total * train_ratio))

    # Compute per-label totals
    total_counts: Counter[str] = Counter()
    for idx in indices:
        total_counts.update(records[idx].get("eurovoc_concepts", {}).get(LABEL_FIELD, []))

    # Desired counts per split
    target_train = {lbl: int(round(cnt * train_ratio)) for lbl, cnt in total_counts.items()}
    target_val = {lbl: total_counts[lbl] - target_train[lbl] for lbl in total_counts}

    train_counts: Counter[str] = Counter()
    val_counts: Counter[str] = Counter()
    train, val = [], []

    shuffled = indices.copy()
    rng_local.shuffle(shuffled)

    for idx in shuffled:
        labels = records[idx].get("eurovoc_concepts", {}).get(LABEL_FIELD, [])

        # Compute deficits
        train_def = sum(max(0, target_train[lbl] - train_counts[lbl]) for lbl in labels)
        val_def = sum(max(0, target_val[lbl] - val_counts[lbl]) for lbl in labels)

        choose_train = False
        if len(train) >= train_target_size:
            choose_train = False
        elif len(val) >= total - train_target_size:
            choose_train = True
        else:
            choose_train = train_def > val_def or (train_def == val_def and len(train) < train_target_size)

        if choose_train:
            train.append(idx)
            train_counts.update(labels)
        else:
            val.append(idx)
            val_counts.update(labels)

    return sorted(train), sorted(val)

train_indices, val_indices = deficit_stratified_split(selected_list, TRAIN_RATIO)

print(
    f"Final sizes – Train: {len(train_indices)}, val: {len(val_indices)} (cap used = {cap})"
)


In [None]:
# 4. Write splits
write_split(OUT_TRAIN, train_indices, records)
write_split(OUT_val, val_indices, records)

In [None]:
# 5. Combined distribution (train + val)
plot_distribution("All", sorted(selected_list), records, LABEL_FIELD)