In [None]:
# ============================================================
# TaxoClass-style HMTC pipeline (3.2 / 3.3 / 3.4 reflected)
# - 3.2.1: path score based top-down candidate selection
# - 3.2.2: conf vs parent/sibling + median saliency + multi-core set
# - 3.3: GNN class encoder + (log-)bilinear matching + pos/neg masked BCE
# - 3.4: multi-label self-training with hierarchy closure + masked loss
# ============================================================

import os
import re
import csv
import math
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional, Set

import numpy as np
from tqdm import tqdm


# ============================================================
# CONFIG
# ============================================================
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

BASE_DIR = "Amazon_products"
CLASSES_PATH   = os.path.join(BASE_DIR, "classes.txt")
HIER_PATH      = os.path.join(BASE_DIR, "class_hierarchy.txt")
KEYWORDS_PATH  = os.path.join(BASE_DIR, "class_related_keywords.txt")
TRAIN_PATH     = os.path.join(BASE_DIR, "train", "train_corpus.txt")
TEST_PATH      = os.path.join(BASE_DIR, "test", "test_corpus.txt")
SUBMISSION_PATH = "submission.csv"

NUM_CLASSES = 531
MIN_LABELS = 2
MAX_LABELS = 3

# -------- Stage-1: Bi-encoder (fast retrieval) ----------
BI_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
BI_BATCH = 256
BI_MAX_TOK_DOC = 256
BI_MAX_TOK_CLASS = 64

# -------- Stage-2: NLI (expensive similarity) ----------
NLI_MODEL_NAME = "textattack/roberta-base-MNLI"
NLI_BATCH = 32
NLI_MAX_LEN = 256
NLI_USE_FP16 = True
MAX_CHARS_DOC_FOR_NLI = 600

# -------- 3.2 core mining params (reflect paper) ----------
TOPK_CHILDREN_BI = 50       # bi로 자식 후보를 넓게 줄임 (속도/품질 trade-off)
MAX_LEVELS = 12             # taxonomy 깊이에 따라 조정 가능
MAX_CORES_PER_DOC = 5       # multi-core set 상한 (너무 많아지는 것 방지)
MIN_SIM_ABS = 0.55          # sim 자체가 너무 낮으면 core 후보에서 제외 (실무 안정화용)

# paper: at level l, for each node choose (l+2) children; and keep (l+1)^2 nodes.
# 아래 구현은 이 원칙을 따르되, NLI 호출은 bi로 줄인 후보에만 수행.

# -------- 3.3 classifier ----------
DOC_MODEL_NAME = "bert-base-uncased"  # paper uses BERT-base
DOC_MAX_TOK = 256
CLASS_NAME_MAX_TOK = 32              # class surface name encoding
TRAIN_EPOCHS = 2
LR = 2e-5
BATCH_SIZE = 12

# Matching function
MATCH_USE_EXP = False  # True로 하면 sigmoid(exp(bilinear)) 형태(권장X: 확률이 0.5 이상으로 쏠릴 수 있음)

# -------- 3.4 self-training ----------
SELF_TRAIN_EPOCHS = 1
PSEUDO_POS_TAU = 0.90        # 높은 confidence만 추가
PSEUDO_MIN_POS = 1

# HF controls
HF_TOKEN = os.getenv("HF_TOKEN", None)
HF_CACHE_DIR = os.getenv("HF_HOME", None)
HF_LOCAL_ONLY = False        # 인터넷 OFF면 True


# ============================================================
# Utils: IO + taxonomy parsing
# ============================================================
def load_corpus(path: str):
    pids, texts = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.rstrip("\n").split("\t", 1)
            if len(parts) == 2:
                pid, text = parts
                pids.append(pid)
                texts.append(text)
    return pids, texts

def load_classes(path: str):
    id2name = {}
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            s = line.strip()
            if not s:
                continue
            parts = re.split(r"[\t,]", s, maxsplit=1)
            if len(parts) == 2 and parts[0].strip().isdigit():
                cid = int(parts[0].strip())
                id2name[cid] = parts[1].strip()
            else:
                id2name[i] = s
    for cid in range(NUM_CLASSES):
        if cid not in id2name:
            id2name[cid] = f"class_{cid}"
    return id2name

def load_keywords_accumulate(path: str):
    kw = {cid: [] for cid in range(NUM_CLASSES)}

    def split_tokens(fields):
        out = []
        for r in fields:
            out.extend([x.strip() for x in re.split(r"[,;/|]", r) if x.strip()])
        out = [x for x in out if not re.fullmatch(r"[-+]?\d+(\.\d+)?", x)]
        return out

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = s.split("\t")
            if len(parts) == 1:
                parts = s.split(",")
            if len(parts) < 2:
                continue
            cid_raw = parts[0].strip()
            if not cid_raw.isdigit():
                continue
            cid = int(cid_raw)
            if cid < 0 or cid >= NUM_CLASSES:
                continue
            rest = [p.strip() for p in parts[1:] if p.strip()]
            toks = split_tokens(rest)

            seen = set(kw[cid])
            for t in toks:
                if t not in seen:
                    kw[cid].append(t)
                    seen.add(t)
    return kw

def read_edges(path: str, id2name: Dict[int, str]):
    name2id = {v: k for k, v in id2name.items()}

    def to_id(x):
        x = x.strip()
        if x.isdigit():
            return int(x)
        return name2id.get(x, None)

    edges = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = re.split(r"[\t,]", s)
            if len(parts) < 2:
                continue
            a = to_id(parts[0]); b = to_id(parts[1])
            if a is None or b is None or a == b:
                continue
            if 0 <= a < NUM_CLASSES and 0 <= b < NUM_CLASSES:
                edges.append((a, b))
    return edges

def build_graph(edges, parent_to_child=True):
    parents_of = {i: set() for i in range(NUM_CLASSES)}
    children_of = {i: set() for i in range(NUM_CLASSES)}
    for a, b in edges:
        if parent_to_child:
            p, c = a, b
        else:
            p, c = b, a
        if p != c:
            parents_of[c].add(p)
            children_of[p].add(c)

    roots = [i for i in range(NUM_CLASSES) if len(parents_of[i]) == 0]
    parents_of = {k: sorted(v) for k, v in parents_of.items()}
    children_of = {k: sorted(v) for k, v in children_of.items()}
    return roots, parents_of, children_of

def load_hierarchy_autodetect(path: str, id2name: Dict[int, str]):
    edges = read_edges(path, id2name)
    r1, p1, c1 = build_graph(edges, parent_to_child=True)
    r2, p2, c2 = build_graph(edges, parent_to_child=False)
    if 1 <= len(r2) < len(r1):
        return r2, p2, c2, "inverted(child->parent in file)"
    return r1, p1, c1, "as-is(parent->child in file)"

def compute_depths(roots, children_of):
    depth = np.full(NUM_CLASSES, 10**9, dtype=np.int32)
    from collections import deque
    q = deque()
    for r in roots:
        depth[r] = 0
        q.append(r)
    while q:
        u = q.popleft()
        for v in children_of.get(u, []):
            if depth[v] > depth[u] + 1:
                depth[v] = depth[u] + 1
                q.append(v)
    maxd = int(np.max(depth[depth < 10**9])) if np.any(depth < 10**9) else 0
    depth[depth >= 10**9] = maxd + 1
    return depth

def build_siblings(parents_of: Dict[int, List[int]], children_of: Dict[int, List[int]]) -> Dict[int, List[int]]:
    sib = {i: set() for i in range(NUM_CLASSES)}
    for c in range(NUM_CLASSES):
        for p in parents_of.get(c, []):
            for s in children_of.get(p, []):
                if s != c:
                    sib[c].add(s)
    return {k: sorted(v) for k, v in sib.items()}

def get_ancestors_all(cid: int, parents_of: Dict[int, List[int]], max_steps: int = 50) -> List[int]:
    # DAG 대비: 첫 부모만 타지 않고, 모든 부모를 확장(단, 폭이 커질 수 있으니 max_steps로 안전장치)
    out = []
    seen = set()
    frontier = [cid]
    steps = 0
    while frontier and steps < max_steps:
        nxt = []
        for x in frontier:
            for p in parents_of.get(x, []):
                if p not in seen:
                    seen.add(p)
                    out.append(p)
                    nxt.append(p)
        frontier = nxt
        steps += 1
    return out

def ensure_k_labels(primary: List[int], parents_of) -> List[int]:
    out = []
    seen = set()
    for x in primary:
        if x is None:
            continue
        if x not in seen:
            out.append(x); seen.add(x)
        if len(out) >= MAX_LABELS:
            break

    i = 0
    while len(out) < MIN_LABELS and i < len(out):
        for a in get_ancestors_all(out[i], parents_of, max_steps=10):
            if a not in seen:
                out.append(a); seen.add(a)
            if len(out) >= MIN_LABELS:
                break
        i += 1

    if len(out) < MIN_LABELS:
        for k in range(NUM_CLASSES):
            if k not in seen:
                out.append(k); seen.add(k)
            if len(out) >= MIN_LABELS:
                break

    return sorted(out[:MAX_LABELS])

def build_class_texts(id2name, kw, parents_of):
    # retrieval/entailment에 쓸 “설명 문장” 생성(기존과 유사)
    texts = []
    for cid in range(NUM_CLASSES):
        name = id2name[cid]
        path_ids = list(reversed(get_ancestors_all(cid, parents_of, max_steps=10))) + [cid]
        path_names = [id2name[i] for i in path_ids if i in id2name]
        path_str = " > ".join(path_names)

        kws = kw.get(cid, [])[:25]
        if kws:
            t = f"Category path: {path_str}. Keywords: " + ", ".join(kws) + "."
        else:
            t = f"Category path: {path_str}."
        texts.append(t)
    return texts


# ============================================================
# Stage-1: Bi-encoder
# ============================================================
class BiEncoder:
    def __init__(self, model_name: str, device: str, batch_size: int,
                 token: Optional[str] = None, cache_dir: Optional[str] = None, local_files_only: bool = False):
        try:
            from sentence_transformers import SentenceTransformer
        except ImportError as e:
            raise RuntimeError("pip install -q sentence-transformers") from e

        self.device = device
        self.batch_size = batch_size
        kwargs = {}
        if cache_dir is not None:
            kwargs["cache_folder"] = cache_dir
        if local_files_only:
            kwargs["local_files_only"] = True
        if token is not None:
            kwargs["token"] = token

        self.model = SentenceTransformer(model_name, device=device, **kwargs)

    def encode(self, texts: List[str], batch_size: Optional[int] = None) -> np.ndarray:
        bs = batch_size or self.batch_size
        emb = self.model.encode(
            texts,
            batch_size=bs,
            show_progress_bar=True,
            convert_to_numpy=True,
            normalize_embeddings=True,
        )
        return emb.astype(np.float32)


# ============================================================
# Stage-2: NLI scorer (sim)
# ============================================================
class EntailmentScorer:
    def __init__(
        self,
        model_name: str,
        device: str = "cuda",
        use_fp16: bool = True,
        max_chars_doc: int = 600,
        token: Optional[str] = None,
        cache_dir: Optional[str] = None,
        local_files_only: bool = False,
        verbose: bool = True,
    ):
        try:
            import torch
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
        except ImportError as e:
            raise RuntimeError("pip install -q transformers torch") from e

        self.torch = torch
        self.device = device
        self.max_chars_doc = int(max_chars_doc)
        self.verbose = verbose

        self.use_fp16 = bool(use_fp16 and device.startswith("cuda") and torch.cuda.is_available())

        common_kwargs = {
            "token": token,
            "cache_dir": cache_dir,
            "local_files_only": local_files_only,
        }

        if self.verbose:
            print(f"[EntailmentScorer] loading: {model_name}")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, **common_kwargs)
        model_kwargs = dict(common_kwargs)
        if self.use_fp16:
            model_kwargs["torch_dtype"] = torch.float16
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, **model_kwargs)
        self.model.to(self.device)
        self.model.eval()

        cfg = self.model.config
        entail_id = None
        if hasattr(cfg, "label2id") and isinstance(cfg.label2id, dict):
            for k, v in cfg.label2id.items():
                if "entail" in str(k).lower():
                    entail_id = int(v)
                    break
        if entail_id is None:
            entail_id = 2
        self.entail_id = entail_id

        if self.verbose:
            print(f"[EntailmentScorer] entail_id={self.entail_id} | fp16={self.use_fp16}")

    def _encode_batch(self, premise: str, hyps: list, max_len: int):
        return self.tokenizer(
            [premise] * len(hyps),
            hyps,
            truncation=True,
            padding=True,
            max_length=max_len,
            return_tensors="pt",
        )

    def score(self, premise: str, hypotheses: list, batch_size: int = 16, max_len: int = 256) -> np.ndarray:
        import numpy as np
        torch = self.torch

        if not hypotheses:
            return np.zeros((0,), dtype=np.float32)

        premise = (premise or "")[: self.max_chars_doc]
        bs = int(max(1, batch_size))
        out = []
        i = 0
        while i < len(hypotheses):
            cur = hypotheses[i:i+bs]
            try:
                enc = self._encode_batch(premise, cur, max_len=max_len)
                enc = {k: v.to(self.device, non_blocking=True) for k, v in enc.items()}

                if self.use_fp16:
                    with torch.inference_mode():
                        with torch.cuda.amp.autocast(dtype=torch.float16):
                            logits = self.model(**enc).logits
                            probs = torch.softmax(logits, dim=-1)[:, self.entail_id]
                else:
                    with torch.inference_mode():
                        logits = self.model(**enc).logits
                        probs = torch.softmax(logits, dim=-1)[:, self.entail_id]

                out.append(probs.detach().float().cpu().numpy())
                i += bs
            except RuntimeError as e:
                msg = str(e).lower()
                if ("out of memory" in msg or "cuda" in msg) and bs > 1 and self.device.startswith("cuda"):
                    torch.cuda.empty_cache()
                    bs = max(1, bs // 2)
                    continue
                raise

        return np.concatenate(out, axis=0).astype(np.float32)


# ============================================================
# 3.2: Core class mining
# - 3.2.1 path score candidate selection (top-down)
# - 3.2.2 conf vs parent/sibling + median saliency + multi-core set
# ============================================================
@dataclass
class DocMiningResult:
    cand_ids: List[int]                 # C_i^{cand}
    sim_cache: Dict[int, float]         # sim(D, c) for scored classes
    conf_map: Dict[int, float]          # conf(D, c) for candidate classes (computed after scoring parent/sib)
    core_ids: List[int]                 # final multi-core set C_i (after median threshold; filled later)

def _hyp(cid: int, id2name: Dict[int, str]) -> str:
    return f"This document is about {id2name[cid]}."

def _topk_by_bi(doc_emb: np.ndarray, cand_ids: List[int], class_embs: np.ndarray, k: int) -> List[int]:
    if not cand_ids:
        return []
    c = np.array(cand_ids, dtype=np.int32)
    sim = class_embs[c] @ doc_emb
    k_eff = min(k, sim.shape[0])
    idx = np.argpartition(-sim, k_eff-1)[:k_eff]
    idx = idx[np.argsort(-sim[idx])]
    return c[idx].tolist()

def core_candidate_selection_path_score(
    doc_text: str,
    doc_emb: np.ndarray,
    class_embs: np.ndarray,
    nli: EntailmentScorer,
    roots: List[int],
    children_of: Dict[int, List[int]],
    id2name: Dict[int, str],
    max_levels: int = 12,
) -> Tuple[List[int], Dict[int, float]]:
    """
    3.2.1: path score(ps) 기반 top-down 후보 선택
      - ps(root)=1 (여기서는 virtual root를 두고 roots를 children으로 취급)
      - level l에서 큐(선택된 노드들)의 각 노드에 대해:
          * 자식 중 (l+2)개를 sim 기준으로 선택
      - 다음 레벨 후보 중 ps 기준 상위 (l+1)^2개 유지
    반환:
      - candidate ids (visited nodes, excluding virtual root)
      - sim_cache: sim(D,c) 계산된 값 (NLI 기반)
    """
    sim_cache: Dict[int, float] = {}
    ps: Dict[int, float] = {}  # path score
    visited: Set[int] = set()

    # virtual root handling
    current = list(roots)
    # level 0: treat roots as children of root, choose 2 (since l=0 => (l+2)=2)
    l = 0
    need = min(2, len(current))
    if need == 0:
        return [], sim_cache

    # score roots (bi shortlist then NLI)
    root_bi = _topk_by_bi(doc_emb, current, class_embs, k=min(TOPK_CHILDREN_BI, len(current)))
    hyps = [_hyp(c, id2name) for c in root_bi]
    scores = nli.score(doc_text[:MAX_CHARS_DOC_FOR_NLI], hyps, batch_size=NLI_BATCH, max_len=NLI_MAX_LEN)
    for c, s in zip(root_bi, scores.tolist()):
        sim_cache[c] = float(s)

    root_sorted = sorted(root_bi, key=lambda c: sim_cache[c], reverse=True)[:need]
    # ps for selected roots: ps(root)=1 so ps(c)=sim(D,c)
    for c in root_sorted:
        ps[c] = max(ps.get(c, 0.0), sim_cache[c])
        visited.add(c)

    frontier = root_sorted[:]  # selected at level 0

    # levels 0..max_levels-1 expand to level+1
    for l in range(0, max_levels):
        if not frontier:
            break

        next_candidates: Set[int] = set()
        # For each node in frontier, choose (l+2) best children by sim
        for p in frontier:
            children = children_of.get(p, [])
            if not children:
                continue

            # bi shortlist
            bi_sel = _topk_by_bi(doc_emb, children, class_embs, k=min(TOPK_CHILDREN_BI, len(children)))
            if not bi_sel:
                continue

            # NLI score
            need_children = min(l + 2, len(bi_sel))
            hyps = [_hyp(c, id2name) for c in bi_sel]
            scores = nli.score(doc_text[:MAX_CHARS_DOC_FOR_NLI], hyps, batch_size=NLI_BATCH, max_len=NLI_MAX_LEN)
            for c, s in zip(bi_sel, scores.tolist()):
                sim_cache[c] = float(s)

            best_children = sorted(bi_sel, key=lambda c: sim_cache[c], reverse=True)[:need_children]

            # update ps(child) = max_{parent} ps(parent) * sim(child)
            for c in best_children:
                next_candidates.add(c)
                new_ps = ps.get(p, 0.0) * sim_cache.get(c, 0.0)
                if new_ps > ps.get(c, 0.0):
                    ps[c] = new_ps
                visited.add(c)

        if not next_candidates:
            break

        # Keep top (l+1)^2 by path score among next_candidates
        keep = min((l + 1) ** 2, len(next_candidates))
        frontier = sorted(list(next_candidates), key=lambda c: ps.get(c, 0.0), reverse=True)[:keep]

    cand_ids = sorted(list(visited))
    return cand_ids, sim_cache

def compute_conf_for_doc_candidates(
    doc_text: str,
    cand_ids: List[int],
    sim_cache: Dict[int, float],
    nli: EntailmentScorer,
    id2name: Dict[int, str],
    parents_of: Dict[int, List[int]],
    siblings_of: Dict[int, List[int]],
) -> Dict[int, float]:
    """
    3.2.2의 conf 정의:
      conf(D,c) = sim(D,c) - max_{c' in Par(c) ∪ Sib(c)} sim(D,c')
    부모/형제 sim이 없으면 NLI로 추가 계산(캐시 채움).
    """
    conf_map: Dict[int, float] = {}

    # collect missing sims to score in one batch per doc (for efficiency)
    to_score: List[int] = []
    need_set: Set[int] = set()

    for c in cand_ids:
        need_set.add(c)
        for p in parents_of.get(c, []):
            need_set.add(p)
        for s in siblings_of.get(c, []):
            need_set.add(s)

    for x in need_set:
        if x not in sim_cache:
            to_score.append(x)

    if to_score:
        hyps = [_hyp(x, id2name) for x in to_score]
        scores = nli.score(doc_text[:MAX_CHARS_DOC_FOR_NLI], hyps, batch_size=NLI_BATCH, max_len=NLI_MAX_LEN)
        for x, s in zip(to_score, scores.tolist()):
            sim_cache[x] = float(s)

    for c in cand_ids:
        sim_c = sim_cache.get(c, 0.0)
        # optional: 너무 낮은 sim은 core 후보에서 제외(노이즈 완화)
        if sim_c < MIN_SIM_ABS:
            continue

        comps = []
        comps.extend(parents_of.get(c, []))
        comps.extend(siblings_of.get(c, []))
        if comps:
            max_comp = max(sim_cache.get(x, 0.0) for x in comps)
        else:
            max_comp = 0.0
        conf_map[c] = float(sim_c - max_comp)

    return conf_map

def mine_cores_with_median_saliency(
    texts: List[str],
    doc_embs: np.ndarray,
    class_embs: np.ndarray,
    nli: EntailmentScorer,
    roots: List[int],
    parents_of: Dict[int, List[int]],
    children_of: Dict[int, List[int]],
    siblings_of: Dict[int, List[int]],
    id2name: Dict[int, str],
) -> Tuple[List[DocMiningResult], np.ndarray]:
    """
    전체 코퍼스에 대해:
      1) 3.2.1: 후보 cand_ids + sim_cache 산출
      2) 3.2.2: 각 문서에서 cand에 대해 conf 계산
      3) class-wise median(conf) 계산
      4) conf >= median(conf[class]) AND conf>0 인 cand들을 core로 채택 (multi-core)
    반환:
      - results: 문서별 mining 결과(코어 포함)
      - salient_mask: core_ids가 비어있지 않은 문서만 True
    """
    results: List[DocMiningResult] = []
    per_class_confs: List[List[float]] = [[] for _ in range(NUM_CLASSES)]

    print("[STEP2] 3.2 core mining: path score candidates -> conf -> class-wise median -> multi-core")
    for i, (t, e) in enumerate(tqdm(list(zip(texts, doc_embs)), desc="3.2 mining")):
        cand_ids, sim_cache = core_candidate_selection_path_score(
            doc_text=t,
            doc_emb=e,
            class_embs=class_embs,
            nli=nli,
            roots=roots,
            children_of=children_of,
            id2name=id2name,
            max_levels=MAX_LEVELS,
        )
        conf_map = compute_conf_for_doc_candidates(
            doc_text=t,
            cand_ids=cand_ids,
            sim_cache=sim_cache,
            nli=nli,
            id2name=id2name,
            parents_of=parents_of,
            siblings_of=siblings_of,
        )
        for c, v in conf_map.items():
            if 0 <= c < NUM_CLASSES:
                per_class_confs[c].append(float(v))

        results.append(DocMiningResult(
            cand_ids=cand_ids,
            sim_cache=sim_cache,
            conf_map=conf_map,
            core_ids=[],
        ))

    # class-wise median(conf)
    class_median = np.full(NUM_CLASSES, np.inf, dtype=np.float32)
    for c in range(NUM_CLASSES):
        if per_class_confs[c]:
            class_median[c] = float(np.median(np.array(per_class_confs[c], dtype=np.float32)))

    salient = np.zeros(len(results), dtype=bool)

    # decide multi-core set per doc
    for i, r in enumerate(results):
        core = []
        for c, conf in r.conf_map.items():
            if conf <= 0.0:
                continue
            med = class_median[c]
            if np.isfinite(med) and conf >= med:
                core.append(c)

        # sort by conf descending, cap
        core = sorted(core, key=lambda c: r.conf_map.get(c, -1e9), reverse=True)[:MAX_CORES_PER_DOC]
        r.core_ids = core
        salient[i] = (len(core) > 0)

    print(f"[STEP2] salient docs (non-empty multi-core): {int(salient.sum())} / {len(salient)}")
    return results, salient


# ============================================================
# 3.3: Core-guided classifier training
# - GNN class encoder
# - bilinear matching (matrix B)
# - pos/neg set design + masked BCE
# ============================================================
def build_pos_neg_sets_for_doc(
    core_ids: List[int],
    parents_of: Dict[int, List[int]],
    children_of: Dict[int, List[int]],
) -> Tuple[List[int], List[int]]:
    """
    Paper (Eq.7):
      C_pos = core ∪ parents(core)
      C_neg = C \ C_pos \ children(core)
    """
    pos: Set[int] = set()
    chd: Set[int] = set()

    for c in core_ids:
        pos.add(c)
        for p in get_ancestors_all(c, parents_of, max_steps=20):
            pos.add(p)
        for k in children_of.get(c, []):
            chd.add(k)

    # negative: all - pos - chd
    neg = [i for i in range(NUM_CLASSES) if (i not in pos and i not in chd)]
    return sorted(list(pos)), neg

def train_core_guided_classifier(
    train_texts: List[str],
    mining_results: List[DocMiningResult],
    salient_mask: np.ndarray,
    id2name: Dict[int, str],
    parents_of: Dict[int, List[int]],
    children_of: Dict[int, List[int]],
    device: str,
    epochs: int,
    lr: float,
    batch_size: int,
):
    """
    3.3 implementation:
      - doc encoder: BERT-base CLS
      - class encoder: GNN over taxonomy, init from class surface name embeddings
      - matcher: bilinear with matrix B (+ bias)
      - loss: masked BCE on pos/neg sets only
      - skip docs where core set empty (salient_mask)
    """
    try:
        import torch
        import torch.nn as nn
        from torch.utils.data import Dataset, DataLoader
        from transformers import AutoTokenizer, AutoModel
    except ImportError as e:
        raise RuntimeError("pip install -q transformers torch") from e

    torch.manual_seed(SEED)

    # Build training examples (only salient docs)
    idx = np.where(salient_mask)[0].tolist()
    if not idx:
        raise RuntimeError("No salient docs found. Consider relaxing thresholds (MIN_SIM_ABS / MAX_LEVELS / PSEUDO_POS_TAU).")

    ex_texts = [train_texts[i] for i in idx]
    ex_cores = [mining_results[i].core_ids for i in idx]

    # Build pos/neg lists for each example
    ex_pos_neg = []
    for core_ids in ex_cores:
        pos_ids, neg_ids = build_pos_neg_sets_for_doc(core_ids, parents_of, children_of)
        ex_pos_neg.append((pos_ids, neg_ids))

    tokenizer = AutoTokenizer.from_pretrained(DOC_MODEL_NAME, token=HF_TOKEN, cache_dir=HF_CACHE_DIR, local_files_only=HF_LOCAL_ONLY)
    doc_encoder = AutoModel.from_pretrained(DOC_MODEL_NAME, token=HF_TOKEN, cache_dir=HF_CACHE_DIR, local_files_only=HF_LOCAL_ONLY).to(device)

    # ---- class initial features: encode surface names with same transformer (once) ----
    class_names = [id2name[i] for i in range(NUM_CLASSES)]
    with torch.no_grad():
        doc_encoder.eval()
        feats = []
        for i in range(0, NUM_CLASSES, 64):
            batch = class_names[i:i+64]
            enc = tokenizer(batch, truncation=True, padding=True, max_length=CLASS_NAME_MAX_TOK, return_tensors="pt").to(device)
            out = doc_encoder(**enc).last_hidden_state[:, 0, :]  # CLS
            feats.append(out.detach())
        class_init = torch.cat(feats, dim=0)  # (C, d)
        d_model = class_init.size(-1)

    # ---- build sparse adjacency (undirected + self-loop), row-normalized mean ----
    def build_sparse_row_mean(edges: List[Tuple[int, int]], n: int, device: str):
        # edges list is directed; for undirected, add both ways outside
        rows = []
        cols = []
        for a, b in edges:
            rows.append(a); cols.append(b)
        # self loops
        for i in range(n):
            rows.append(i); cols.append(i)
        idx = torch.tensor([rows, cols], dtype=torch.long, device=device)
        vals = torch.ones(idx.size(1), dtype=torch.float32, device=device)
        A = torch.sparse_coo_tensor(idx, vals, size=(n, n)).coalesce()
        deg = torch.sparse.sum(A, dim=1).to_dense().clamp(min=1.0)
        # row-mean: D^{-1} A
        inv_deg = 1.0 / deg
        # apply row scaling by multiplying values with inv_deg[row]
        r = A.indices()[0]
        v = A.values() * inv_deg[r]
        A_mean = torch.sparse_coo_tensor(A.indices(), v, size=A.size(), device=device).coalesce()
        return A_mean

    # taxonomy edges: parent<->child
    undirected_edges = []
    for p, chs in children_of.items():
        for c in chs:
            undirected_edges.append((p, c))
            undirected_edges.append((c, p))

    # ego pooling matrix E: each row i averages over {i} ∪ parents(i) ∪ children(i)
    def build_sparse_ego_pool(parents_of, children_of, n: int, device: str):
        rows, cols = [], []
        for i in range(n):
            ego = set([i])
            ego.update(parents_of.get(i, []))
            ego.update(children_of.get(i, []))
            ego = sorted(list(ego))
            for j in ego:
                rows.append(i); cols.append(j)
        idx = torch.tensor([rows, cols], dtype=torch.long, device=device)
        vals = torch.ones(idx.size(1), dtype=torch.float32, device=device)
        E = torch.sparse_coo_tensor(idx, vals, size=(n, n)).coalesce()
        deg = torch.sparse.sum(E, dim=1).to_dense().clamp(min=1.0)
        inv_deg = 1.0 / deg
        r = E.indices()[0]
        v = E.values() * inv_deg[r]
        E_mean = torch.sparse_coo_tensor(E.indices(), v, size=E.size(), device=device).coalesce()
        return E_mean

    A_mean = build_sparse_row_mean(undirected_edges, NUM_CLASSES, device)
    E_mean = build_sparse_ego_pool(parents_of, children_of, NUM_CLASSES, device)

    class ClassGNNEncoder(nn.Module):
        def __init__(self, init_feats: torch.Tensor, A_mean, E_mean, hidden_dim: int):
            super().__init__()
            self.register_buffer("h0", init_feats)  # fixed init
            self.A = A_mean
            self.E = E_mean
            self.fc1 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.fc2 = nn.Linear(hidden_dim, hidden_dim, bias=True)
            self.act = nn.ReLU()

        def forward(self):
            # simple 2-layer mean-GCN style
            h = self.h0
            h = self.act(self.fc1(h + torch.sparse.mm(self.A, h)))
            h = self.act(self.fc2(h + torch.sparse.mm(self.A, h)))
            # ego pooling (paper's ego network representation)
            c = torch.sparse.mm(self.E, h)
            return c  # (C, d)

    class CoreGuidedMatcher(nn.Module):
        def __init__(self, doc_encoder, class_encoder, d_model: int):
            super().__init__()
            self.doc_encoder = doc_encoder
            self.class_encoder = class_encoder
            self.B = nn.Parameter(torch.eye(d_model))         # bilinear matrix
            self.bias = nn.Parameter(torch.zeros(NUM_CLASSES))
            self.use_exp = MATCH_USE_EXP

        def forward(self, input_ids, attention_mask):
            doc = self.doc_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]  # (B, d)
            cls = self.class_encoder()  # (C, d)
            # bilinear: (doc @ B) @ cls^T
            proj = doc @ self.B
            logits = proj @ cls.t() + self.bias
            if self.use_exp:
                logits = torch.exp(logits)  # optional: mimic σ(exp(.)) formulation
            return logits

    class PseudoDataset(Dataset):
        def __init__(self, texts: List[str], pos_neg: List[Tuple[List[int], List[int]]]):
            self.texts = texts
            self.pos_neg = pos_neg

        def __len__(self):
            return len(self.texts)

        def __getitem__(self, idx):
            t = self.texts[idx]
            pos_ids, neg_ids = self.pos_neg[idx]
            return t, np.array(pos_ids, dtype=np.int64), np.array(neg_ids, dtype=np.int64)

    def collate(batch):
        texts = [b[0] for b in batch]
        pos = [b[1] for b in batch]
        neg = [b[2] for b in batch]
        # build dense target/mask per batch only
        B = len(batch)
        y = np.zeros((B, NUM_CLASSES), dtype=np.float32)
        m = np.zeros((B, NUM_CLASSES), dtype=np.float32)
        for i in range(B):
            if pos[i].size > 0:
                y[i, pos[i]] = 1.0
                m[i, pos[i]] = 1.0
            if neg[i].size > 0:
                # y already 0
                m[i, neg[i]] = 1.0
        return texts, torch.tensor(y), torch.tensor(m)

    ds = PseudoDataset(ex_texts, ex_pos_neg)
    dl = DataLoader(ds, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=collate)

    class_encoder = ClassGNNEncoder(class_init, A_mean, E_mean, hidden_dim=d_model).to(device)
    model = CoreGuidedMatcher(doc_encoder, class_encoder, d_model=d_model).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=lr)
    bce = torch.nn.BCEWithLogitsLoss(reduction="none")

    def masked_bce_loss(logits, y, mask):
        # logits,y,mask: (B,C)
        loss = bce(logits, y) * mask
        denom = mask.sum().clamp(min=1.0)
        return loss.sum() / denom

    print(f"[STEP3] training core-guided matcher | samples={len(ds)} | device={device}")
    model.train()
    for ep in range(epochs):
        total = 0.0
        for texts, y, mask in tqdm(dl, desc=f"train ep{ep+1}/{epochs}"):
            enc = tokenizer(list(texts), truncation=True, padding=True, max_length=DOC_MAX_TOK, return_tensors="pt").to(device)
            y = y.to(device)
            mask = mask.to(device)

            logits = model(enc["input_ids"], enc["attention_mask"])
            loss = masked_bce_loss(logits, y, mask)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += float(loss.detach().cpu())
        avg = total / max(1, len(dl))
        print(f"[STEP3] epoch {ep+1}: loss={avg:.4f}")

    model.eval()
    return model, tokenizer


# ============================================================
# Prediction helpers + 3.4 self-training
# ============================================================
def predict_probs(model, tokenizer, texts: List[str], batch_size: int, device: str) -> np.ndarray:
    import torch
    probs = []
    for i in tqdm(range(0, len(texts), batch_size), desc="predict"):
        batch = texts[i:i+batch_size]
        enc = tokenizer(batch, truncation=True, padding=True, max_length=DOC_MAX_TOK, return_tensors="pt").to(device)
        with torch.inference_mode():
            logits = model(enc["input_ids"], enc["attention_mask"])
            p = torch.sigmoid(logits).detach().cpu().numpy()
        probs.append(p)
    return np.vstack(probs)

def self_training_build_examples(
    texts: List[str],
    probs: np.ndarray,
    parents_of: Dict[int, List[int]],
    children_of: Dict[int, List[int]],
    tau: float,
) -> Tuple[List[str], List[Tuple[List[int], List[int]]]]:
    """
    3.4 multi-label self-training:
      - p >= tau 인 라벨을 positive로 선택
      - ancestor closure 적용
      - negative는 paper처럼: all - pos - children(pos)
    """
    ex_texts = []
    ex_pos_neg = []
    for t, p in zip(texts, probs):
        pos = np.where(p >= tau)[0].tolist()
        if len(pos) < PSEUDO_MIN_POS:
            pos = [int(np.argmax(p))]

        # ancestor closure
        pos_set = set(pos)
        for c in list(pos_set):
            for a in get_ancestors_all(c, parents_of, max_steps=20):
                pos_set.add(a)

        # children 제외 negative
        chd = set()
        for c in list(pos_set):
            for k in children_of.get(c, []):
                chd.add(k)

        neg = [i for i in range(NUM_CLASSES) if (i not in pos_set and i not in chd)]
        ex_texts.append(t)
        ex_pos_neg.append((sorted(list(pos_set)), neg))
    return ex_texts, ex_pos_neg


# ============================================================
# MAIN
# ============================================================
def main():
    for p in [CLASSES_PATH, HIER_PATH, KEYWORDS_PATH, TRAIN_PATH, TEST_PATH]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing required file: {p}")

    try:
        import torch
        device = "cuda" if torch.cuda.is_available() else "cpu"
    except Exception:
        device = "cpu"
    print(f"[INFO] device: {device}")

    id2name = load_classes(CLASSES_PATH)
    kw = load_keywords_accumulate(KEYWORDS_PATH)
    roots, parents_of, children_of, note = load_hierarchy_autodetect(HIER_PATH, id2name)
    siblings_of = build_siblings(parents_of, children_of)
    depth_arr = compute_depths(roots, children_of)
    print(f"[INFO] taxonomy direction: {note} | roots={len(roots)} | max_depth≈{int(depth_arr.max())}")

    train_pids, train_texts = load_corpus(TRAIN_PATH)
    test_pids, test_texts = load_corpus(TEST_PATH)

    # Step 1: bi-encoder embeddings (docs + classes)
    class_texts = build_class_texts(id2name, kw, parents_of)
    bi = BiEncoder(
        model_name=BI_MODEL_NAME,
        device=device,
        batch_size=BI_BATCH if device == "cuda" else min(BI_BATCH, 128),
        token=HF_TOKEN,
        cache_dir=HF_CACHE_DIR,
        local_files_only=HF_LOCAL_ONLY,
    )
    print("[STEP1] Encoding class texts (bi-encoder)...")
    class_embs = bi.encode(class_texts, batch_size=min(BI_BATCH, 256))  # (C, D)

    print("[STEP1] Encoding documents (bi-encoder)...")
    all_texts = train_texts + test_texts
    doc_embs = bi.encode(all_texts)  # (N, D)

    # Step 2: NLI scorer + 3.2 mining
    nli = EntailmentScorer(
        model_name=NLI_MODEL_NAME,
        device=device,
        use_fp16=NLI_USE_FP16,
        max_chars_doc=MAX_CHARS_DOC_FOR_NLI,
        token=HF_TOKEN,
        cache_dir=HF_CACHE_DIR,
        local_files_only=HF_LOCAL_ONLY,
        verbose=True,
    )

    mining_all, salient_all = mine_cores_with_median_saliency(
        texts=all_texts,
        doc_embs=doc_embs,
        class_embs=class_embs,
        nli=nli,
        roots=roots,
        parents_of=parents_of,
        children_of=children_of,
        siblings_of=siblings_of,
        id2name=id2name,
    )

    n_train = len(train_texts)
    mining_train = mining_all[:n_train]
    salient_train = salient_all[:n_train]

    # Step 3: 3.3 core-guided training (masked BCE)
    model, tok = train_core_guided_classifier(
        train_texts=train_texts,
        mining_results=mining_train,
        salient_mask=salient_train,
        id2name=id2name,
        parents_of=parents_of,
        children_of=children_of,
        device=device,
        epochs=TRAIN_EPOCHS,
        lr=LR,
        batch_size=BATCH_SIZE,
    )

    # Step 4: 3.4 self-training (transductive: train+test)
    all_probs = predict_probs(model, tok, all_texts, batch_size=BATCH_SIZE, device=device)
    st_texts, st_pos_neg = self_training_build_examples(
        texts=all_texts,
        probs=all_probs,
        parents_of=parents_of,
        children_of=children_of,
        tau=PSEUDO_POS_TAU,
    )

    # retrain with (salient train) + (self-train pseudo)
    # 간단히: self-train 전체를 다시 학습 데이터로 사용 (현업에서는 중복/가중치 설계 가능)
    # 여기서는 3.3과 동일한 마스크 손실로 1 epoch만 추가 학습.
    # (재학습을 위해 mining 결과가 아닌 pos/neg 리스트를 직접 사용)
    try:
        import torch
        import torch.nn as nn
        from torch.utils.data import Dataset, DataLoader
        from transformers import AutoTokenizer
    except ImportError as e:
        raise RuntimeError("pip install -q transformers torch") from e

    class PseudoDataset2(Dataset):
        def __init__(self, texts: List[str], pos_neg: List[Tuple[List[int], List[int]]]):
            self.texts = texts
            self.pos_neg = pos_neg
        def __len__(self):
            return len(self.texts)
        def __getitem__(self, idx):
            t = self.texts[idx]
            pos, neg = self.pos_neg[idx]
            return t, np.array(pos, dtype=np.int64), np.array(neg, dtype=np.int64)

    def collate2(batch):
        texts = [b[0] for b in batch]
        pos = [b[1] for b in batch]
        neg = [b[2] for b in batch]
        B = len(batch)
        y = np.zeros((B, NUM_CLASSES), dtype=np.float32)
        m = np.zeros((B, NUM_CLASSES), dtype=np.float32)
        for i in range(B):
            if pos[i].size > 0:
                y[i, pos[i]] = 1.0
                m[i, pos[i]] = 1.0
            if neg[i].size > 0:
                m[i, neg[i]] = 1.0
        return texts, torch.tensor(y), torch.tensor(m)

    # reuse model/tok, continue training
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    bce = torch.nn.BCEWithLogitsLoss(reduction="none")

    def masked_bce_loss(logits, y, mask):
        loss = bce(logits, y) * mask
        denom = mask.sum().clamp(min=1.0)
        return loss.sum() / denom

    ds2 = PseudoDataset2(st_texts, st_pos_neg)
    dl2 = DataLoader(ds2, batch_size=BATCH_SIZE, shuffle=True, drop_last=False, collate_fn=collate2)

    print(f"[STEP4] self-training fine-tune | samples={len(ds2)} | epochs={SELF_TRAIN_EPOCHS}")
    model.train()
    for ep in range(SELF_TRAIN_EPOCHS):
        total = 0.0
        for texts, y, mask in tqdm(dl2, desc=f"self-train ep{ep+1}/{SELF_TRAIN_EPOCHS}"):
            enc = tok(list(texts), truncation=True, padding=True, max_length=DOC_MAX_TOK, return_tensors="pt").to(device)
            y = y.to(device)
            mask = mask.to(device)
            logits = model(enc["input_ids"], enc["attention_mask"])
            loss = masked_bce_loss(logits, y, mask)

            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += float(loss.detach().cpu())
        avg = total / max(1, len(dl2))
        print(f"[STEP4] epoch {ep+1}: loss={avg:.4f}")

    model.eval()

    # Final prediction on test
    test_probs = predict_probs(model, tok, test_texts, batch_size=BATCH_SIZE, device=device)

    test_labels = []
    for p in test_probs:
        top = np.argsort(-p)[:MAX_LABELS].tolist()
        labs = ensure_k_labels(top, parents_of)
        test_labels.append(labs)

    with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "labels"])
        for pid, labs in zip(test_pids, test_labels):
            w.writerow([pid, ",".join(map(str, labs))])

    print(f"[DONE] saved: {SUBMISSION_PATH} | test={len(test_pids)} | labels={MIN_LABELS}-{MAX_LABELS}")

if __name__ == "__main__":
    main()


Generating dummy predictions: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19658/19658 [00:00<00:00, 190266.11it/s]

Dummy submission file saved to: submission.csv
Total samples: 19658, Classes per sample: 1-3



