In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
import sacrebleu
from tqdm.auto import tqdm
import csv

# 반주 어휘 파일을 로드하는 함수
def load_instrumental_vocab(vocab_file):
    with open(vocab_file, 'r', encoding='utf-8') as file:
        vocab = {line.strip(): i for i, line in enumerate(file.readlines())}
    return vocab

class MidiDataset(Dataset):
    def __init__(self, tokenizer, dataframe, vocab, max_length=512):
        self.tokenizer = tokenizer
        self.data = dataframe
        self.vocab = vocab  # 반주 어휘
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]
        instrumental_data = eval(item['instrumental'])  # 리스트 형태로 저장된 문자열을 리스트로 변환
        lyrics_data = item['lyrics']

        # 반주 데이터를 어휘 인덱스로 변환
        instrumental_tokens = [str(self.vocab[token]) if token in self.vocab else '0' for token in instrumental_data]
        instrumental_str = ' '.join(instrumental_tokens)

        input_text = f"instrumental: {instrumental_str} lyrics:"
        target_text = lyrics_data

        input_encodings = self.tokenizer(input_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")
        target_encodings = self.tokenizer(target_text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors="pt")

        input_ids = input_encodings.input_ids.squeeze()
        target_ids = target_encodings.input_ids.squeeze()

        labels = torch.cat((input_ids, target_ids), dim=0)
        labels = labels[:self.max_length]  # 최대 길이로 자르기
        input_ids = input_ids[:self.max_length]

        return input_ids, labels

# 데이터 로드 및 분할
vocab_file_path = '/content/drive/MyDrive/vocab_instrumental.vocab'  # 어휘 파일 경로 수정 필요
instrumental_vocab = load_instrumental_vocab(vocab_file_path)

df = pd.read_parquet('/content/drive/MyDrive/dataset_chords.parquet')  # 파일 경로 수정 필요
train_size = int(0.9 * len(df))
val_size = len(df) - train_size

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # 패딩 토큰 설정

dataset = MidiDataset(tokenizer, df, instrumental_vocab)
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# 입력과 타겟 예시 출력
first_batch = next(iter(train_loader))
source_example, target_example = first_batch[0][0], first_batch[1][0]

source_text = tokenizer.decode(source_example, skip_special_tokens=True)
target_text = tokenizer.decode(target_example, skip_special_tokens=True)

print("입력 예시:", source_text)
print("타겟 예시:", target_text)

# 모델 설정
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GPT2LMHeadModel.from_pretrained('gpt2').to(device)
optimizer = AdamW(model.parameters(), lr=1e-4)

# 학습
model.train()
for epoch in range(50):  # 학습 에포크
    for input_ids, labels in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        input_ids, labels = input_ids.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

# 검증 및 BLEU 점수 계산
model.eval()
predictions, actuals = [], []
with torch.no_grad():
    for input_ids, labels in tqdm(val_loader, desc="Validating"):
        input_ids = input_ids.to(device)
        outputs = model.generate(
            input_ids=input_ids,
            max_length=512,
            num_beams=5,
            repetition_penalty=2.0,
            no_repeat_ngram_size=4
        )
        pred_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        true_text = tokenizer.batch_decode(labels, skip_special_tokens=True)

        predictions.extend(pred_text)
        actuals.extend([true_text])

# BLEU-2gram 스코어 구하기
bleu_2_score = sacrebleu.corpus_bleu(predictions, actuals, weights=(0.5, 0.5)).score

# BLEU-3gram 스코어 구하기
bleu_3_score = sacrebleu.corpus_bleu(predictions, actuals, weights=(0.33, 0.33, 0.33)).score

print(f"BLEU-2 Score: {bleu_2_score}")
print(f"BLEU-3 Score: {bleu_3_score}")

# 결과 저장
with open('/content/drive/MyDrive/lyrics_predictions_gpt2.csv', 'w', newline='', encoding='utf-8') as csvfile:  # 파일 경로 수정 필요
    writer = csv.writer(csvfile)
    writer.writerow(['Predicted Lyrics', 'Actual Lyrics'])
    for pred, actual in zip(predictions, actuals):
        writer.writerow([pred, actual[0]])

# 모델 저장
model.save_pretrained('/content/drive/MyDrive/saved_model_gpt2')  # 파일 경로 수정 필요
tokenizer.save_pretrained('/content/drive/MyDrive/saved_tokenizer_gpt2')  # 파일 경로 수정 필요
