In [8]:
!pip install transformers datasets evaluate rouge-score bert-score -q

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset
from transformers import GPT2Tokenizer, AutoModelForCausalLM, AutoTokenizer, set_seed
from torch.cuda.amp import autocast, GradScaler

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

import json, os
import evaluate
import matplotlib.pyplot as plt
import random

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [9]:
MAX_LEN = 384
BATCH_SIZE = 8

dataset = load_dataset("Anthropic/hh-rlhf", split="train[:3000]")

# Функция выделения промпта и ответа
def split_prompt_answer(text: str):
    lines = [line.strip() for line in text.strip().split("\n") if line.strip()]
    assistant_indices = [i for i, line in enumerate(lines) if line.startswith("Assistant:")]
    
    if not assistant_indices:
        return text, "" 
    
    last_idx = assistant_indices[-1]
    prompt_lines = lines[:last_idx] + ["Assistant:"]
    prompt = "\n".join(prompt_lines)
    
    answer_lines = [lines[last_idx][len("Assistant:"):].strip()] + lines[last_idx+1:]
    answer = "\n".join(answer_lines).strip()
    
    return prompt, answer

def format_dpo_example(example):
    chosen_prompt, chosen_answer = split_prompt_answer(example["chosen"])
    rejected_prompt, rejected_answer = split_prompt_answer(example["rejected"])
    
    example["prompt_formatted"] = chosen_prompt  
    example["chosen_formatted"] = chosen_answer
    example["rejected_formatted"] = rejected_answer
    return example

dataset = dataset.map(format_dpo_example)

dataset = dataset.train_test_split(test_size=0.1)
train_data_raw = dataset["train"]
val_data_raw = dataset["test"]

class DPOPairDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_len=MAX_LEN):
        self.dataset = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        ex = self.dataset[idx]
        
        chosen_enc = self.tokenizer(
            ex["prompt_formatted"], ex["chosen_formatted"],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
            return_overflowing_tokens=False, 
        )
        rejected_enc = self.tokenizer(
            ex["prompt_formatted"], ex["rejected_formatted"],
            padding="max_length",
            truncation=True,
            max_length=self.max_len,
            return_tensors="pt",
            return_overflowing_tokens=False, 
        )
    
        return {
            "chosen_input_ids": chosen_enc.input_ids.squeeze(0),
            "chosen_attention_mask": chosen_enc.attention_mask.squeeze(0),
            "rejected_input_ids": rejected_enc.input_ids.squeeze(0),
            "rejected_attention_mask": rejected_enc.attention_mask.squeeze(0),
            "prompt": ex["prompt_formatted"],
            "chosen_answer": ex["chosen_formatted"],
        }
        
train_dataset = DPOPairDataset(train_data_raw, tokenizer)
val_dataset = DPOPairDataset(val_data_raw, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=torch.cuda.is_available())

print(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")

Train size: 2700, Validation size: 300


In [10]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.resize_token_embeddings(len(tokenizer))

optimizer = optim.AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)


In [11]:
scaler = GradScaler()

def train_one_epoch(model, dataloader, optimizer, scaler, device, beta):
    model.train()
    total_loss, total_correct, total_samples = 0, 0, 0

    for batch in dataloader:
        optimizer.zero_grad()

        chosen_input_ids = batch['chosen_input_ids'].to(device)
        chosen_attention_mask = batch['chosen_attention_mask'].to(device)
        rejected_input_ids = batch['rejected_input_ids'].to(device)
        rejected_attention_mask = batch['rejected_attention_mask'].to(device)

        with autocast(dtype=torch.float16):
            
            chosen_outputs = model(input_ids=chosen_input_ids, attention_mask=chosen_attention_mask)
            logp_chosen = -F.cross_entropy(
                chosen_outputs.logits[:, :-1, :].reshape(-1, chosen_outputs.logits.size(-1)),
                chosen_input_ids[:, 1:].reshape(-1),
                reduction="none"
            ).reshape(chosen_outputs.logits.size(0), -1).sum(dim=1)

            rejected_outputs = model(input_ids=rejected_input_ids, attention_mask=rejected_attention_mask)
            logp_rejected = -F.cross_entropy(
                rejected_outputs.logits[:, :-1, :].reshape(-1, rejected_outputs.logits.size(-1)),
                rejected_input_ids[:, 1:].reshape(-1),
                reduction="none"
            ).reshape(rejected_outputs.logits.size(0), -1).sum(dim=1)

            # DPO loss
            loss = -torch.mean(F.logsigmoid(beta * (logp_chosen - logp_rejected)))

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Accuracy
        correct = (logp_chosen > logp_rejected).sum().item()
        total_correct += correct
        total_samples += chosen_input_ids.size(0)
        total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc


def validate(model, dataloader, device, beta):
    model.eval()
    total_loss, total_correct, total_samples = 0, 0, 0

    with torch.no_grad():
        for batch in dataloader:
            chosen_input_ids = batch['chosen_input_ids'].to(device)
            chosen_attention_mask = batch['chosen_attention_mask'].to(device)
            rejected_input_ids = batch['rejected_input_ids'].to(device)
            rejected_attention_mask = batch['rejected_attention_mask'].to(device)

            chosen_outputs = model(input_ids=chosen_input_ids, attention_mask=chosen_attention_mask)
            logp_chosen = -F.cross_entropy(
                chosen_outputs.logits[:, :-1, :].reshape(-1, chosen_outputs.logits.size(-1)),
                chosen_input_ids[:, 1:].reshape(-1),
                reduction="none"
            ).reshape(chosen_outputs.logits.size(0), -1).sum(dim=1)

            rejected_outputs = model(input_ids=rejected_input_ids, attention_mask=rejected_attention_mask)
            logp_rejected = -F.cross_entropy(
                rejected_outputs.logits[:, :-1, :].reshape(-1, rejected_outputs.logits.size(-1)),
                rejected_input_ids[:, 1:].reshape(-1),
                reduction="none"
            ).reshape(rejected_outputs.logits.size(0), -1).sum(dim=1)

            loss = -torch.mean(F.logsigmoid(beta * (logp_chosen - logp_rejected)))
            correct = (logp_chosen > logp_rejected).sum().item()

            total_loss += loss.item()
            total_correct += correct
            total_samples += chosen_input_ids.size(0)

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_correct / total_samples
    return avg_loss, avg_acc

  scaler = GradScaler()


In [None]:
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)

epochs = 12
patience = 2
beta = 0.2

patience_counter = 0
best_epoch = 0

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

best_val_loss = float("inf")

for epoch in range(epochs):
    
    print(f"Epoch {epoch+1}/{epochs} ")

    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, scaler, device, beta)
    val_loss, val_acc = validate(model, val_loader, device, beta)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(
        f"Epoch {epoch+1}/{epochs} | "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        patience_counter = 0
        os.makedirs("results", exist_ok=True)
        model.save_pretrained("results/dpo_finetuned_gpt2")
        tokenizer.save_pretrained("results/dpo_finetuned_gpt2")
        print(f"New best model saved (val_loss={val_loss:.4f}, val_acc={val_acc:.4f})")
    else:
        patience_counter += 1
        print(f"No improvement ({patience_counter}/{patience})")

    if patience_counter >= patience:
        print("!!! Early stopping triggered !!!")
        break

print(f"\n Loading best model from epoch {best_epoch} (val_loss={best_val_loss:.4f})")
model = AutoModelForCausalLM.from_pretrained("results/dpo_finetuned_gpt2").to(device)
tokenizer = GPT2Tokenizer.from_pretrained("results/dpo_finetuned_gpt2")
tokenizer.pad_token = tokenizer.eos_token

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# График Loss
axes[0].plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss")
axes[0].plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("DPO Training Loss")
axes[0].legend()
axes[0].grid(True)

# График Accuracy
axes[1].plot(range(1, len(train_accuracies) + 1), train_accuracies, label="Train Accuracy")
axes[1].plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Validation Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Accuracy")
axes[1].set_title("DPO Training Accuracy")
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()

In [None]:
bertscore = evaluate.load("bertscore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_response(model, tokenizer, prompt, max_new_tokens=80):
    prompt = prompt.strip()
    if not prompt.endswith("Assistant:"):
        prompt += "\nAssistant:"

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            min_new_tokens=8,
            do_sample=True,
            top_p=0.9,
            temperature=1.0,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id
        )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    answer = decoded[len(prompt):].strip() if decoded.startswith(prompt) else decoded.strip()
    return answer

# базовая модель
base_model_name = "gpt2"
base_tokenizer = AutoTokenizer.from_pretrained(base_model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
base_model.eval()

# оценка на подвыборке валидационных данных
val_samples = random.sample(list(val_data_raw), min(10, len(val_data_raw)))

references, base_outputs, tuned_outputs = [], [], []

for idx, sample in enumerate(val_samples):
    prompt = sample['prompt_formatted']
    ref_answer = sample['chosen_formatted']

    tuned_out = generate_response(model, tokenizer, prompt)
    base_out = generate_response(base_model, base_tokenizer, prompt)

    print(f"\n=== Prompt {idx+1} ===\n{prompt}\n{'-'*40}")
    print(f"Reference:\n{ref_answer}\n{'-'*40}")
    print(f"Base model:\n{base_out}\n{'-'*40}")
    print(f"Tuned model:\n{tuned_out}\n{'='*40}")

    references.append([ref_answer])
    base_outputs.append(base_out)
    tuned_outputs.append(tuned_out)

# Подсчет BERTScore
bert_results = bertscore.compute(
    predictions=tuned_outputs,
    references=[r[0] for r in references],
    lang="en"
)
mean_f1 = sum(bert_results["f1"]) / len(bert_results["f1"])
print(f"\nTuned BERTScore (mean F1): {mean_f1:.4f}")

results = []
for i in range(len(val_samples)):
    results.append({
        "prompt": val_samples[i]['prompt_formatted'],
        "reference": references[i][0],
        "base_output": base_outputs[i],
        "tuned_output": tuned_outputs[i]
    })

with open("results/eval_results.json", "w", encoding="utf-8") as f:
    json.dump(results, f, ensure_ascii=False, indent=2)