In [None]:
import os
import csv
import torch
import torch.nn as nn
import math
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
from collections import defaultdict
import transformers.models.deberta_v2.modeling_deberta_v2 as deberta_v2_mod

# StableDropout이 모듈에 없으면 강제로 주입
if not hasattr(deberta_v2_mod, 'StableDropout'):
    try:
        # 최신 버전에는 DebertaV2StableDropout이라는 이름으로 존재할 가능성이 높음
        deberta_v2_mod.StableDropout = deberta_v2_mod.DebertaV2StableDropout
    except AttributeError:
        # 만약 그것도 없다면 일반 Dropout으로 대체 (일반적인 추론에는 문제 없음)
        deberta_v2_mod.StableDropout = torch.nn.Dropout

# --- 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 256
BATCH_SIZE = 32
NUM_CLASSES = 531

# 경로 확인
PATH_BERT_MODEL = "saved_model_gnn/best_model_gnn.pt"         
PATH_DEBERTA_MODEL = "saved_model_deberta/best_model_deberta.pt" 

BASE_DIR = "../Amazon_products"
TEST_CORPUS_PATH = os.path.join(BASE_DIR, "test/test_corpus.txt")
HIERARCHY_PATH = os.path.join(BASE_DIR, "class_hierarchy.txt")
OUTPUT_CSV = "final_submission.csv"

# --- 모델 클래스 정의 ---
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    
    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        return output + self.bias if self.bias is not None else output


class BertGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(BertGCN, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix
    
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        doc_embedding = self.dropout(outputs.pooler_output)
        refined_label_embedding = torch.tanh(self.gcn(self.label_embedding, self.adj_matrix))
        return torch.mm(doc_embedding, refined_label_embedding.t())


class DebertaGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(DebertaGCN, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix
    
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        doc_embedding = self.dropout(outputs.last_hidden_state[:, 0, :]) 
        refined_label_embedding = torch.tanh(self.gcn(self.label_embedding, self.adj_matrix))
        return torch.mm(doc_embedding, refined_label_embedding.t())


# --- 데이터셋 ---
def load_test_data():
    pids, texts = [], []
    with open(TEST_CORPUS_PATH, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(parts[1])
    return pids, texts


def load_hierarchy():
    parents = defaultdict(list)
    if os.path.exists(HIERARCHY_PATH):
        with open(HIERARCHY_PATH, "r") as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) >= 2:
                    p, c = int(parts[0]), int(parts[1])
                    if p < NUM_CLASSES and c < NUM_CLASSES:
                        parents[c].append(p)
    return parents


class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item):
        encoding = self.tokenizer.encode_plus(
            str(self.texts[item]),
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }


# --- 실행 ---
def run_ensemble():
    print("1. Loading Data...")
    test_pids, test_texts = load_test_data()
    parents_map = load_hierarchy()
    
    # 토크나이저를 각각 따로 로드
    print("   - Loading BERT Tokenizer...")
    tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-uncased")
    
    print("   - Loading DeBERTa Tokenizer...")
    tokenizer_deberta = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", use_fast=False)
    
    # 데이터셋과 로더도 각각 따로 생성
    dataset_bert = TestDataset(test_texts, tokenizer_bert, MAX_LEN)
    loader_bert = DataLoader(dataset_bert, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    
    dataset_deberta = TestDataset(test_texts, tokenizer_deberta, MAX_LEN)
    loader_deberta = DataLoader(dataset_deberta, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print("2. Loading Models...")
    model_bert = torch.load(PATH_BERT_MODEL, map_location=DEVICE, weights_only=False)
    model_bert.to(DEVICE).eval()
    
    model_deberta = torch.load(PATH_DEBERTA_MODEL, map_location=DEVICE, weights_only=False)
    model_deberta.to(DEVICE).eval()

    print("3. Ensembling...")
    all_preds = []
    
    # 두 로더를 zip으로 묶어서 동시에 순회
    with torch.no_grad():
        for batch_bert, batch_deberta in tqdm(zip(loader_bert, loader_deberta), total=len(loader_bert)):
            
            # BERT 입력
            ids_b = batch_bert['input_ids'].to(DEVICE)
            mask_b = batch_bert['attention_mask'].to(DEVICE)
            logits_bert = model_bert(ids_b, mask_b)
            
            # DeBERTa 입력
            ids_d = batch_deberta['input_ids'].to(DEVICE)
            mask_d = batch_deberta['attention_mask'].to(DEVICE)
            logits_deberta = model_deberta(ids_d, mask_d)

            # Soft Voting (DeBERTa 가중치 0.6)
            probs = (torch.sigmoid(logits_bert) * 0.4) + (torch.sigmoid(logits_deberta) * 0.6)
            batch_probs = probs.cpu().numpy()
            
            # 후처리
            for sample_probs in batch_probs:
                top_indices = sample_probs.argsort()[-10:][::-1]
                candidate_set = set()
                score_map = {}

                for cid in top_indices:
                    candidate_set.add(cid)
                    score_map[cid] = float(sample_probs[cid])
                    curr, depth = cid, 0
                    while curr in parents_map and depth < 5:
                        for pid in parents_map[curr]:
                            if pid < NUM_CLASSES:
                                candidate_set.add(pid)
                                if pid not in score_map:
                                    score_map[pid] = float(sample_probs[pid])
                            curr = pid
                        depth += 1
                
                sorted_candidates = sorted(list(candidate_set), key=lambda x: score_map.get(x, 0), reverse=True)
                final_labels = sorted_candidates[:3]
                
                if len(final_labels) < 2:
                    rem = [x for x in sample_probs.argsort()[::-1] if x not in final_labels]
                    for r in rem:
                        final_labels.append(r)
                        if len(final_labels) >= 2:
                            break
                
                all_preds.append(sorted(final_labels))

    print(f"4. Saving to {OUTPUT_CSV}...")
    with open(OUTPUT_CSV, "w", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        writer.writerow(["id", "labels"])
        for pid, labels in zip(test_pids, all_preds):
            writer.writerow([pid, ",".join(map(str, labels))])
    
    print("✅ Ensemble Done!")


if __name__ == "__main__":
    run_ensemble()

1. Loading Data...
   - Loading BERT Tokenizer...
   - Loading DeBERTa Tokenizer...
2. Loading Models...
3. Ensembling...
