In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer, AutoModel
import math
from collections import defaultdict

# --- 1. 설정 ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LEN = 256
BATCH_SIZE = 32
NUM_CLASSES = 531
TOP_K_SAMPLES = 1000  # 추출할 샘플 개수

# 경로
BASE_DIR = "Amazon_products"
TEST_CORPUS_PATH = os.path.join(BASE_DIR, "test/test_corpus.txt")
CLASSES_PATH = os.path.join(BASE_DIR, "classes.txt")

# 모델 경로
PATH_BERT_MODEL = "saved_model_gnn/best_model_gnn.pt"
PATH_DEBERTA_MODEL = "saved_model_deberta_gnn/best_model_deberta.pt"
OUTPUT_CSV = "hard_samples_for_llm.csv"

# --- 2. 모델 클래스 정의 ---
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias: self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else: self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None: self.bias.data.uniform_(-stdv, stdv)
    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        return output + self.bias if self.bias is not None else output

class BertGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(BertGCN, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        doc_embedding = self.dropout(outputs.pooler_output)
        refined_label_embedding = torch.tanh(self.gcn(self.label_embedding, self.adj_matrix))
        return torch.mm(doc_embedding, refined_label_embedding.t())

class DebertaGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(DebertaGCN, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix
    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        doc_embedding = self.dropout(outputs.last_hidden_state[:, 0, :])
        refined_label_embedding = torch.tanh(self.gcn(self.label_embedding, self.adj_matrix))
        return torch.mm(doc_embedding, refined_label_embedding.t())

# --- 3. 유틸리티 ---
def load_test_data():
    pids, texts = [], []
    with open(TEST_CORPUS_PATH, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                pids.append(parts[0])
                texts.append(parts[1])
    return pids, texts

def load_class_names(path):
    id2name = {}
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split("\t", 1)
            if len(parts) == 2:
                id2name[int(parts[0])] = parts[1].strip()
    return id2name

class TestDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len
    def __len__(self): return len(self.texts)
    def __getitem__(self, item):
        encoding = self.tokenizer.encode_plus(
            str(self.texts[item]), add_special_tokens=True, max_length=self.max_len,
            return_token_type_ids=False, padding='max_length', truncation=True,
            return_attention_mask=True, return_tensors='pt'
        )
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten()}

# --- 4. 메인 실행 ---
def find_hard_samples():
    print("1. Loading Data...")
    test_pids, test_texts = load_test_data()
    id2name = load_class_names(CLASSES_PATH)
    
    # 각각의 토크나이저 로드 (use_fast=False 필수)
    tokenizer_bert = AutoTokenizer.from_pretrained("bert-base-uncased")
    tokenizer_deberta = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base", use_fast=False)
    
    loader_bert = DataLoader(TestDataset(test_texts, tokenizer_bert, MAX_LEN), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    loader_deberta = DataLoader(TestDataset(test_texts, tokenizer_deberta, MAX_LEN), batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

    print("2. Loading Models...")
    model_bert = torch.load(PATH_BERT_MODEL, map_location=DEVICE, weights_only=False)
    model_bert.to(DEVICE).eval()
    
    model_deberta = torch.load(PATH_DEBERTA_MODEL, map_location=DEVICE, weights_only=False)
    model_deberta.to(DEVICE).eval()

    results = []
    print("3. Calculating Uncertainty (Entropy)...")
    
    with torch.no_grad():
        for i, (batch_b, batch_d) in enumerate(tqdm(zip(loader_bert, loader_deberta), total=len(loader_bert))):
            # BERT 입력
            ids_b = batch_b['input_ids'].to(DEVICE)
            mask_b = batch_b['attention_mask'].to(DEVICE)
            
            # DeBERTa 입력
            ids_d = batch_d['input_ids'].to(DEVICE)
            mask_d = batch_d['attention_mask'].to(DEVICE)

            # 앙상블 예측
            logits_b = model_bert(ids_b, mask_b)
            logits_d = model_deberta(ids_d, mask_d)
            
            # Soft Voting (확률 평균)
            probs = (torch.sigmoid(logits_b) * 0.4) + (torch.sigmoid(logits_d) * 0.6)
            
            # Entropy 계산: -sum(p * log(p))
            # (p가 0이면 log가 무한대가 되므로 1e-9를 더해줌)
            entropy = -torch.sum(probs * torch.log(probs + 1e-9), dim=1)
            
            batch_entropy = entropy.cpu().numpy()
            batch_probs = probs.cpu().numpy()
            
            start_idx = i * BATCH_SIZE
            for j, ent in enumerate(batch_entropy):
                pid = test_pids[start_idx + j]
                text = test_texts[start_idx + j]
                
                # 모델이 생각하는 상위 10개 후보 (LLM에게 힌트로 주기 위함)
                top10_idx = batch_probs[j].argsort()[-10:][::-1]
                candidates = [f"{idx}: {id2name.get(idx, 'Unknown')}" for idx in top10_idx]
                
                results.append({
                    "pid": pid,
                    "text": text,
                    "entropy": ent,
                    "candidates": " | ".join(candidates)
                })

    # Entropy 높은 순으로 정렬하여 상위 1000개 추출
    df = pd.DataFrame(results)
    df = df.sort_values(by="entropy", ascending=False)
    
    hard_samples = df.head(TOP_K_SAMPLES)
    hard_samples.to_csv(OUTPUT_CSV, index=False)
    

if __name__ == "__main__":
    find_hard_samples()