In [None]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import math
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModel, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import defaultdict

# --- 1. 환경 설정 ---
def seed_everything(seed=42):
    import random
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(42)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# [핵심 변경] 모델을 DeBERTa v3로 변경
MODEL_NAME = "microsoft/deberta-v3-base" 
BATCH_SIZE = 16 # OOM 나면 8로 줄이세요
EPOCHS = 5
LEARNING_RATE = 1e-5 
MAX_LEN = 256
NUM_CLASSES = 531
HIDDEN_DIM = 768

# 경로 (기존과 동일)
BASE_DIR = "../Amazon_products"
TRAIN_CORPUS_PATH = os.path.join(BASE_DIR, "train/train_corpus.txt")
HIERARCHY_PATH = os.path.join(BASE_DIR, "class_hierarchy.txt")
SILVER_LABELS_PATH = "train_round_2.csv" 
OUTPUT_MODEL_DIR = "saved_model_deberta_gnn"

if not os.path.exists(OUTPUT_MODEL_DIR):
    os.makedirs(OUTPUT_MODEL_DIR)

# --- 2. 인접 행렬 생성 (기존과 동일) ---
def build_adjacency_matrix(hierarchy_path, num_classes):
    adj = torch.eye(num_classes)
    with open(hierarchy_path, "r", encoding="utf-8") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 2:
                p, c = int(parts[0]), int(parts[1])
                if p < num_classes and c < num_classes:
                    adj[p, c] = 1
                    adj[c, p] = 1
    rowsum = torch.sum(adj, dim=1)
    d_inv_sqrt = torch.pow(rowsum, -0.5)
    d_inv_sqrt[torch.isinf(d_inv_sqrt)] = 0.
    d_mat_inv_sqrt = torch.diag(d_inv_sqrt)
    norm_adj = torch.mm(torch.mm(d_mat_inv_sqrt, adj), d_mat_inv_sqrt)
    return norm_adj.to(DEVICE)

# --- 3. GCN Layer (기존과 동일) ---
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.weight = nn.Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

# --- 4. [수정됨] DeBERTa 호환 GNN 모델 ---
class DebertaGCN(nn.Module):
    def __init__(self, model_name, num_classes, hidden_dim, adj_matrix):
        super(DebertaGCN, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.1)
        
        # 라벨 임베딩
        self.label_embedding = nn.Parameter(torch.FloatTensor(num_classes, hidden_dim))
        nn.init.xavier_uniform_(self.label_embedding)
        
        self.gcn = GraphConvolution(hidden_dim, hidden_dim)
        self.adj_matrix = adj_matrix

    def forward(self, input_ids, attention_mask):
        # DeBERTa 출력 처리
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        
        # [중요] DeBERTa는 pooler_output이 없을 수 있으므로 CLS 토큰(0번 인덱스)을 직접 가져옴
        last_hidden_state = outputs.last_hidden_state
        doc_embedding = last_hidden_state[:, 0, :] # [Batch, Hidden]
        doc_embedding = self.dropout(doc_embedding)
        
        # GCN 처리
        refined_label_embedding = self.gcn(self.label_embedding, self.adj_matrix)
        refined_label_embedding = torch.tanh(refined_label_embedding)
        
        # 예측
        logits = torch.mm(doc_embedding, refined_label_embedding.t())
        return logits

# --- 5. 학습 준비 ---
class ReviewDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, item):
        text = str(self.texts[item])
        label_list = [int(x) for x in str(self.labels[item]).split(",") if x]
        label_tensor = torch.zeros(NUM_CLASSES)
        label_tensor[label_list] = 1.0
        
        encoding = self.tokenizer.encode_plus(
            text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors='pt'
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label_tensor
        }

print("Loading data...")
pid2text = {}
with open(TRAIN_CORPUS_PATH, "r", encoding="utf-8") as f:
    for line in f:
        p, t = line.strip().split("\t", 1)
        pid2text[int(p)] = t

df = pd.read_csv(SILVER_LABELS_PATH)
df['text'] = df['pid'].map(pid2text)
df = df.dropna()

train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)

train_dataset = ReviewDataset(train_df.text.to_numpy(), train_df.labels.to_numpy(), tokenizer, MAX_LEN)
val_dataset = ReviewDataset(val_df.text.to_numpy(), val_df.labels.to_numpy(), tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# --- 6. 학습 실행 ---
adj_matrix = build_adjacency_matrix(HIERARCHY_PATH, NUM_CLASSES)
model = DebertaGCN(MODEL_NAME, NUM_CLASSES, HIDDEN_DIM, adj_matrix).to(DEVICE)

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, 0, total_steps)
loss_fn = nn.BCEWithLogitsLoss()

print(f"Starting Training with {MODEL_NAME}...")
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        ids = batch['input_ids'].to(DEVICE)
        mask = batch['attention_mask'].to(DEVICE)
        lbls = batch['labels'].to(DEVICE)
        
        optimizer.zero_grad()
        logits = model(ids, mask)
        loss = loss_fn(logits, lbls)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        
    avg_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1} Train Loss: {avg_loss:.4f}")
    
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            ids = batch['input_ids'].to(DEVICE)
            mask = batch['attention_mask'].to(DEVICE)
            lbls = batch['labels'].to(DEVICE)
            logits = model(ids, mask)
            val_loss += loss_fn(logits, lbls).item()
            
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}")
    
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        # 모델 저장
        torch.save(model, os.path.join(OUTPUT_MODEL_DIR, "best_model_deberta.pt"))
        print("Model Saved!")

print("Training Complete.")

Loading data...
Starting Training with microsoft/deberta-v3-base...


Epoch 1:   0%|          | 0/1659 [00:00<?, ?it/s]