In [None]:
import torch
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import Dataset
from tqdm import tqdm

# 시드 설정
SEED = 456
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# 디바이스 설정
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 모델과 토크나이저 로드
model_name = "beomi/Llama-3-Open-Ko-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=7).to(DEVICE)

# 패딩 토큰 설정
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# 데이터 로드
noise_data = pd.read_csv("noise_train.csv")
clean_data = pd.read_csv("clean_train.csv")

# 텍스트와 레이블을 학습에 사용될 형태로 변환
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset = Dataset.from_pandas(noise_data[['text', 'target']])
train_dataset = train_dataset.map(tokenize_function, batched=True)
train_dataset = train_dataset.rename_column("target", "labels")

# 학습 파라미터 설정
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    seed=SEED,
)

# Trainer 설정
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

# 모델 학습
trainer.train()

# Clean Data 재라벨링 준비
clean_dataset = Dataset.from_dict({"text": clean_data['text'].tolist()})
clean_dataset = clean_dataset.map(tokenize_function, batched=True)

# 예측 함수
def relabel_texts(dataset):
    predictions = trainer.predict(dataset).predictions
    predicted_labels = predictions.argmax(axis=1)
    return predicted_labels

# 재라벨링 수행
clean_data['target'] = relabel_texts(clean_dataset)

# 데이터 저장
clean_data['id'] = clean_data.index  # 인덱스를 'id'로 설정
final_data = clean_data[['id', 'text', 'target']]
final_data.to_csv("relabel_clean_train.csv", index=False)