In [None]:
import os
import csv
import torch
import torch.nn as nn
import math
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict

# --- 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 256
BATCH_SIZE = 32
NUM_CLASSES = 531

PATH_BERT_MODEL = "saved_ensemble/bert_final.pt"
PATH_DEBERTA_MODEL = "saved_ensemble/deberta_final.pt"
OUTPUT_CSV = "submission.csv"

BASE_DIR = "Amazon_products"
TEST_CORPUS_PATH = os.path.join(BASE_DIR, "test/test_corpus.txt")
HIERARCHY_PATH = os.path.join(BASE_DIR, "class_hierarchy.txt")

# --- [핵심 수정] 통합 모델 클래스 정의 ---
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias: self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else: self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None: self.bias.data.uniform_(-stdv, stdv)
    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        return output + self.bias if self.bias is not None else output

class GNN_Model(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix, model_type="bert"):
        super(GNN_Model, self).__init__()
        self.model_type = model_type
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        nn.init.xavier_uniform_(self.label_embedding)
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        if self.model_type == "bert":
            doc_embedding = outputs.pooler_output
        else: # deberta
            doc_embedding = outputs.last_hidden_state[:, 0, :]
            
        doc_embedding = self.dropout(doc_embedding)
        refined_label = torch.tanh(self.gcn(self.label_embedding, self.adj_matrix))
        return torch.mm(doc_embedding, refined_label.t())

# --- 데이터 로더 ---
def load_test_data():
    pids, texts = [], []
    with open(TEST_CORPUS_PATH, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(parts[1])
    return pids, texts

def load_hierarchy():
    parents = defaultdict(list)
    if os.path.exists(HIERARCHY_PATH):
        with open(HIERARCHY_PATH, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    p, c = int(parts[0]), int(parts[1])
                    if p < NUM_CLASSES and c < NUM_CLASSES: parents[c].append(p)
    return parents

class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.texts)
    def __getitem__(self, item):
        encoding = self.tokenizer.encode_plus(
            str(self.texts[item]), add_special_tokens=True, max_length=self.max_len,
            return_token_type_ids=False, padding='max_length', truncation=True,
            return_attention_mask=True, return_tensors='pt'
        )
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten()}

# --- 실행 ---
def run_ensemble():
    print("1. Loading Data...")
    test_pids, test_texts = load_test_data()
    parents_map = load_hierarchy()
    
    # Tokenizer 로드 (use_fast=False 필수)
    print("   - Loading Tokenizers...")
    tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-uncased")
    tokenizer_deberta = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", use_fast=False)
    
    # Dataset & Loader 생성
    loader_bert = DataLoader(TestDataset(test_texts, tokenizer_bert, MAX_LEN), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    loader_deberta = DataLoader(TestDataset(test_texts, tokenizer_deberta, MAX_LEN), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print("2. Loading Models...")
    # 통합 클래스(GNN_Model)로 로드 (weights_only=False 필수)
    model_bert = torch.load(PATH_BERT_MODEL, map_location=DEVICE, weights_only=False)
    model_bert.to(DEVICE).eval()
    
    model_deberta = torch.load(PATH_DEBERTA_MODEL, map_location=DEVICE, weights_only=False)
    model_deberta.to(DEVICE).eval()

    print("3. Ensembling...")
    all_preds = []
    
    with torch.no_grad():
        for batch_bert, batch_deberta in tqdm(zip(loader_bert, loader_deberta), total=len(loader_bert)):
            
            # BERT Forward
            ids_b = batch_bert['input_ids'].to(DEVICE)
            mask_b = batch_bert['attention_mask'].to(DEVICE)
            logits_bert = model_bert(ids_b, mask_b)
            
            # DeBERTa Forward
            ids_d = batch_deberta['input_ids'].to(DEVICE)
            mask_d = batch_deberta['attention_mask'].to(DEVICE)
            logits_deberta = model_deberta(ids_d, mask_d)

            # Soft Voting (DeBERTa 가중치 0.6)
            probs = (torch.sigmoid(logits_bert) * 0.4) + (torch.sigmoid(logits_deberta) * 0.6)
            batch_probs = probs.cpu().numpy()
            
            # 후처리 (계층 구조 보정 & 2~3개 선택)
            for sample_probs in batch_probs:
                top_indices = sample_probs.argsort()[-10:][::-1]
                candidate_set = set()
                score_map = {}

                for cid in top_indices:
                    candidate_set.add(cid)
                    score_map[cid] = float(sample_probs[cid])
                    curr, depth = cid, 0
                    while curr in parents_map and depth < 5:
                        for pid in parents_map[curr]:
                            if pid < NUM_CLASSES:
                                candidate_set.add(pid)
                                if pid not in score_map: score_map[pid] = float(sample_probs[pid])
                            curr = pid
                        depth += 1
                
                sorted_candidates = sorted(list(candidate_set), key=lambda x: score_map.get(x, 0), reverse=True)
                final_labels = sorted_candidates[:3]
                
                if len(final_labels) < 2:
                    rem = [x for x in sample_probs.argsort()[::-1] if x not in final_labels]
                    for r in rem:
                        final_labels.append(r)
                        if len(final_labels) >= 2: break
                
                all_preds.append(sorted(final_labels))

    print(f"4. Saving to {OUTPUT_CSV}...")
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "labels"])
        for pid, labels in zip(test_pids, all_preds):
            writer.writerow([pid, ",".join(map(str, labels))])
    print("✅ Ensemble Done!")

if __name__ == "__main__":
    run_ensemble()