In [None]:
# coding: utf-8
"""
Vertex AI Express mode (API Key) + RoBERTa embedding 기반 hierarchical multi-label 분류
- Endpoint: https://aiplatform.googleapis.com/v1/publishers/google/models/{model}:generateContent?key=API_KEY
  (Express mode는 project/location 없이 global endpoint 사용) :contentReference[oaicite:2]{index=2}

조건:
  1) gemini-2.0-flash-lite-001 + class/class_related_keywords로 class description 생성 (약 54회 호출)
  2) FacebookAI/roberta-base 임베딩: class description & (train+test) 리뷰 텍스트
  3) leaf node only cosine similarity로 core class 결정. threshold 이하(애매)면 LLM batch(10개)로 판정
  4) 최대 3개의 hierarchical class: parent들과 cosine sim 비교해 greedy로 채택, sim 급감하면 중단
  5) test_corpus 결과를 submission.csv ["id","labels"]로 저장
"""

import os
import re
import csv
import json
import time
import math
import random
import requests
import numpy as np
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel


# =========================
# User Config
# =========================
API_KEY = "AQ.Ab8RN6JUb82VrSIvjvHxX1wie7AH4c-SzUSd6pGwiHcCbRUt3g"  # <-- Express mode API Key 입력
GEMINI_MODEL = "gemini-2.0-flash-lite-001"  # 요구사항 준수

ROOT_DIR = "Amazon_products"
CLASSES_PATH = os.path.join(ROOT_DIR, "classes.txt")
HIER_PATH = os.path.join(ROOT_DIR, "class_hierarchy.txt")
KW_PATH = os.path.join(ROOT_DIR, "class_related_keywords.txt")
TRAIN_CORPUS_PATH = os.path.join(ROOT_DIR, "train", "train_corpus.txt")
TEST_CORPUS_PATH = os.path.join(ROOT_DIR, "test", "test_corpus.txt")

ARTIFACT_DIR = os.path.join(ROOT_DIR, "_artifacts")
os.makedirs(ARTIFACT_DIR, exist_ok=True)

CLASS_DESC_JSON = os.path.join(ARTIFACT_DIR, "class_descriptions.json")
CLASS_EMB_NPY = os.path.join(ARTIFACT_DIR, "class_embeddings.npy")
REVIEW_EMB_NPY = os.path.join(ARTIFACT_DIR, "review_embeddings.npy")
REVIEW_META_JSONL = os.path.join(ARTIFACT_DIR, "review_meta.jsonl")

SUBMISSION_PATH = "submission.csv"

# API usage guard
API_CALL_LIMIT = 1000
API_CALL_COUNT = 0

# 531 classes -> 약 54회 호출 목표
DESC_CHUNK_SIZE = 10

# Embedding model
EMB_MODEL_NAME = "FacebookAI/roberta-base"
MAX_SEQ_LEN = 256
EMB_BATCH_SIZE = 64

# LLM fallback batching
TOPK_CANDIDATES_FOR_LLM = 5
LLM_BATCH_SIZE = 10

# Threshold calibration (train+test 모두 사용)
PERCENTILE_FOR_THRESHOLD = 15
THRESHOLD_FLOOR = 0.20

# Parent greedy selection
PARENT_MIN_SIM = 0.15
DROP_ABS_DELTA = 0.10      # 절대 감소
DROP_REL_RATIO = 0.25      # 상대 감소(25% 이상 감소면 급감으로 판단)


# =========================
# Reproducibility
# =========================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)


# =========================
# Helpers
# =========================
def is_int_str(s: str) -> bool:
    return bool(re.fullmatch(r"\d+", s.strip()))


def safe_json_load(path: str, default):
    if not os.path.exists(path):
        return default
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def safe_json_save(obj, path: str):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, ensure_ascii=False, indent=2)


def normalize_rows(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
    n = np.linalg.norm(x, axis=1, keepdims=True)
    return x / np.clip(n, eps, None)


# =========================
# Load taxonomy files
# =========================
def load_classes(path: str):
    """
    지원 포맷:
      - id<TAB>name
      - id name...
      - name only (라인 인덱스를 id로 사용)
    """
    id2name = {}
    with open(path, "r", encoding="utf-8") as f:
        for idx, line in enumerate(f):
            line = line.strip()
            if not line:
                continue

            parts = re.split(r"\t+", line, maxsplit=1)
            if len(parts) == 2 and is_int_str(parts[0]):
                cid = int(parts[0])
                name = parts[1].strip()
            else:
                sp = line.split(" ", 1)
                if len(sp) == 2 and is_int_str(sp[0]):
                    cid = int(sp[0])
                    name = sp[1].strip()
                else:
                    cid = idx
                    name = line.strip()

            id2name[cid] = name
    return id2name


def load_keywords(path: str):
    """
    지원 포맷(유연):
      - id<TAB>kw1,kw2,...
      - id: kw1, kw2 ...
      - id<TAB>kw1 kw2 ...
    """
    id2kws = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            line = line.replace(":", "\t")
            parts = re.split(r"\t+", line, maxsplit=1)
            if len(parts) != 2 or not is_int_str(parts[0]):
                continue
            cid = int(parts[0])
            raw = parts[1].strip()
            if "," in raw:
                kws = [k.strip() for k in raw.split(",") if k.strip()]
            else:
                kws = [k.strip() for k in raw.split() if k.strip()]
            id2kws[cid] = kws
    return id2kws


def load_hierarchy(path: str):
    """
    class_hierarchy.txt에서 (a,b) 페어들을 읽고,
    parent->child 또는 child->parent 중 더 그럴듯한 방향(루트가 적고 cycle 적은 쪽)을 선택.
    """
    with open(path, "r", encoding="utf-8") as f:
        lines = [ln.strip() for ln in f if ln.strip()]

    pairs = []
    for line in lines:
        parts = re.split(r"[\t, ]+", line)
        if len(parts) < 2:
            continue
        a, b = parts[0], parts[1]
        if is_int_str(a) and is_int_str(b):
            pairs.append((int(a), int(b)))

    def build(p2c: bool):
        parent_of = {}
        children_of = {}
        for x, y in pairs:
            p, c = (x, y) if p2c else (y, x)
            children_of.setdefault(p, set()).add(c)
            parent_of.setdefault(c, p)  # 단일 parent 가정
        return parent_of, children_of

    def has_cycle(parent_of):
        seen, visiting = set(), set()

        def dfs(n):
            if n in visiting:
                return True
            if n in seen:
                return False
            visiting.add(n)
            p = parent_of.get(n)
            if p is not None and dfs(p):
                return True
            visiting.remove(n)
            seen.add(n)
            return False

        for n in list(parent_of.keys()):
            if dfs(n):
                return True
        return False

    def count_roots(parent_of, children_of):
        nodes = set(parent_of.keys()) | set(children_of.keys())
        return sum(1 for n in nodes if n not in parent_of)

    p1, c1 = build(True)
    p2, c2 = build(False)
    cyc1, cyc2 = has_cycle(p1), has_cycle(p2)
    r1, r2 = count_roots(p1, c1), count_roots(p2, c2)

    if (not cyc1) and cyc2:
        return p1, c1
    if (not cyc2) and cyc1:
        return p2, c2
    if cyc1 and cyc2:
        return p1, c1
    return (p1, c1) if r1 <= r2 else (p2, c2)


def get_leaf_nodes(all_class_ids, children_of):
    return sorted([cid for cid in all_class_ids if cid not in children_of or len(children_of[cid]) == 0])


# =========================
# Vertex AI Express mode (API key) non-stream generateContent
# =========================
def vertex_generate_content(prompt: str, temperature: float = 0.2, max_output_tokens: int = 2048) -> str:
    """
    Express mode non-stream endpoint:
      POST https://aiplatform.googleapis.com/v1/{model}:generateContent
    where {model} is 'publishers/google/models/*'
    and API key is passed via ?key=... :contentReference[oaicite:3]{index=3}
    """
    global API_CALL_COUNT
    if API_CALL_COUNT >= API_CALL_LIMIT:
        raise RuntimeError(f"API_CALL_LIMIT exceeded ({API_CALL_LIMIT}).")

    url = (
        "https://aiplatform.googleapis.com/v1/publishers/google/models/"
        f"{GEMINI_MODEL}:generateContent?key={API_KEY}"
    )

    payload = {
        "contents": [{"role": "user", "parts": [{"text": prompt}]}],
        "generationConfig": {
            "temperature": temperature,
            "maxOutputTokens": max_output_tokens,
        },
    }

    headers = {"Content-Type": "application/json"}

    backoff = 2.0
    for attempt in range(6):
        r = requests.post(url, headers=headers, json=payload, timeout=90)
        if r.status_code == 200:
            API_CALL_COUNT += 1
            data = r.json()
            return _extract_text_from_generatecontent(data)

        if r.status_code in (429, 500, 503):
            time.sleep(backoff)
            backoff *= 1.7
            continue

        raise RuntimeError(f"Vertex generateContent error: {r.status_code} {r.text}")

    raise RuntimeError("Vertex generateContent failed after retries.")


def _extract_text_from_generatecontent(data: dict) -> str:
    """
    Non-stream response에서 candidates[0].content.parts[].text를 합침
    """
    try:
        cands = data.get("candidates", [])
        if not cands:
            return json.dumps(data, ensure_ascii=False)
        parts = cands[0].get("content", {}).get("parts", [])
        out = []
        for p in parts:
            t = p.get("text")
            if t:
                out.append(t)
        return "".join(out).strip()
    except Exception:
        return json.dumps(data, ensure_ascii=False)


# =========================
# 1) Class description generation (cached)
# =========================
def build_desc_prompt(chunk_items):
    example = {"class_id": 0, "description": "1-3 sentence English description of the category."}
    lines = []
    lines.append("You generate short English descriptions for product taxonomy classes.")
    lines.append("Return STRICT JSON only. No markdown. No commentary.")
    lines.append("Output must be a JSON array; each item: {class_id:int, description:string}.")
    lines.append(f"Example: {json.dumps(example)}")
    lines.append("")
    lines.append("Classes:")
    for it in chunk_items:
        lines.append(json.dumps({
            "class_id": it["class_id"],
            "class_name": it["class_name"],
            "related_keywords": it.get("keywords", [])[:25],
            "parent_name": it.get("parent_name", None),
        }, ensure_ascii=False))
    return "\n".join(lines)


def parse_desc_response(text: str):
    m = re.search(r"\[[\s\S]*\]", text)
    if not m:
        raise ValueError(f"Description response is not JSON array. Raw:\n{text[:400]}")
    arr = json.loads(m.group(0))
    out = {}
    for obj in arr:
        out[int(obj["class_id"])] = str(obj["description"]).strip()
    return out


def generate_class_descriptions(id2name, id2kws, parent_of):
    existing = safe_json_load(CLASS_DESC_JSON, default={})
    existing_norm = {}
    for k, v in existing.items():
        try:
            existing_norm[int(k)] = v
        except Exception:
            continue

    all_ids = sorted(id2name.keys())
    missing = [cid for cid in all_ids if cid not in existing_norm]

    if not missing:
        return existing_norm

    def parent_name(cid):
        p = parent_of.get(cid)
        return id2name.get(p) if p is not None else None

    chunks = [missing[i:i + DESC_CHUNK_SIZE] for i in range(0, len(missing), DESC_CHUNK_SIZE)]
    for chunk in tqdm(chunks, desc="Generating class descriptions (generateContent)"):
        chunk_items = []
        for cid in chunk:
            chunk_items.append({
                "class_id": cid,
                "class_name": id2name.get(cid, f"class_{cid}"),
                "keywords": id2kws.get(cid, []),
                "parent_name": parent_name(cid),
            })

        prompt = build_desc_prompt(chunk_items)
        resp = vertex_generate_content(prompt, temperature=0.2, max_output_tokens=2048)
        parsed = parse_desc_response(resp)

        for cid, desc in parsed.items():
            existing_norm[cid] = {
                "name": id2name.get(cid, f"class_{cid}"),
                "keywords": id2kws.get(cid, []),
                "description": desc,
            }

        safe_json_save({str(k): v for k, v in existing_norm.items()}, CLASS_DESC_JSON)

    return existing_norm


# =========================
# 2) RoBERTa embeddings
# =========================
@torch.no_grad()
def mean_pool(last_hidden_state, attention_mask):
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)
    summed = torch.sum(last_hidden_state * mask, dim=1)
    counts = torch.clamp(mask.sum(dim=1), min=1e-9)
    return summed / counts


def embed_texts_roberta(texts, tokenizer, model, device, batch_size=64, max_len=256):
    embs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="Embedding"):
        batch = texts[i:i+batch_size]
        enc = tokenizer(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
        enc = {k: v.to(device) for k, v in enc.items()}
        out = model(**enc)
        pooled = mean_pool(out.last_hidden_state, enc["attention_mask"])
        pooled = torch.nn.functional.normalize(pooled, p=2, dim=1)
        embs.append(pooled.cpu().numpy())
    return np.vstack(embs)


def load_corpus_any(path: str, split_name: str):
    rows = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.rstrip("\n")
            if not line.strip():
                continue
            parts = line.split("\t")
            if len(parts) == 2:
                pid, text = parts
            elif len(parts) >= 3:
                pid = parts[0]
                text = parts[-1]
            else:
                continue
            rows.append({"split": split_name, "pid": pid, "text": text})
    return rows


def build_or_load_review_store():
    if os.path.exists(REVIEW_META_JSONL):
        train_rows, test_rows, all_rows = [], [], []
        with open(REVIEW_META_JSONL, "r", encoding="utf-8") as f:
            for line in f:
                r = json.loads(line)
                all_rows.append(r)
                (train_rows if r.get("split") == "train" else test_rows).append(r)
        return train_rows, test_rows, all_rows

    train_rows = load_corpus_any(TRAIN_CORPUS_PATH, "train")
    test_rows = load_corpus_any(TEST_CORPUS_PATH, "test")
    all_rows = train_rows + test_rows

    with open(REVIEW_META_JSONL, "w", encoding="utf-8") as f:
        for r in all_rows:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return train_rows, test_rows, all_rows


def build_or_load_embeddings(class_desc_map, all_reviews):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    tokenizer = AutoTokenizer.from_pretrained(EMB_MODEL_NAME)
    model = AutoModel.from_pretrained(EMB_MODEL_NAME).to(device)
    model.eval()

    if os.path.exists(CLASS_EMB_NPY):
        class_emb = np.load(CLASS_EMB_NPY)
    else:
        max_id = max(class_desc_map.keys())
        class_texts = []
        for cid in range(max_id + 1):
            info = class_desc_map.get(cid)
            if info is None:
                class_texts.append(f"class_{cid}")
            else:
                nm = info.get("name", f"class_{cid}")
                kws = ", ".join(info.get("keywords", [])[:20])
                desc = info.get("description", "")
                class_texts.append(f"Class: {nm}. Keywords: {kws}. Description: {desc}")

        class_emb = embed_texts_roberta(
            class_texts, tokenizer, model, device,
            batch_size=EMB_BATCH_SIZE, max_len=MAX_SEQ_LEN
        )
        class_emb = normalize_rows(class_emb)
        np.save(CLASS_EMB_NPY, class_emb)

    if os.path.exists(REVIEW_EMB_NPY):
        review_emb = np.load(REVIEW_EMB_NPY)
    else:
        texts = [r["text"] for r in all_reviews]
        review_emb = embed_texts_roberta(
            texts, tokenizer, model, device,
            batch_size=EMB_BATCH_SIZE, max_len=MAX_SEQ_LEN
        )
        review_emb = normalize_rows(review_emb)
        np.save(REVIEW_EMB_NPY, review_emb)

    return normalize_rows(class_emb), normalize_rows(review_emb)


# =========================
# 3) Core leaf selection + LLM fallback
# =========================
def calibrate_threshold(leaf_emb: np.ndarray, review_emb_all: np.ndarray) -> float:
    max_sims = []
    leaf_t = leaf_emb.T
    bs = 2048
    for i in range(0, review_emb_all.shape[0], bs):
        chunk = review_emb_all[i:i+bs]
        sims = chunk @ leaf_t
        max_sims.append(sims.max(axis=1))
    max_sims = np.concatenate(max_sims, axis=0)
    thr = float(np.percentile(max_sims, PERCENTILE_FOR_THRESHOLD))
    return max(thr, THRESHOLD_FLOOR)


def build_llm_batch_prompt(batch_items, id2name, class_desc_map):
    lines = []
    lines.append("You classify each review into ONE best leaf class among provided candidates.")
    lines.append("Return STRICT JSON only. Output must be a JSON array of objects: {idx:int, chosen_class_id:int}.")
    lines.append("No markdown. No explanations.")
    lines.append("")
    for item in batch_items:
        cands = []
        for cid, score in item["candidates"]:
            info = class_desc_map.get(cid, {})
            cands.append({
                "class_id": cid,
                "class_name": id2name.get(cid, info.get("name", f"class_{cid}")),
                "description": info.get("description", ""),
                "similarity_hint": round(float(score), 4),
            })
        lines.append(json.dumps({
            "idx": item["idx"],
            "review_text": item["text"],
            "candidates": cands,
        }, ensure_ascii=False))
    return "\n".join(lines)


def parse_llm_choice_response(text: str):
    m = re.search(r"\[[\s\S]*\]", text)
    if not m:
        raise ValueError(f"LLM choice response is not JSON array. Raw:\n{text[:400]}")
    arr = json.loads(m.group(0))
    out = {}
    for obj in arr:
        out[int(obj["idx"])] = int(obj["chosen_class_id"])
    return out


def infer_core_leaf_classes(leaf_ids, leaf_emb, review_emb_all, all_reviews, id2name, class_desc_map):
    global API_CALL_COUNT

    leaf_t = leaf_emb.T
    n = review_emb_all.shape[0]
    best_leaf = np.empty(n, dtype=np.int32)
    best_sim = np.empty(n, dtype=np.float32)

    bs = 2048
    for i in tqdm(range(0, n, bs), desc="Core leaf by cosine (top1)"):
        chunk = review_emb_all[i:i+bs]
        sims = chunk @ leaf_t
        arg = sims.argmax(axis=1)
        best_leaf[i:i+bs] = np.array([leaf_ids[int(p)] for p in arg], dtype=np.int32)
        best_sim[i:i+bs] = sims.max(axis=1).astype(np.float32)

    thr = calibrate_threshold(leaf_emb, review_emb_all)
    print(f"[Threshold] ambiguity threshold={thr:.4f} (percentile={PERCENTILE_FOR_THRESHOLD}, floor={THRESHOLD_FLOOR})")

    ambiguous_idx = np.where(best_sim < thr)[0].tolist()
    print(f"[Ambiguous] {len(ambiguous_idx)} / {n} below threshold -> LLM fallback (batched {LLM_BATCH_SIZE}).")

    core_leaf_for_idx = {i: int(best_leaf[i]) for i in range(n)}

    if not ambiguous_idx:
        return core_leaf_for_idx

    remaining = API_CALL_LIMIT - API_CALL_COUNT
    needed_calls = math.ceil(len(ambiguous_idx) / LLM_BATCH_SIZE)
    if needed_calls > remaining:
        print(f"[Warn] LLM fallback needs {needed_calls} calls but only {remaining} remain. "
              f"Remaining ambiguous will keep cosine-top1 result.")
        ambiguous_idx = ambiguous_idx[:remaining * LLM_BATCH_SIZE]

    batches = [ambiguous_idx[i:i+LLM_BATCH_SIZE] for i in range(0, len(ambiguous_idx), LLM_BATCH_SIZE)]
    for b in tqdm(batches, desc="LLM fallback (generateContent)"):
        batch_items = []
        for idx in b:
            rv = review_emb_all[idx]
            sims = rv @ leaf_t
            topk_pos = np.argpartition(-sims, TOPK_CANDIDATES_FOR_LLM - 1)[:TOPK_CANDIDATES_FOR_LLM]
            topk_pos = topk_pos[np.argsort(-sims[topk_pos])]
            cands = [(leaf_ids[int(p)], float(sims[int(p)])) for p in topk_pos]
            batch_items.append({"idx": idx, "text": all_reviews[idx]["text"], "candidates": cands})

        prompt = build_llm_batch_prompt(batch_items, id2name, class_desc_map)
        resp = vertex_generate_content(prompt, temperature=0.0, max_output_tokens=2048)
        chosen_map = parse_llm_choice_response(resp)

        for i2, cid in chosen_map.items():
            core_leaf_for_idx[int(i2)] = int(cid)

    return core_leaf_for_idx


# =========================
# 4) Greedy parent selection (max 3 labels)
# =========================
def select_hierarchical_labels(review_vec_norm, core_leaf_cid, parent_of, class_emb_all):
    labels = [int(core_leaf_cid)]
    child = int(core_leaf_cid)

    prev_sim = float(review_vec_norm @ class_emb_all[child])

    while len(labels) < 3:
        p = parent_of.get(child)
        if p is None:
            break
        p = int(p)

        p_sim = float(review_vec_norm @ class_emb_all[p])

        if p_sim < PARENT_MIN_SIM:
            break

        abs_drop = (prev_sim - p_sim)
        rel_drop = abs_drop / max(prev_sim, 1e-6)
        if abs_drop >= DROP_ABS_DELTA or rel_drop >= DROP_REL_RATIO:
            break

        labels.append(p)
        child = p
        prev_sim = p_sim

    labels = sorted(set(labels))
    return labels[:3]


# =========================
# 5) Submission writer
# =========================
def write_submission(test_rows, all_rows, review_emb_all, leaf_ids, leaf_emb, pid_to_labels, out_path):
    with open(out_path, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "labels"])
        for r in test_rows:
            pid = r["pid"]
            labels = pid_to_labels.get(pid, [])
            if not labels:
                # fallback: 최소 1개 라벨
                idx = None
                for i, rr in enumerate(all_rows):
                    if rr.get("split") == "test" and rr.get("pid") == pid:
                        idx = i
                        break
                if idx is not None:
                    sims = review_emb_all[idx] @ leaf_emb.T
                    cid = leaf_ids[int(np.argmax(sims))]
                    labels = [int(cid)]
            w.writerow([pid, ",".join(map(str, labels))])


# =========================
# Main
# =========================
def main():
    global API_CALL_COUNT

    id2name = load_classes(CLASSES_PATH)
    id2kws = load_keywords(KW_PATH)
    parent_of, children_of = load_hierarchy(HIER_PATH)

    all_class_ids = sorted(id2name.keys())
    leaf_ids = get_leaf_nodes(all_class_ids, children_of)
    print(f"[Taxonomy] classes={len(all_class_ids)}, leaf={len(leaf_ids)}")

    # 1) class descriptions (cached)
    class_desc_map = generate_class_descriptions(id2name, id2kws, parent_of)
    print(f"[Descriptions] ready={len(class_desc_map)} | API used={API_CALL_COUNT}/{API_CALL_LIMIT}")

    # train+test 모두 사용
    train_rows, test_rows, all_rows = build_or_load_review_store()
    print(f"[Corpus] train={len(train_rows)}, test={len(test_rows)}, all={len(all_rows)}")

    # 2) embeddings (cached)
    class_emb_all, review_emb_all = build_or_load_embeddings(class_desc_map, all_rows)
    class_emb_all = normalize_rows(class_emb_all)
    review_emb_all = normalize_rows(review_emb_all)

    leaf_emb = class_emb_all[leaf_ids]

    # 3) core leaf inference (+ LLM fallback)
    core_leaf_for_idx = infer_core_leaf_classes(
        leaf_ids, leaf_emb, review_emb_all, all_rows, id2name, class_desc_map
    )
    print(f"[Core leaf] done | API used={API_CALL_COUNT}/{API_CALL_LIMIT}")

    # 4) hierarchical labels for test only
    pid_to_labels = {}
    for i, r in tqdm(list(enumerate(all_rows)), desc="Hierarchical labeling"):
        if r.get("split") != "test":
            continue
        pid = r["pid"]
        rv = review_emb_all[i]
        core = core_leaf_for_idx[i]
        labels = select_hierarchical_labels(rv, core, parent_of, class_emb_all)
        pid_to_labels[pid] = labels

    # 5) save submission
    write_submission(test_rows, all_rows, review_emb_all, leaf_ids, leaf_emb, pid_to_labels, SUBMISSION_PATH)
    print(f"[Done] submission saved: {SUBMISSION_PATH}")
    print(f"[API usage] total calls: {API_CALL_COUNT}/{API_CALL_LIMIT}")


if __name__ == "__main__":
    main()
