In [None]:
import re
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

# Huggingface 로그인
login("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")

# 시드 설정 (재현 가능성 보장)
SEED = 456
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# 디바이스 설정 (GPU 사용 가능 여부에 따라 선택)
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# 모델과 토크나이저 로드
model_name = "beomi/Llama-3-Open-Ko-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)

# 데이터 로드
data = pd.read_csv('original_train.csv')

# 라벨 매핑 정의
label_mapping = {
    0: '생활문화',
    1: '스포츠',
    2: '정치',
    3: '사회',
    4: 'IT과학',
    5: '경제',
    6: '세계',
}

# target을 사용해 label 컬럼 생성
data['label'] = data['target'].map(label_mapping)

# LLaMA 모델의 출력에서 정제된 텍스트 추출
def extract_corrected_text(output):
    matches = re.findall(r'"corrected_text":\s*"([^"]+)"', output)
    return matches[-1] if matches else None

# 문장을 주어진 주제에 맞게 수정하는 함수 정의
def clean_sentence(noisy_sentence, label):
    prompt = f"""
You are required to return the corrected sentence in JSON format. 
Ensure your response strictly adheres to the JSON structure below.
Understand the meaning of the sentence and correctly generate the sentence to fit the topic.
Input:
Original: "topic: {label}, sentence: {noisy_sentence}"
Output:
{{
    "corrected_text": "<corrected sentence>"
}}
"""
    inputs = tokenizer(prompt, return_tensors='pt').to(DEVICE)
    outputs = model.generate(**inputs, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id)
    decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
    corrected_sentence = extract_corrected_text(decoded_output)
    return corrected_sentence if corrected_sentence else noisy_sentence

# tqdm을 사용해 진행 상황 출력
tqdm.pandas()

# 필요 없는 'label' 열 삭제
data = data.drop(columns=['label'])

# 최종 데이터 저장
data.to_csv('train_final.csv', index=False)