In [None]:
# ============================================================
# TaxoClass End-to-End Pipeline (Step 1 -> Step 4) + submission.csv
# - Keeps the Kaggle submission.csv writing format identical
# - Implements:
#   Step 1: RoBERTa-large-MNLI entailment similarity sim(D,c)
#   Step 2: Core class mining (top-down candidates + median-conf filtering)
#   Step 3: Core-guided classifier training (BERT doc encoder + GNN class encoder + matching)
#   Step 4: Multi-label self-training (Q refinement + KL objective)
# ============================================================

import os
import csv
import math
import random
import pickle
import statistics
from dataclasses import dataclass
from collections import defaultdict, deque
from typing import Dict, List, Set, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModel
from torch.optim import AdamW


# -------------------------
# Reproducibility
# -------------------------
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

# -------------------------
# Paths (given structure)
# -------------------------
DATA_DIR = "Amazon_products"
TRAIN_DIR = os.path.join(DATA_DIR, "train")
TEST_DIR  = os.path.join(DATA_DIR, "test")

TRAIN_CORPUS_PATH = os.path.join(TRAIN_DIR, "train_corpus.txt")  # pid \t text
TEST_CORPUS_PATH  = os.path.join(TEST_DIR, "test_corpus.txt")    # pid \t text

CLASSES_PATH   = os.path.join(DATA_DIR, "classes.txt")
HIERARCHY_PATH = os.path.join(DATA_DIR, "class_hierarchy.txt")

SUBMISSION_PATH = "submission.csv"

# Kaggle submission constraints (you used these in baseline)
NUM_CLASSES = 531
MIN_LABELS  = 1
MAX_LABELS  = 3

# -------------------------
# Config (adjust for speed)
# -------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Step 1 entailment scorer
NLI_MODEL_NAME = "textattack/roberta-base-MNLI"
NLI_TEMPLATE = "this product is about {}."
NLI_BATCH_SIZE = 32
NLI_MAX_LENGTH = 256
NLI_FP16 = True

# Step 2 core mining
TOPK_PER_PARENT = 1
MAX_DEPTH = 4
MIN_SIM_TO_EXPAND = 0.2

# Step 3 core-guided training
DOC_MODEL_NAME = "bert-base-uncased"
DOC_MAX_LENGTH = 256
STEP3_BATCH_SIZE = 4
STEP3_EPOCHS = 1
STEP3_LR = 2e-5
STEP3_NEG_SAMPLE_SIZE = 128
GNN_LAYERS = 2
EGO_K = 1
# Paper form is sigmoid(exp(score)). Practical alternative is sigmoid(score).
USE_EXP_IN_MATCHING = False

# Step 4 self-training
STEP4_BATCH_SIZE = 8
STEP4_EPOCHS = 2
STEP4_LR = 1e-5
Q_UPDATE_EVERY_EPOCH = 1  # recompute class-wise sums each epoch
USE_AMP = True

# Prediction -> labels
PRED_THRESHOLD = None   # e.g., 0.5; if None, use pure topK
TOPK_PRED = MAX_LABELS  # cap label count
ENSURE_MIN_LABELS = MIN_LABELS

# Optional caching (recommended to avoid re-mining/retraining each run)
CACHE_DIR = "taxoclass_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
CORE_CACHE_PATH = os.path.join(CACHE_DIR, "pid2core.pkl")
MEDIAN_CACHE_PATH = os.path.join(CACHE_DIR, "median_conf.pkl")
STEP3_MODEL_PATH = os.path.join(CACHE_DIR, "step3_model.pt")
STEP4_MODEL_PATH = os.path.join(CACHE_DIR, "step4_model.pt")

# ============================================================
# IO helpers
# ============================================================
def load_corpus(path: str) -> Dict[str, str]:
    """Load corpus into {pid: text}."""
    pid2text = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.rstrip("\n").split("\t", 1)
            if len(parts) == 2:
                pid, text = parts
                pid2text[pid] = text
    return pid2text

def load_classes(path: str) -> List[str]:
    """
    Load class surface names.
    Assumption: classes.txt has 531 lines; line index = class_id (0..530).
    """
    classes = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            name = line.strip()
            if name:
                classes.append(name)
    return classes

def load_taxonomy_edges(path: str) -> Tuple[Dict[int, List[int]], Dict[int, List[int]]]:
    """
    Load taxonomy parent-child edges.
    Expect: each line has two ints: parent child (space or tab separated).
    """
    parent2children = {}
    child2parents = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t") if "\t" in line else line.split()
            if len(parts) < 2:
                continue
            p, c = int(parts[0]), int(parts[1])
            parent2children.setdefault(p, []).append(c)
            child2parents.setdefault(c, []).append(p)
    return parent2children, child2parents

# ============================================================
# Step 1: Document–Class Similarity (NLI entailment)
# ============================================================
class EntailmentScorer:
    """
    sim(D, c) = P(D entails template(class_name))
    Uses roberta-large-mnli.
    Caches per pid: {pid: {class_id: sim}}
    """
    def __init__(
        self,
        model_name: str = "roberta-large-mnli",
        template: str = "this product is about {}.",
        device: str = "cuda",
        batch_size: int = 32,
        max_length: int = 256,
        use_fp16: bool = True,
    ):
        self.model_name = model_name
        self.template = template
        self.batch_size = batch_size
        self.max_length = max_length
        self.device = device
        self.use_fp16 = use_fp16 and (device == "cuda")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
        self.model.eval()

        self._cache: Dict[str, Dict[int, float]] = {}

        label2id = {k.upper(): v for k, v in getattr(self.model.config, "label2id", {}).items()}
        self.entailment_idx = label2id.get("ENTAILMENT", 2)

    def _hypothesis(self, class_name: str) -> str:
        return self.template.format(class_name)

    @torch.no_grad()
    def score_pid_to_class_ids(
        self,
        pid: str,
        doc_text: str,
        class_ids: List[int],
        id2name: List[str]
    ) -> Dict[int, float]:
        if pid not in self._cache:
            self._cache[pid] = {}

        missing = [cid for cid in class_ids if cid not in self._cache[pid]]
        if not missing:
            return {cid: self._cache[pid][cid] for cid in class_ids}

        for start in range(0, len(missing), self.batch_size):
            batch_cids = missing[start:start + self.batch_size]
            premises = [doc_text] * len(batch_cids)
            hypotheses = [self._hypothesis(id2name[cid]) for cid in batch_cids]

            enc = self.tokenizer(
                premises, hypotheses,
                truncation=True, padding=True,
                max_length=self.max_length,
                return_tensors="pt"
            ).to(self.device)

            if self.use_fp16:
                with torch.autocast(device_type="cuda", dtype=torch.float16):
                    logits = self.model(**enc).logits
            else:
                logits = self.model(**enc).logits

            probs = torch.softmax(logits, dim=-1)[:, self.entailment_idx].detach().cpu().numpy()
            for cid, s in zip(batch_cids, probs.tolist()):
                self._cache[pid][cid] = float(s)

        return {cid: self._cache[pid][cid] for cid in class_ids}

# ============================================================
# Taxonomy wrapper (supports pseudo-root)
# ============================================================
class Taxonomy:
    PSEUDO_ROOT = -1

    def __init__(self, parent2children: Dict[int, List[int]], child2parents: Dict[int, List[int]], num_classes: int):
        self.num_classes = num_classes
        self.parent2children = {k: list(v) for k, v in parent2children.items()}
        self.child2parents = {k: list(v) for k, v in child2parents.items()}

        for cid in range(num_classes):
            self.parent2children.setdefault(cid, [])
            self.child2parents.setdefault(cid, [])

        top_level = [cid for cid in range(num_classes) if len(self.child2parents[cid]) == 0]
        self.parent2children[self.PSEUDO_ROOT] = top_level
        for cid in top_level:
            self.child2parents[cid].append(self.PSEUDO_ROOT)

        self._siblings = self._build_siblings()

    def parents(self, cid: int) -> List[int]:
        return self.child2parents.get(cid, [])

    def children(self, cid: int) -> List[int]:
        return self.parent2children.get(cid, [])

    def siblings(self, cid: int) -> List[int]:
        return self._siblings.get(cid, [])

    def _build_siblings(self) -> Dict[int, List[int]]:
        sib = {cid: set() for cid in range(self.num_classes)}
        for cid in range(self.num_classes):
            for p in self.parents(cid):
                for s in self.children(p):
                    if s != cid and s != self.PSEUDO_ROOT and 0 <= s < self.num_classes:
                        sib[cid].add(s)
        return {cid: sorted(list(v)) for cid, v in sib.items()}

# ============================================================
# Step 2: Document Core Class Mining
# ============================================================
class CoreClassMiner:
    def __init__(
        self,
        taxonomy: Taxonomy,
        entailment_scorer: EntailmentScorer,
        class_names: List[str],
        topk_per_parent: int = 2,
        max_depth: int = 6,
        min_sim_to_expand: float = 0.0,
    ):
        self.tax = taxonomy
        self.scorer = entailment_scorer
        self.class_names = class_names
        self.topk_per_parent = topk_per_parent
        self.max_depth = max_depth
        self.min_sim_to_expand = min_sim_to_expand

    def select_candidates_for_doc(self, pid: str, doc_text: str) -> Set[int]:
        visited: Set[int] = set()
        ps: Dict[int, float] = {self.tax.PSEUDO_ROOT: 1.0}

        frontier = [self.tax.PSEUDO_ROOT]
        depth = 0

        while frontier and depth < self.max_depth:
            next_frontier = []
            depth += 1

            for u in frontier:
                children = [c for c in self.tax.children(u) if c != self.tax.PSEUDO_ROOT]
                if not children:
                    continue

                sim_map = self.scorer.score_pid_to_class_ids(pid, doc_text, children, self.class_names)

                scored = []
                for c in children:
                    sim_c = sim_map[c]
                    if sim_c < self.min_sim_to_expand:
                        continue

                    best_parent_ps = 0.0
                    for p in self.tax.parents(c):
                        if p in ps:
                            best_parent_ps = max(best_parent_ps, ps[p])
                    if best_parent_ps == 0.0:
                        best_parent_ps = ps.get(u, 0.0)

                    ps_c = best_parent_ps * sim_c
                    ps[c] = max(ps.get(c, 0.0), ps_c)
                    scored.append((c, ps[c]))

                if not scored:
                    continue

                scored.sort(key=lambda x: x[1], reverse=True)
                keep = [c for c, _ in scored[: self.topk_per_parent]]

                for c in keep:
                    visited.add(c)
                    next_frontier.append(c)

            frontier = next_frontier

        return visited

    def _conf_required_class_ids(self, c: int) -> Set[int]:
        req = {c}
        for p in self.tax.parents(c):
            if p != self.tax.PSEUDO_ROOT and 0 <= p < self.tax.num_classes:
                req.add(p)
        for s in self.tax.siblings(c):
            req.add(s)
        return req

    def compute_conf_for_doc_candidates(self, pid: str, doc_text: str, candidates: Set[int]) -> Dict[int, float]:
        if not candidates:
            return {}

        needed = set()
        for c in candidates:
            needed |= self._conf_required_class_ids(c)
        needed_list = sorted(list(needed))

        sim_map = self.scorer.score_pid_to_class_ids(pid, doc_text, needed_list, self.class_names)

        conf_map = {}
        for c in candidates:
            sim_c = sim_map[c]
            comp = []
            for p in self.tax.parents(c):
                if p != self.tax.PSEUDO_ROOT and 0 <= p < self.tax.num_classes:
                    comp.append(p)
            comp.extend(self.tax.siblings(c))
            max_other = max((sim_map[x] for x in comp), default=0.0)
            conf_map[c] = float(sim_c - max_other)
        return conf_map

    def mine_core_classes_over_corpus(
        self,
        pid2text: Dict[str, str],
        show_progress: bool = True,
    ) -> Tuple[Dict[str, Set[int]], Dict[int, float]]:
        """
        Returns:
          pid2core: {pid: set(core classes)}
          median_conf: {class_id: median threshold}
        """
        class_conf_pool: Dict[int, List[float]] = defaultdict(list)
        pid2conf: Dict[str, Dict[int, float]] = {}

        iterator = tqdm(pid2text.items(), desc="Step2-Pass1 (cand/conf)") if show_progress else pid2text.items()
        for pid, text in iterator:
            cand = self.select_candidates_for_doc(pid, text)
            conf_map = self.compute_conf_for_doc_candidates(pid, text, cand)
            pid2conf[pid] = conf_map
            for c, v in conf_map.items():
                class_conf_pool[c].append(v)

        median_conf: Dict[int, float] = {}
        for c in range(self.tax.num_classes):
            vals = class_conf_pool.get(c, [])
            median_conf[c] = float(statistics.median(vals)) if vals else float("inf")

        pid2core: Dict[str, Set[int]] = {}
        iterator2 = tqdm(pid2conf.items(), desc="Step2-Pass2 (core)") if show_progress else pid2conf.items()
        for pid, conf_map in iterator2:
            core = {c for c, v in conf_map.items() if v >= median_conf[c]}
            pid2core[pid] = core

        return pid2core, median_conf

# ============================================================
# Step 3: Core Class Guided Classifier Training
# ============================================================
def build_pos_neg_sets_for_doc(core: Set[int], tax: Taxonomy, num_classes: int) -> Tuple[Set[int], Set[int]]:
    if not core:
        return set(), set()

    pos = set(core)
    chd_union = set()
    for c in core:
        for p in tax.parents(c):
            if p != tax.PSEUDO_ROOT and 0 <= p < num_classes:
                pos.add(p)
        for ch in tax.children(c):
            if ch != tax.PSEUDO_ROOT and 0 <= ch < num_classes:
                chd_union.add(ch)

    all_classes = set(range(num_classes))
    neg = all_classes - pos - chd_union
    return pos, neg

class CoreGuidedTrainDataset(Dataset):
    def __init__(
        self,
        pid2text: Dict[str, str],
        pid2core: Dict[str, Set[int]],
        tax: Taxonomy,
        num_classes: int,
        neg_sample_size: Optional[int] = 256,
        seed: int = 42,
    ):
        self.items = []
        rng = random.Random(seed)

        for pid, text in pid2text.items():
            core = pid2core.get(pid, set())
            if not core:
                continue

            pos, neg = build_pos_neg_sets_for_doc(core, tax, num_classes)

            if neg_sample_size is not None and len(neg) > neg_sample_size:
                neg = set(rng.sample(sorted(list(neg)), neg_sample_size))

            self.items.append((text, pos, neg))

        self.num_classes = num_classes

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx: int):
        text, pos, neg = self.items[idx]
        y = torch.zeros(self.num_classes, dtype=torch.float32)
        m = torch.zeros(self.num_classes, dtype=torch.float32)
        if pos:
            pos_idx = torch.tensor(list(pos), dtype=torch.long)
            y[pos_idx] = 1.0
            m[pos_idx] = 1.0
        if neg:
            neg_idx = torch.tensor(list(neg), dtype=torch.long)
            y[neg_idx] = 0.0
            m[neg_idx] = 1.0
        return {"text": text, "y": y, "mask": m}

@dataclass
class TrainBatch:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor
    y: torch.Tensor
    mask: torch.Tensor

def make_collate_fn(tokenizer, max_length: int = 256):
    def collate(examples: List[dict]) -> TrainBatch:
        texts = [ex["text"] for ex in examples]
        enc = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
        y = torch.stack([ex["y"] for ex in examples], dim=0)
        m = torch.stack([ex["mask"] for ex in examples], dim=0)
        return TrainBatch(enc["input_ids"], enc["attention_mask"], y, m)
    return collate

class GCNLayer(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)

    def forward(self, H: torch.Tensor, A_norm: torch.Tensor) -> torch.Tensor:
        HW = self.W(H).float()
        A = A_norm.float()
        out = torch.sparse.mm(A_norm, HW)
        return F.relu(out)

def build_normalized_adjacency(
    tax: Taxonomy,
    num_classes: int,
    add_self_loops: bool = True,
    undirected: bool = True,
    device: str = "cpu"
) -> torch.Tensor:
    edges = set()
    for p, children in tax.parent2children.items():
        if p == tax.PSEUDO_ROOT:
            continue
        for c in children:
            if c == tax.PSEUDO_ROOT:
                continue
            if 0 <= p < num_classes and 0 <= c < num_classes:
                edges.add((p, c))
                if undirected:
                    edges.add((c, p))
    if add_self_loops:
        for i in range(num_classes):
            edges.add((i, i))

    rows = torch.tensor([u for (u, v) in edges], dtype=torch.long, device=device)
    cols = torch.tensor([v for (u, v) in edges], dtype=torch.long, device=device)

    deg = torch.zeros(num_classes, dtype=torch.float32, device=device)
    deg.scatter_add_(0, rows, torch.ones_like(rows, dtype=torch.float32))
    deg = torch.clamp(deg, min=1.0)

    vals = 1.0 / torch.sqrt(deg[rows] * deg[cols])

    A = torch.sparse_coo_tensor(
        indices=torch.stack([rows, cols], dim=0),
        values=vals,
        size=(num_classes, num_classes),
        device=device
    ).coalesce()
    return A

def compute_k_hop_ego_sets(tax: Taxonomy, num_classes: int, k: int = 1) -> List[List[int]]:
    adj = [[] for _ in range(num_classes)]
    for p, children in tax.parent2children.items():
        if p == tax.PSEUDO_ROOT:
            continue
        if 0 <= p < num_classes:
            for c in children:
                if 0 <= c < num_classes:
                    adj[p].append(c)
                    adj[c].append(p)

    ego = []
    for center in range(num_classes):
        seen = {center}
        frontier = [center]
        for _ in range(k):
            nxt = []
            for u in frontier:
                for v in adj[u]:
                    if v not in seen:
                        seen.add(v)
                        nxt.append(v)
            frontier = nxt
            if not frontier:
                break
        ego.append(sorted(list(seen)))
    return ego

class ClassEncoderGNN(nn.Module):
    def __init__(
        self,
        class_names: List[str],
        bert_tokenizer: AutoTokenizer,
        bert_embedding_weight: torch.Tensor,
        tax: Taxonomy,
        num_classes: int,
        hidden_dim: int = 768,
        gnn_layers: int = 2,
        ego_k: int = 1,
        device: str = "cpu",
    ):
        super().__init__()
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        self.device = device
        self.tokenizer = bert_tokenizer

        with torch.no_grad():
            H0 = []
            for name in class_names:
                toks = self.tokenizer(name, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze(0)
                if toks.numel() == 0:
                    vec = torch.zeros(hidden_dim)
                else:
                    vec = bert_embedding_weight[toks].mean(dim=0).cpu()
                H0.append(vec)
            H0 = torch.stack(H0, dim=0)
        self.register_buffer("H0", H0.to(device), persistent=False)

        self.register_buffer("A_norm", build_normalized_adjacency(tax, num_classes, device=device), persistent=False)
        self.ego_sets = compute_k_hop_ego_sets(tax, num_classes, k=ego_k)

        self.layers = nn.ModuleList([GCNLayer(hidden_dim, hidden_dim) for _ in range(gnn_layers)])

    def forward(self) -> torch.Tensor:
        H = self.H0
        for layer in self.layers:
            H = layer(H, self.A_norm)

        C = torch.empty((self.num_classes, self.hidden_dim), device=H.device, dtype=H.dtype)
        for j in range(self.num_classes):
            nodes = self.ego_sets[j]
            C[j] = H[nodes].mean(dim=0)
        return C

class TaxoClassModel(nn.Module):
    """
    Single model used for Step 3 and Step 4.
    """
    def __init__(
        self,
        num_classes: int,
        class_names: List[str],
        tax: Taxonomy,
        doc_model_name: str = "bert-base-uncased",
        gnn_layers: int = 2,
        ego_k: int = 1,
        use_exp_in_matching: bool = False,
        device: str = "cpu",
    ):
        super().__init__()
        self.num_classes = num_classes
        self.use_exp_in_matching = use_exp_in_matching
        self.device = device

        self.doc_tokenizer = AutoTokenizer.from_pretrained(doc_model_name, use_fast=True)
        self.doc_encoder = AutoModel.from_pretrained(doc_model_name).to(device)

        hidden_dim = self.doc_encoder.config.hidden_size
        self.B = nn.Parameter(torch.empty(hidden_dim, hidden_dim))
        nn.init.xavier_uniform_(self.B)

        bert_emb = self.doc_encoder.get_input_embeddings().weight.detach()
        self.class_encoder = ClassEncoderGNN(
            class_names=class_names,
            bert_tokenizer=self.doc_tokenizer,
            bert_embedding_weight=bert_emb,
            tax=tax,
            num_classes=num_classes,
            hidden_dim=hidden_dim,
            gnn_layers=gnn_layers,
            ego_k=ego_k,
            device=device,
        )

    def encode_docs(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        out = self.doc_encoder(input_ids=input_ids, attention_mask=attention_mask)
        return out.last_hidden_state[:, 0, :]

    def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        D = self.encode_docs(input_ids, attention_mask)  # [B,dim] (AMP면 fp16일 수 있음)

        # GNN(class encoder)만 FP32로 수행 (sparse.mm Half 방지)
        if self.device == "cuda":
            with torch.cuda.amp.autocast(enabled=False):
                C = self.class_encoder()  # [N,dim] float32
        else:
            C = self.class_encoder()

        # matmul dtype 정합: 안전하게 float32로 계산 후 sigmoid
        Df = D.float()
        Bf = self.B.float()
        scores = Df @ (C @ Bf).t()

        if self.use_exp_in_matching:
            return torch.sigmoid(torch.exp(scores))
        return torch.sigmoid(scores)


def core_guided_loss(p: torch.Tensor, y: torch.Tensor, mask: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    p = torch.clamp(p, eps, 1.0 - eps)
    per_label = y * torch.log(p) + (1.0 - y) * torch.log(1.0 - p)
    per_label = per_label * mask
    denom = torch.clamp(mask.sum(dim=1), min=1.0)
    loss_per_doc = -per_label.sum(dim=1) / denom
    return loss_per_doc.mean()

def train_step3(
    pid2text_train: Dict[str, str],
    pid2core: Dict[str, Set[int]],
    class_names: List[str],
    tax: Taxonomy,
    device: str,
) -> TaxoClassModel:
    model = TaxoClassModel(
        num_classes=len(class_names),
        class_names=class_names,
        tax=tax,
        doc_model_name=DOC_MODEL_NAME,
        gnn_layers=GNN_LAYERS,
        ego_k=EGO_K,
        use_exp_in_matching=USE_EXP_IN_MATCHING,
        device=device,
    ).to(device)

    dataset = CoreGuidedTrainDataset(
        pid2text=pid2text_train,
        pid2core=pid2core,
        tax=tax,
        num_classes=len(class_names),
        neg_sample_size=STEP3_NEG_SAMPLE_SIZE,
        seed=42,
    )
    loader = DataLoader(
        dataset,
        batch_size=STEP3_BATCH_SIZE,
        shuffle=True,
        collate_fn=make_collate_fn(model.doc_tokenizer, max_length=DOC_MAX_LENGTH),
    )

    optimizer = AdamW(model.parameters(), lr=STEP3_LR, weight_decay=0.01)
    scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device == "cuda"))

    model.train()
    for ep in range(1, STEP3_EPOCHS + 1):
        pbar = tqdm(loader, desc=f"Step3 (epoch {ep}/{STEP3_EPOCHS})")
        for batch in pbar:
            input_ids = batch.input_ids.to(device)
            attention_mask = batch.attention_mask.to(device)
            y = batch.y.to(device)
            m = batch.mask.to(device)

            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    p = model(input_ids, attention_mask)
                    loss = core_guided_loss(p, y, m)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                p = model(input_ids, attention_mask)
                loss = core_guided_loss(p, y, m)
                loss.backward()
                optimizer.step()

            pbar.set_postfix(loss=float(loss.detach().cpu()))
    return model

# ============================================================
# Step 4: Multi-label Self-Training
# ============================================================
class UnlabeledTextDataset(Dataset):
    def __init__(self, pid2text: Dict[str, str]):
        self.pids = list(pid2text.keys())
        self.texts = [pid2text[pid] for pid in self.pids]

    def __len__(self):
        return len(self.pids)

    def __getitem__(self, idx: int):
        return {"pid": self.pids[idx], "text": self.texts[idx]}

@dataclass
class UnlabeledBatch:
    input_ids: torch.Tensor
    attention_mask: torch.Tensor

def make_unlabeled_collate_fn(tokenizer, max_length: int = 256):
    def collate(examples: List[dict]) -> UnlabeledBatch:
        texts = [ex["text"] for ex in examples]
        enc = tokenizer(texts, truncation=True, padding=True, max_length=max_length, return_tensors="pt")
        return UnlabeledBatch(enc["input_ids"], enc["attention_mask"])
    return collate

@torch.no_grad()
def compute_classwise_sums(model: TaxoClassModel, loader: DataLoader, device: str) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Computes:
      sum_p[j]     = sum_i p_ij
      sum_1mp[j]   = sum_i (1 - p_ij)
    Needed by Q refinement formula.
    """
    model.eval()
    sum_p = None
    sum_1mp = None

    for batch in tqdm(loader, desc="Step4-Q update (sums)", leave=False):
        input_ids = batch.input_ids.to(device)
        attention_mask = batch.attention_mask.to(device)
        p = model(input_ids, attention_mask)  # [B,C]
        if sum_p is None:
            sum_p = p.sum(dim=0)
            sum_1mp = (1.0 - p).sum(dim=0)
        else:
            sum_p += p.sum(dim=0)
            sum_1mp += (1.0 - p).sum(dim=0)

    # avoid division by zero
    sum_p = torch.clamp(sum_p, min=1e-8)
    sum_1mp = torch.clamp(sum_1mp, min=1e-8)
    return sum_p.detach(), sum_1mp.detach()

def refine_q(p: torch.Tensor, sum_p: torch.Tensor, sum_1mp: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Paper formula:
      q_ij = (p_ij^2 / sum_i p_ij) / ((p_ij^2 / sum_i p_ij) + ((1-p_ij)^2 / sum_i (1-p_ij)))
    p: [B,C], sum_p: [C], sum_1mp: [C]
    """
    p = torch.clamp(p, eps, 1.0 - eps)
    a = (p * p) / sum_p
    b = ((1.0 - p) * (1.0 - p)) / sum_1mp
    q = a / (a + b + eps)
    return torch.clamp(q, eps, 1.0 - eps)

def bernoulli_kl(q: torch.Tensor, p: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    KL(Bern(q) || Bern(p)) = q log(q/p) + (1-q) log((1-q)/(1-p))
    """
    q = torch.clamp(q, eps, 1.0 - eps)
    p = torch.clamp(p, eps, 1.0 - eps)
    return (q * torch.log(q / p) + (1.0 - q) * torch.log((1.0 - q) / (1.0 - p))).mean()

def self_train_step4(
    model: TaxoClassModel,
    pid2text_unlabeled: Dict[str, str],
    device: str,
) -> TaxoClassModel:
    dataset = UnlabeledTextDataset(pid2text_unlabeled)
    loader = DataLoader(
        dataset,
        batch_size=STEP4_BATCH_SIZE,
        shuffle=True,
        collate_fn=make_unlabeled_collate_fn(model.doc_tokenizer, max_length=DOC_MAX_LENGTH),
    )

    optimizer = AdamW(model.parameters(), lr=STEP4_LR, weight_decay=0.01)
    scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device == "cuda"))

    sum_p, sum_1mp = None, None

    for ep in range(1, STEP4_EPOCHS + 1):
        if sum_p is None or (Q_UPDATE_EVERY_EPOCH and ((ep - 1) % Q_UPDATE_EVERY_EPOCH == 0)):
            # recompute class-wise sums on current model
            eval_loader = DataLoader(
                dataset,
                batch_size=STEP4_BATCH_SIZE,
                shuffle=False,
                collate_fn=make_unlabeled_collate_fn(model.doc_tokenizer, max_length=DOC_MAX_LENGTH),
            )
            sum_p, sum_1mp = compute_classwise_sums(model, eval_loader, device)

        model.train()
        pbar = tqdm(loader, desc=f"Step4 (epoch {ep}/{STEP4_EPOCHS})")
        for batch in pbar:
            input_ids = batch.input_ids.to(device)
            attention_mask = batch.attention_mask.to(device)

            optimizer.zero_grad(set_to_none=True)

            if scaler.is_enabled():
                with torch.cuda.amp.autocast(dtype=torch.float16):
                    p = model(input_ids, attention_mask)
                    q = refine_q(p, sum_p.to(device), sum_1mp.to(device))
                    loss = bernoulli_kl(q, p)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                p = model(input_ids, attention_mask)
                q = refine_q(p, sum_p.to(device), sum_1mp.to(device))
                loss = bernoulli_kl(q, p)
                loss.backward()
                optimizer.step()

            pbar.set_postfix(loss=float(loss.detach().cpu()))

    return model

# ============================================================
# Inference + label selection
# ============================================================
@torch.no_grad()
def predict_proba(
    model: TaxoClassModel,
    pid2text: Dict[str, str],
    device: str,
    batch_size: int = 16,
    max_length: int = 256,
) -> Tuple[List[str], np.ndarray]:
    model.eval()
    pids = list(pid2text.keys())
    texts = [pid2text[pid] for pid in pids]

    all_probs = []
    for start in tqdm(range(0, len(texts), batch_size), desc="Predicting"):
        batch_texts = texts[start:start + batch_size]
        enc = model.doc_tokenizer(
            batch_texts,
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)
        p = model(enc["input_ids"], enc["attention_mask"]).detach().cpu().numpy()
        all_probs.append(p)

    probs = np.vstack(all_probs) if all_probs else np.zeros((0, model.num_classes), dtype=np.float32)
    return pids, probs

def probs_to_labels(
    probs_row: np.ndarray,
    topk: int = 3,
    threshold: Optional[float] = None,
    ensure_min: int = 1
) -> List[int]:
    idx_sorted = np.argsort(-probs_row)  # descending
    if threshold is not None:
        chosen = [int(i) for i in idx_sorted if probs_row[i] >= threshold]
        chosen = chosen[:topk] if topk is not None else chosen
        if len(chosen) < ensure_min:
            chosen = [int(i) for i in idx_sorted[:ensure_min]]
        return sorted(chosen)
    else:
        chosen = [int(i) for i in idx_sorted[:max(ensure_min, topk)]]
        return sorted(chosen)

# ============================================================
# Main pipeline
# ============================================================
def main():
    # Load data
    pid2text_train = load_corpus(TRAIN_CORPUS_PATH)
    pid2text_test  = load_corpus(TEST_CORPUS_PATH)
    class_names = load_classes(CLASSES_PATH)
    assert len(class_names) == NUM_CLASSES, f"Expected {NUM_CLASSES} classes, got {len(class_names)}"

    parent2children, child2parents = load_taxonomy_edges(HIERARCHY_PATH)
    tax = Taxonomy(parent2children, child2parents, num_classes=len(class_names))

    # -------------------------
    # Step 2: Core mining (uses Step 1 entailment scorer)
    # -------------------------
    if os.path.exists(CORE_CACHE_PATH) and os.path.exists(MEDIAN_CACHE_PATH):
        with open(CORE_CACHE_PATH, "rb") as f:
            pid2core = pickle.load(f)
        with open(MEDIAN_CACHE_PATH, "rb") as f:
            median_conf = pickle.load(f)
        print(f"[Cache] Loaded pid2core ({len(pid2core)}) and median_conf from {CACHE_DIR}")
    else:
        entail_scorer = EntailmentScorer(
            model_name=NLI_MODEL_NAME,
            template=NLI_TEMPLATE,
            device=DEVICE,
            batch_size=NLI_BATCH_SIZE,
            max_length=NLI_MAX_LENGTH,
            use_fp16=NLI_FP16,
        )
        miner = CoreClassMiner(
            taxonomy=tax,
            entailment_scorer=entail_scorer,
            class_names=class_names,
            topk_per_parent=TOPK_PER_PARENT,
            max_depth=MAX_DEPTH,
            min_sim_to_expand=MIN_SIM_TO_EXPAND,
        )
        pid2core, median_conf = miner.mine_core_classes_over_corpus(pid2text_train, show_progress=True)

        with open(CORE_CACHE_PATH, "wb") as f:
            pickle.dump(pid2core, f)
        with open(MEDIAN_CACHE_PATH, "wb") as f:
            pickle.dump(median_conf, f)
        print(f"[Cache] Saved pid2core and median_conf to {CACHE_DIR}")

    # -------------------------
    # Step 3: Core-guided training
    # -------------------------
    if os.path.exists(STEP3_MODEL_PATH):
        step3_model = TaxoClassModel(
            num_classes=len(class_names),
            class_names=class_names,
            tax=tax,
            doc_model_name=DOC_MODEL_NAME,
            gnn_layers=GNN_LAYERS,
            ego_k=EGO_K,
            use_exp_in_matching=USE_EXP_IN_MATCHING,
            device=DEVICE,
        ).to(DEVICE)
        step3_model.load_state_dict(torch.load(STEP3_MODEL_PATH, map_location=DEVICE))
        print(f"[Cache] Loaded Step 3 model from {STEP3_MODEL_PATH}")
    else:
        step3_model = train_step3(
            pid2text_train=pid2text_train,
            pid2core=pid2core,
            class_names=class_names,
            tax=tax,
            device=DEVICE,
        )
        torch.save(step3_model.state_dict(), STEP3_MODEL_PATH)
        print(f"[Cache] Saved Step 3 model to {STEP3_MODEL_PATH}")

    # -------------------------
    # Step 4: Multi-label self-training
    # -------------------------
    if os.path.exists(STEP4_MODEL_PATH):
        step4_model = TaxoClassModel(
            num_classes=len(class_names),
            class_names=class_names,
            tax=tax,
            doc_model_name=DOC_MODEL_NAME,
            gnn_layers=GNN_LAYERS,
            ego_k=EGO_K,
            use_exp_in_matching=USE_EXP_IN_MATCHING,
            device=DEVICE,
        ).to(DEVICE)
        step4_model.load_state_dict(torch.load(STEP4_MODEL_PATH, map_location=DEVICE))
        print(f"[Cache] Loaded Step 4 model from {STEP4_MODEL_PATH}")
    else:
        step4_model = self_train_step4(step3_model, pid2text_train, device=DEVICE)
        torch.save(step4_model.state_dict(), STEP4_MODEL_PATH)
        print(f"[Cache] Saved Step 4 model to {STEP4_MODEL_PATH}")

    # -------------------------
    # Predict on test set
    # -------------------------
    pid_list_test, probs_test = predict_proba(
        step4_model,
        pid2text_test,
        device=DEVICE,
        batch_size=16,
        max_length=DOC_MAX_LENGTH,
    )

    # Convert probs -> label lists
    all_pids, all_labels = [], []
    for pid, prob_row in zip(pid_list_test, probs_test):
        labels = probs_to_labels(
            prob_row,
            topk=TOPK_PRED,
            threshold=PRED_THRESHOLD,
            ensure_min=ENSURE_MIN_LABELS
        )
        all_pids.append(pid)
        all_labels.append(labels)

    # ============================================================
    # --- Save submission file ---
    # (Same structure/format as your baseline; only labels source changed.)
    # ============================================================
    with open(SUBMISSION_PATH, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "labels"])
        for pid, labels in zip(all_pids, all_labels):
            writer.writerow([pid, ",".join(map(str, labels))])

    print(f"Submission file saved to: {SUBMISSION_PATH}")
    print(f"Total samples: {len(all_pids)}, Classes per sample: {MIN_LABELS}-{MAX_LABELS}")

if __name__ == "__main__":
    main()


Generating dummy predictions: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 19658/19658 [00:00<00:00, 190266.11it/s]

Dummy submission file saved to: submission.csv
Total samples: 19658, Classes per sample: 1-3



