In [None]:
import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

# 1. 데이터 로드 및 분할
df = pd.read_csv('../data/processed/cleaned_reviews.csv')
train_df, val_df = train_test_split(df, test_size=0.1, random_state=42, stratify=df['label'])

# 2. 토크나이저 및 모델 설정
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# 3. 파이토치 Dataset 클래스 정의 (데이터 포장지)
class ReviewDataset(Dataset):
    def __init__(self, reviews, labels, tokenizer, max_len):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.reviews)

    def __getitem__(self, item):
        encoding = self.tokenizer(
            str(self.reviews[item]),
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt',
        )
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(self.labels[item], dtype=torch.long)
        }

# 4. 데이터셋 객체 생성
train_dataset = ReviewDataset(train_df['review_text'].values, train_df['label'].values, tokenizer, 128)
val_dataset = ReviewDataset(val_df['review_text'].values, val_df['label'].values, tokenizer, 128)

# 5. 모델 불러오기
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

# 6. 학습 설정 (Trainer API 활용)
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=2,              # 빠른 종료를 위해 2회만 학습
    per_device_train_batch_size=16,
    evaluation_strategy="epoch",      # 에폭마다 성능 측정
    save_strategy="epoch",
    load_best_model_at_end=True,
)

# 7. 트레이너 생성 및 학습 시작
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)

print(" 학습을 시작합니다...")
trainer.train()
print(" 학습 완료!")

KeyError: 'label'

22641
