In [None]:
import os
import re
import csv
import random
import numpy as np
from tqdm import tqdm

# ------------------------
# CONFIG
# ------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

BASE_DIR = "Amazon_products"
CLASSES_PATH   = os.path.join(BASE_DIR, "classes.txt")
HIER_PATH      = os.path.join(BASE_DIR, "class_hierarchy.txt")
KEYWORDS_PATH  = os.path.join(BASE_DIR, "class_related_keywords.txt")
TRAIN_PATH     = os.path.join(BASE_DIR, "train", "train_corpus.txt")
TEST_PATH      = os.path.join(BASE_DIR, "test", "test_corpus.txt")

SUBMISSION_PATH = "submission.csv"

NUM_CLASSES = 531
MIN_LABELS = 2
MAX_LABELS = 3

# Sentence-BERT model (빠르고 성능 괜찮은 기본값)
SBERT_NAME = "sentence-transformers/all-MiniLM-L6-v2"
BATCH_SIZE_ENC = 64        # GPU면 128~256 가능
MAX_CHARS_DOC = 1500       # 너무 긴 문서 자르기 (속도/노이즈)
DEPTH_ALPHA = 0.03         # "더 구체(깊은) 클래스"를 살짝 선호하는 가중치


# ------------------------
# I/O
# ------------------------
def load_corpus(path):
    pids, texts = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.rstrip("\n").split("\t", 1)
            if len(parts) == 2:
                pid, text = parts
                pids.append(pid)
                texts.append(text)
    return pids, texts

def load_classes(path):
    id2name = {}
    with open(path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            s = line.strip()
            if not s:
                continue
            parts = re.split(r"[\t,]", s, maxsplit=1)
            if len(parts) == 2 and parts[0].strip().isdigit():
                cid = int(parts[0].strip())
                id2name[cid] = parts[1].strip()
            else:
                id2name[i] = s
    for cid in range(NUM_CLASSES):
        if cid not in id2name:
            id2name[cid] = f"class_{cid}"
    return id2name

def load_keywords_accumulate(path):
    kw = {cid: [] for cid in range(NUM_CLASSES)}

    def split_tokens(fields):
        out = []
        for r in fields:
            out.extend([x.strip() for x in re.split(r"[,;/|]", r) if x.strip()])
        # 숫자 토큰 제거(가중치가 섞여있는 경우가 많음)
        out = [x for x in out if not re.fullmatch(r"[-+]?\d+(\.\d+)?", x)]
        return out

    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = s.split("\t")
            if len(parts) == 1:
                parts = s.split(",")
            if len(parts) < 2:
                continue
            cid_raw = parts[0].strip()
            if not cid_raw.isdigit():
                continue
            cid = int(cid_raw)
            if cid < 0 or cid >= NUM_CLASSES:
                continue
            rest = [p.strip() for p in parts[1:] if p.strip()]
            toks = split_tokens(rest)

            seen = set(kw[cid])
            for t in toks:
                if t not in seen:
                    kw[cid].append(t)
                    seen.add(t)
    return kw

def read_edges(path, id2name):
    name2id = {v: k for k, v in id2name.items()}

    def to_id(x):
        x = x.strip()
        if x.isdigit():
            return int(x)
        return name2id.get(x, None)

    edges = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            s = line.strip()
            if not s:
                continue
            parts = re.split(r"[\t,]", s)
            if len(parts) < 2:
                continue
            a = to_id(parts[0]); b = to_id(parts[1])
            if a is None or b is None or a == b:
                continue
            if 0 <= a < NUM_CLASSES and 0 <= b < NUM_CLASSES:
                edges.append((a, b))
    return edges

def build_graph(edges, parent_to_child=True):
    parents_of = {i: set() for i in range(NUM_CLASSES)}
    children_of = {i: set() for i in range(NUM_CLASSES)}

    for a, b in edges:
        if parent_to_child:
            p, c = a, b
        else:
            p, c = b, a
        if p != c:
            parents_of[c].add(p)
            children_of[p].add(c)

    roots = [i for i in range(NUM_CLASSES) if len(parents_of[i]) == 0]
    parents_of = {k: sorted(v) for k, v in parents_of.items()}
    children_of = {k: sorted(v) for k, v in children_of.items()}
    return roots, parents_of, children_of

def load_hierarchy_autodetect(path, id2name):
    edges = read_edges(path, id2name)
    r1, p1, c1 = build_graph(edges, parent_to_child=True)
    r2, p2, c2 = build_graph(edges, parent_to_child=False)
    # 보통 taxonomy는 root 수가 작습니다. root가 더 적은 방향을 선택
    if 1 <= len(r2) < len(r1):
        return r2, p2, c2, "inverted(child->parent in file)"
    return r1, p1, c1, "as-is(parent->child in file)"

def compute_depths(roots, children_of):
    # BFS로 최소 depth 계산
    depth = np.full(NUM_CLASSES, 10**9, dtype=np.int32)
    from collections import deque
    q = deque()
    for r in roots:
        depth[r] = 0
        q.append(r)
    while q:
        u = q.popleft()
        for v in children_of.get(u, []):
            if depth[v] > depth[u] + 1:
                depth[v] = depth[u] + 1
                q.append(v)
    # 도달 불가 노드가 있으면 큰 값이 유지되므로 정리
    maxd = int(np.max(depth[depth < 10**9])) if np.any(depth < 10**9) else 0
    depth[depth >= 10**9] = maxd + 1
    return depth

def get_ancestors(cid, parents_of, max_steps=10):
    anc = []
    cur = cid
    for _ in range(max_steps):
        plist = parents_of.get(cur, [])
        if not plist:
            break
        p = plist[0]  # deterministic: smallest parent
        anc.append(p)
        cur = p
    return anc

def cores_to_labels(cores, parents_of):
    labels = []
    for c in cores:
        labels.append(c)
    for c in cores:
        for a in get_ancestors(c, parents_of, max_steps=5):
            if len(labels) >= MAX_LABELS:
                break
            labels.append(a)
        if len(labels) >= MAX_LABELS:
            break

    # dedup keep order
    seen = set()
    uniq = []
    for x in labels:
        if x not in seen:
            uniq.append(x); seen.add(x)
    labels = uniq

    # ensure 2~3
    if len(labels) < MIN_LABELS and cores:
        for a in get_ancestors(cores[0], parents_of, max_steps=10):
            if a not in seen:
                labels.append(a); seen.add(a)
            if len(labels) >= MIN_LABELS:
                break

    if len(labels) < MIN_LABELS:
        for k in range(NUM_CLASSES):
            if k not in seen:
                labels.append(k); seen.add(k)
            if len(labels) >= MIN_LABELS:
                break

    labels = labels[:MAX_LABELS]
    return sorted(labels)

def build_class_texts(id2name, kw):
    # 클래스 텍스트를 조금 더 "문장"처럼 만들어 의미 임베딩이 잘 먹도록 함
    texts = []
    for cid in range(NUM_CLASSES):
        name = id2name[cid]
        kws = kw.get(cid, [])[:40]
        # 템플릿: NLI는 아니지만 SBERT에서 의미 매칭이 좋아지는 경우가 많음
        t = f"This document is about {name}. Keywords: " + ", ".join(kws) if kws else f"This document is about {name}."
        texts.append(t)
    return texts

def l2_normalize(x, eps=1e-12):
    n = np.linalg.norm(x, axis=1, keepdims=True)
    return x / np.maximum(n, eps)


# ------------------------
# Core mining using embeddings
# ------------------------
def mine_cores_with_embeddings(doc_emb, class_emb, roots, parents_of, children_of, depth):
    """
    doc_emb: (N,D) normalized
    class_emb: (C,D) normalized
    Return: doc_cores list[list[int]]
    """
    root_set = set(roots)

    # median conf per class를 계산하면 또다시 붕괴할 수 있어,
    # 여기서는 "conf>0" + "top candidates" + "fallback"으로 안정적으로 운영
    doc_cores = []

    for i in tqdm(range(doc_emb.shape[0]), desc="Core mining (embedding)"):
        sim = doc_emb[i] @ class_emb.T  # (C,)

        # 후보: 전체 중 상위 M개만 고려 (전체를 다 쓰면 상위 클래스가 너무 유리해질 수 있음)
        M = 60
        top_idx = np.argpartition(-sim, M)[:M]
        # 깊이 가중(구체 클래스 선호): sim' = sim + alpha*depth
        sim2 = sim[top_idx] + DEPTH_ALPHA * depth[top_idx].astype(np.float32)
        # sim2 기준 정렬
        order = top_idx[np.argsort(-sim2)]

        # confidence 계산(부모/형제 대비)
        cand = []
        for c in order:
            if c in root_set:
                continue
            plist = parents_of.get(c, [])
            if not plist:
                continue
            p = plist[0]
            parent_sim = sim[p]

            sib_max = -1e9
            for s in children_of.get(p, []):
                if s == c:
                    continue
                sib_max = max(sib_max, sim[s])

            conf = float(sim[c] - max(parent_sim, sib_max))
            cand.append((c, conf, float(sim[c]), int(depth[c])))

        # core 선택: conf가 양수이고, conf 상위
        cand_pos = [x for x in cand if x[1] > 0]
        cand_pos.sort(key=lambda x: (x[1], x[3]), reverse=True)  # conf 우선, 깊이 보조

        if cand_pos:
            cores = [cand_pos[0][0]]
            # 두 번째 core는 "충분히 유사"하고 conf도 괜찮으면 추가
            for x in cand_pos[1:]:
                if x[2] >= 0.92 * cand_pos[0][2]:
                    cores.append(x[0])
                    break
            doc_cores.append(cores[:2])
        else:
            # fallback: (루트 제외) sim+depth 가중치 최고 1개
            best = None
            best_score = -1e9
            for c in order:
                if c in root_set:
                    continue
                score = float(sim[c] + DEPTH_ALPHA * depth[c])
                if score > best_score:
                    best_score = score
                    best = c
            doc_cores.append([best] if best is not None else [0])

    return doc_cores


def main():
    for p in [CLASSES_PATH, HIER_PATH, KEYWORDS_PATH, TRAIN_PATH, TEST_PATH]:
        if not os.path.exists(p):
            raise FileNotFoundError(f"Missing required file: {p}")

    # Load resources
    id2name = load_classes(CLASSES_PATH)
    kw = load_keywords_accumulate(KEYWORDS_PATH)
    roots, parents_of, children_of, note = load_hierarchy_autodetect(HIER_PATH, id2name)

    print(f"[INFO] taxonomy direction: {note}")
    print(f"[INFO] roots: {len(roots)}")

    depth = compute_depths(roots, children_of)

    train_pids, train_texts = load_corpus(TRAIN_PATH)
    test_pids, test_texts = load_corpus(TEST_PATH)

    # shorten docs to reduce noise / speed up
    train_texts = [t[:MAX_CHARS_DOC] for t in train_texts]
    test_texts  = [t[:MAX_CHARS_DOC] for t in test_texts]

    class_texts = build_class_texts(id2name, kw)

    # Encode with SentenceTransformer
    try:
        from sentence_transformers import SentenceTransformer
    except ImportError as e:
        raise RuntimeError(
            "sentence-transformers is not installed. Run:\n"
            "  pip install -q sentence-transformers\n"
            "and re-run."
        ) from e

    model = SentenceTransformer(SBERT_NAME)

    print("[INFO] Encoding class texts...")
    class_emb = model.encode(
        class_texts,
        batch_size=BATCH_SIZE_ENC,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

    print("[INFO] Encoding train docs...")
    train_emb = model.encode(
        train_texts,
        batch_size=BATCH_SIZE_ENC,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

    print("[INFO] Encoding test docs...")
    test_emb = model.encode(
        test_texts,
        batch_size=BATCH_SIZE_ENC,
        show_progress_bar=True,
        convert_to_numpy=True,
        normalize_embeddings=True,
    )

    # --- Core mining on TEST directly (zero-shot style) ---
    # 가장 먼저 "붕괴가 멈추는지" 확인하려면 이 방식이 가장 안전합니다.
    test_cores = mine_cores_with_embeddings(
        doc_emb=test_emb,
        class_emb=class_emb,
        roots=roots,
        parents_of=parents_of,
        children_of=children_of,
        depth=depth
    )

    # Build final labels (2~3)
    test_labels = [cores_to_labels(cores, parents_of) for cores in test_cores]

    # Diagnostics: top patterns
    from collections import Counter
    freq = Counter(tuple(x) for x in test_labels)
    print("[INFO] top-10 predicted label patterns:", freq.most_common(10))

    # Write submission
    with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["id", "labels"])
        for pid, labs in zip(test_pids, test_labels):
            w.writerow([pid, ",".join(map(str, labs))])

    print(f"[DONE] saved: {SUBMISSION_PATH}")
    print(f"[DONE] test samples: {len(test_pids)} | labels per sample: {MIN_LABELS}-{MAX_LABELS}")


if __name__ == "__main__":
    main()


Generating dummy predictions: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19658/19658 [00:00<00:00, 190266.11it/s]

Dummy submission file saved to: submission.csv
Total samples: 19658, Classes per sample: 1-3



