## 데이터 불러오기

In [None]:
import pandas as pd

train_df = pd.read_csv('data/train.csv')
test_df = pd.read_csv('data/test.csv')

## 데이터 전처리

In [None]:
import re
import string
import nltk
from nltk.corpus import stopwords

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

def preprocess_text(text):
    # 특수 문자 제거
    text = re.sub(f"[{string.punctuation}]", " ", text)
    
    # 소문자로 변환
    text = text.lower()
    
    # stopwords 제거
    text = " ".join([word for word in text.split() if word not in stop_words])
    
    return text

train_df["preprocessed_text"] = train_df["text"].apply(preprocess_text)
test_df["preprocessed_text"] = test_df["text"].apply(preprocess_text)


# Dataset 클래스

In [None]:
!pip install transformers

In [None]:
from torch.utils.data import Dataset
from transformers import BertTokenizer

class NewsDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_length=512, is_test=False):
        self.dataframe = dataframe
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.is_test = is_test
        self.encoded_dict = {}
        
    def __len__(self):
        return len(self.dataframe)
    
    def __getitem__(self, idx):
        if idx not in self.encoded_dict:
            text = self.dataframe.loc[idx, "preprocessed_text"]
            self.encoded_dict[idx] = self.tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt',
                return_attention_mask=True
            )
            
        item = {
            'input_ids': self.encoded_dict[idx]['input_ids'].squeeze(0),
            'attention_mask': self.encoded_dict[idx]['attention_mask'].squeeze(0)
        }
        
        if not self.is_test:
            item['labels'] = self.dataframe.loc[idx, 'label']
            
        return item
