In [None]:
import os
import csv
import torch
import torch.nn as nn
import math
import numpy as np
import pandas as pd
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict

# --- 1. 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/deberta-v3-base" 
MAX_LEN = 256
BATCH_SIZE = 32
NUM_CLASSES = 531

INPUT_MODEL_PATH = "saved_model_deberta_gnn/best_model_deberta.pt" 
OUTPUT_CSV = "submission.csv" 

# 경로
BASE_DIR = "Amazon_products"
TEST_CORPUS_PATH = os.path.join(BASE_DIR, "test/test_corpus.txt")
HIERARCHY_PATH = os.path.join(BASE_DIR, "class_hierarchy.txt")

# --- 2. 모델 클래스 정의 (학습 코드와 동일해야 함) ---
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

class DebertaGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(DebertaGCN, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        doc_embedding = last_hidden_state[:, 0, :] # CLS Token
        doc_embedding = self.dropout(doc_embedding)
        
        refined_label_embedding = self.gcn(self.label_embedding, self.adj_matrix)
        refined_label_embedding = torch.tanh(refined_label_embedding)
        
        logits = torch.mm(doc_embedding, refined_label_embedding.t())
        return logits

# --- 3. 데이터 로더 ---
def load_test_corpus(path):
    pids, texts = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(parts[1])
    return pids, texts

def load_hierarchy(path):
    parents = defaultdict(list)
    if os.path.exists(path):
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    p, c = int(parts[0]), int(parts[1])
                    if p < NUM_CLASSES and c < NUM_CLASSES:
                        parents[c].append(p)
    return parents

class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

# --- 4. 실행 로직 ---
def run_inference():
    print("1. Loading Data...")
    test_pids, test_texts = load_test_corpus(TEST_CORPUS_PATH)
    parents_map = load_hierarchy(HIERARCHY_PATH)
    
    # [핵심 수정] use_fast=False 추가하여 에러 방지
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
    
    dataset = TestDataset(test_texts, tokenizer, MAX_LEN)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print("2. Loading Trained GNN Model...")
    if not os.path.exists(INPUT_MODEL_PATH):
        print(f"❌ Error: Model not found at {INPUT_MODEL_PATH}")
        return

    # [핵심 수정] weights_only=False 추가하여 보안 에러 방지
    model = torch.load(INPUT_MODEL_PATH, map_location=DEVICE, weights_only=False)
    model.to(DEVICE)
    model.eval()

    all_final_labels = []
    print(f"3. Predicting on {len(test_texts)} samples...")

    with torch.no_grad():
        for batch in tqdm(loader, desc="DeBERTa-GNN Inference"):
            input_ids = batch['input_ids'].to(DEVICE)
            attention_mask = batch['attention_mask'].to(DEVICE)

            logits = model(input_ids, attention_mask)
            probs = torch.sigmoid(logits).cpu().numpy()
            
            # 후처리 (2~3개 라벨 선택)
            for i in range(len(probs)):
                sample_probs = probs[i]
                top_indices = sample_probs.argsort()[-10:][::-1]
                candidate_set = set()
                score_map = {}

                for cid in top_indices:
                    candidate_set.add(cid)
                    score_map[cid] = float(sample_probs[cid])
                    curr = cid
                    depth = 0
                    while curr in parents_map and depth < 5:
                        for pid in parents_map[curr]:
                            if pid < NUM_CLASSES:
                                candidate_set.add(pid)
                                if pid not in score_map:
                                    score_map[pid] = float(sample_probs[pid])
                            curr = pid
                        depth += 1
                
                sorted_candidates = sorted(list(candidate_set), key=lambda x: score_map.get(x, 0), reverse=True)
                final_labels = sorted_candidates[:3]
                
                if len(final_labels) < 2:
                    all_sorted = sample_probs.argsort()[::-1]
                    for idx in all_sorted:
                        if idx not in final_labels:
                            final_labels.append(idx)
                        if len(final_labels) >= 2:
                            break
                
                all_final_labels.append(sorted(final_labels))

    print(f"4. Saving to {OUTPUT_CSV}...")
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "labels"])
        for pid, labels in zip(test_pids, all_final_labels):
            label_str = ",".join(map(str, labels))
            writer.writerow([pid, label_str])
            
    print("✅ Done! Ready to submit.")

if __name__ == "__main__":
    run_inference()