### Шаг 1: Подготовка данных


In [1]:
from datasets import load_dataset
import pandas as pd

# Загрузка и подготовка данных
dataset = load_dataset('imdb')
train_data = dataset['train']
test_data = dataset['test']

# Составьте все возможные пары (positive, negative) из train подвыборки
train_df = pd.DataFrame(train_data)
positive_comments = train_df[train_df['label'] == 1]['text'].tolist()
negative_comments = train_df[train_df['label'] == 0]['text'].tolist()
pairs = [(pos, neg) for pos in positive_comments[:200] for neg in negative_comments[:200]]

# Обрезка текстов для датасета промптов
def truncate_text(text, max_tokens=20):
    tokens = text.split()
    return ' '.join(tokens[:max_tokens])

# Создание датасета промптов из тренировочной выборки
prompts_train = [truncate_text(text, max_tokens=15) for text in train_df['text'].tolist()]
prompts_train = prompts_train[:10000]

# Создание датасета промптов из тестовой выборки
prompts_test = [truncate_text(text, max_tokens=15) for text in test_data['text']]
prompts_test = prompts_test[:100]


Downloading readme:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

### Шаг 2: Реализация модели наград

In [11]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments
from datasets import Dataset
from trl import RewardTrainer

# Инициализация модели и токенизатора
model_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
reward_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=1)

# Подготовка данных для модели наград
def preprocess_reward_data(pairs, tokenizer, max_length=128):
    chosen_texts = [pair[0] for pair in pairs]
    rejected_texts = [pair[1] for pair in pairs]

    chosen_encodings = tokenizer(chosen_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    rejected_encodings = tokenizer(rejected_texts, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')

    dataset_dict = {
        'input_ids_chosen': chosen_encodings['input_ids'],
        'attention_mask_chosen': chosen_encodings['attention_mask'],
        'input_ids_rejected': rejected_encodings['input_ids'],
        'attention_mask_rejected': rejected_encodings['attention_mask']
    }

    return Dataset.from_dict(dataset_dict)

pairs_dataset = preprocess_reward_data(pairs, tokenizer)
print(pairs_dataset[:2])  # Вывод первых двух элементов для проверки

# Обучение модели наград
training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir='./logs',
    save_total_limit=1,
    fp16=True
)

trainer = RewardTrainer(
    model=reward_model,
    tokenizer=tokenizer,
    train_dataset=pairs_dataset,
    args=training_args
)

trainer.train()


ModuleNotFoundError: No module named 'trl'

In [6]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments
# Сохраняем обученную reward model
# reward_model.save_pretrained("trained_reward_model")
# tokenizer.save_pretrained("trained_reward_model")

# Загружаем reward model
reward_model = AutoModelForSequenceClassification.from_pretrained("trained_reward_model", num_labels=1)

# from transformers import DistilBertTokenizer, DistilBertModel
# tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-cased')
# reward_model = DistilBertModel.from_pretrained("distilbert-base-cased")

### Шаг 3: Имплементация метода WARP

In [9]:
import torch
from tqdm import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer, AdamW, AutoTokenizer, AutoModelForCausalLM
from collections import defaultdict
import os
from adabelief_pytorch import AdaBelief
import math
import numpy as np

# Параметры WARP
I = 2
M = 2
T = 100
mu = 0.01
eta = 0.3
batch_size = 128
lambda_ = 0.5
beta = 0.1 

# Инициализация модели GPT-2
tokenizer = AutoTokenizer.from_pretrained("lvwerra/gpt2-imdb")
model = GPT2LMHeadModel.from_pretrained("lvwerra/gpt2-imdb")

# Pad token
tokenizer.pad_token = tokenizer.eos_token

# Функция генерации завершений предложений
def generate_completion(model, tokenizer, prompt, max_new_tokens=50):
    inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
    outputs = model.generate(inputs['input_ids'], max_new_tokens=max_new_tokens, num_return_sequences=1, attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.eos_token_id, temperature=0.9)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Функция для валидирования входных данных и замены OOV токенов
def validate_and_replace_oov(input_ids, vocab_size, pad_token_id):
    if torch.any(input_ids >= vocab_size):
        print(f"Found OOV token before replacement: {input_ids}")
        input_ids = torch.where(input_ids >= vocab_size, torch.full_like(input_ids, pad_token_id), input_ids)
    if torch.any(input_ids >= vocab_size):
        print(f"Error: Still OOV tokens present after replacement: {input_ids}")
    return input_ids

# Функция для подготовки входных данных и проверки того, что они в диапазоне валидных токенов
def preprocess_inputs(prompt, completion, tokenizer, vocab_size, pad_token_id):
    reward_input = tokenizer(prompt + completion, return_tensors='pt', padding=True, truncation=True)
    reward_input_ids = reward_input['input_ids']
    
    # Валидирование и замена OOV токенов
    reward_input_ids = validate_and_replace_oov(reward_input_ids, vocab_size, pad_token_id)
    
    # Проверка на то, что входные данные входят в диапазон валидных токенов
    if torch.any(reward_input_ids >= tokenizer.vocab_size):
        raise ValueError(f"final check Failed Removing OOV tokens: {reward_input_ids}")
    
    return reward_input['input_ids'], reward_input['attention_mask']

# Функции по поиску повторяющихся токенов
def find_repeated_ngrams(tokens, n):
    ngram_counts = defaultdict(int)
    repeated_ngrams = set()

    for i in range(len(tokens) - n + 1):
        ngram = tuple(tokens[i:i + n])
        ngram_counts[ngram] += 1
        if ngram_counts[ngram] > 1:
            repeated_ngrams.add(ngram)

    return repeated_ngrams

def clean_repetitive_tokens(text, max_repeats=2, ngram_min=2, ngram_max=5):
    tokens = text.split()
    
    # 1: Limit n-gram repetitions
    for n in range(ngram_min, ngram_max + 1):
        repeated_ngrams = find_repeated_ngrams(tokens, n)
        for ngram in repeated_ngrams:
            count = 0
            new_tokens = []
            i = 0
            while i < len(tokens):
                current_ngram = tuple(tokens[i:i + n])
                if current_ngram == ngram:
                    count += 1
                    if count <= max_repeats:
                        new_tokens.extend(current_ngram)
                    # Skip the repeated ngram
                    i += n
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens

    # 2: Limit single token repetitions
    clean_tokens = []
    prev_token = ''
    repeat_count = 0

    for token in tokens:
        if token == prev_token:
            repeat_count += 1
        else:
            repeat_count = 0

        if repeat_count < max_repeats:
            clean_tokens.append(token)
        prev_token = token

    return ' '.join(clean_tokens)

# SLERP interpolation
# def slerp(theta_init, thetas, lam):
#     theta_slerp = {}
#     for name in theta_init:
#         # Подсчёт deltas
#         deltas = [theta[name] - theta_init[name] for theta in thetas]

#         # Подсчёт omega
#         norm_delta1 = torch.norm(deltas[0])
#         norm_delta2 = torch.norm(deltas[1])
#         cos_omega = (deltas[0] * deltas[1]).sum() / (norm_delta1 * norm_delta2)  # Элементное умножение и сумма всех элементов
#         omega = torch.acos(cos_omega).item()

#         # Подсчёт SLERP interpolation
#         sin_omega = math.sin(omega)
#         theta_slerp[name] = theta_init[name] + (
#             (math.sin((1 - lam) * omega) / sin_omega) * deltas[0] +
#             (math.sin(lam * omega) / sin_omega) * deltas[1]
#         )
    
#     return theta_slerp

def slerp(val, low, high):
    omega = np.arccos(np.clip(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high)), -1.0, 1.0))
    so = np.sin(omega)
    if so == 0:
        return (1.0 - val) * low + val * high  # Fallback for very small angles
    return (np.sin((1.0 - val) * omega) / so) * low + (np.sin(val * omega) / so) * high

def slerp_tensors(val, low, high):
    assert low.shape == high.shape, "Input tensors must have the same shape for slerp"
    low_np, high_np = low.cpu().numpy(), high.cpu().numpy()
    result_np = np.zeros_like(low_np, dtype=low_np.dtype)
    if len(low_np.shape) == 1:
        result_np = slerp(val, low_np, high_np)
    else:
        for i in range(low_np.shape[0]):
            result_np[i] = slerp(val, low_np[i], high_np[i])
    result_tensor = torch.from_numpy(result_np).to(low.device)
    return result_tensor

def weighted_average(weights_list, weight_decay):
    """
    Выполняет взвешенное среднее по экземплярам состояний моделей.
    """
    avg_state_dict = weights_list[0].copy()
    for key in avg_state_dict.keys():
        avg_state_dict[key] = torch.stack([weights[key] for weights in weights_list], dim=0).mean(dim=0)
    return {key: weight_decay * avg_state_dict[key] + (1 - weight_decay) * torch.stack([weights[key] for weights in weights_list], dim=0).mean(dim=0) for key in avg_state_dict.keys()}

# Обучение WARP
def wrap(model, reward_model, tokenizer, prompts_train, I, M, T, beta, mu, eta, batch_size):
    theta_init = model.state_dict()
    for i in range(I):
        theta_m_list = []
        for m in range(M):
            model.load_state_dict(theta_init)
            ema_state_dict = {k: v.clone() for k, v in theta_init.items()}
            
            optimizer = AdaBelief(model.parameters(), lr=1e-6, print_change_log = False)
            
            for t in tqdm(range(T), desc=f"Iteration {i+1}, Run {m+1}"):
                batch_prompts = np.random.choice(prompts_train, batch_size, replace=False)
                
                # Генерация продолжений
                completions = [generate_completion(model, tokenizer, prompt) for prompt in batch_prompts]
                
                # Подсчёт наград
                rewards = []
                for prompt, completion in zip(batch_prompts, completions):
                    try:
                        completion = clean_repetitive_tokens(completion)
                        reward_input_ids, attention_mask = preprocess_inputs(prompt, completion, tokenizer, tokenizer.vocab_size, tokenizer.pad_token_id)
                        # print("0")
                        reward_output = reward_model(reward_input_ids, attention_mask=attention_mask)
                        # print("123")
                        if isinstance(reward_output, dict):
                            reward = reward_output['logits'][0].item()
                        else:
                            reward = reward_output.logits[0].item()

                        # print("456")
                        with torch.no_grad():
                            model_outputs = model(reward_input_ids, attention_mask=attention_mask).logits
                            ema_model_outputs = model(reward_input_ids, attention_mask=attention_mask).logits
                            kl_div = torch.nn.functional.kl_div(
                                torch.log_softmax(model_outputs, dim=-1),
                                torch.softmax(ema_model_outputs, dim=-1),
                                reduction='batchmean'
                            )
                        rewards.append(reward - beta * kl_div.item())
                        # print("789")
                    except Exception as e:
                        # print(f"Error during reward computation: {e}")
                        continue
                        
                # Обновление параметров модели
                if rewards:
                    rewards_tensor = torch.tensor(rewards, dtype=torch.float32, requires_grad=True)
                    loss = -torch.mean(rewards_tensor)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    # Обновление EMA weights
                    with torch.no_grad():
                        state_dict = model.state_dict()
                        for name, param in state_dict.items():
                            ema_state_dict[name].data.copy_((1 - mu) * ema_state_dict[name].data + mu * param)
            
            theta_m = model.state_dict()
            theta_m_list.append(theta_m)

        # SLERP Weight Interpolation
        theta_slerp_i = theta_m_list[0].copy()
        for param_name in theta_slerp_i:
            for j in range(1, len(theta_m_list)):
                theta_slerp_i[param_name] = slerp_tensors(1 / (j + 1), theta_slerp_i[param_name], theta_m_list[j][param_name])

        # LITI Interpolation with initial weights
        for name in theta_init:
            theta_init[name] = (1 - eta) * theta_init[name] + eta * theta_slerp_i[name]
    
    # Сохранение Pareto front weights
    pareto_front = []
    for eta in range(0, 5):
        eta_val = eta / 4.0
        model_weights = {name: (1 - eta_val) * theta_init[name] + eta_val * theta_slerp_i[name] for name in theta_init}
        pareto_front.append(model_weights)
        
    save_dir = "./pareto_models"
    os.makedirs(save_dir, exist_ok=True)
    for i, model_weights in enumerate(pareto_front):
        model_path = os.path.join(save_dir, f"model_weights_{i}.pth")
        torch.save(model_weights, model_path)
        print(f"Saved model: {model_path}")
    
    return pareto_front

# Вывод модели для тестирования
pareto_front = wrap(model, reward_model, tokenizer, prompts_train, I, M, T, beta, mu, eta, batch_size)


Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


Iteration 1, Run 1: 100%|██████████| 100/100 [3:06:01<00:00, 111.62s/it] 


Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


Iteration 1, Run 2: 100%|██████████| 100/100 [3:05:43<00:00, 111.43s/it] 


Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


Iteration 2, Run 1: 100%|██████████| 100/100 [3:07:34<00:00, 112.54s/it] 


Weight decoupling enabled in AdaBelief
Rectification enabled in AdaBelief


Iteration 2, Run 2: 100%|██████████| 100/100 [3:04:49<00:00, 110.90s/it] 


Saved model: ./pareto_models/model_weights_0.pth
Saved model: ./pareto_models/model_weights_1.pth
Saved model: ./pareto_models/model_weights_2.pth
Saved model: ./pareto_models/model_weights_3.pth
Saved model: ./pareto_models/model_weights_4.pth


### Шаг 4: Оценка модели

In [None]:
from collections import defaultdict
import torch.nn.functional as F

# Функция генерации завершений предложений
def generate_completion(model, tokenizer, prompt, max_length=50):
    inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True, max_length=max_length)
    outputs = model.generate(inputs['input_ids'], max_length=max_length, num_return_sequences=1, attention_mask=inputs['attention_mask'], pad_token_id=tokenizer.eos_token_id)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Функция для валидирования входных данных и замены OOV токенов
def validate_and_replace_oov(input_ids, vocab_size, pad_token_id):
    if torch.any(input_ids >= vocab_size):
        print(f"Found OOV token before replacement: {input_ids}")
        input_ids = torch.where(input_ids >= vocab_size, torch.full_like(input_ids, pad_token_id), input_ids)
    if torch.any(input_ids >= vocab_size):
        print(f"Error: Still OOV tokens present after replacement: {input_ids}")
    return input_ids

# Функция для подготовки входных данных и проверки того, что они в диапазоне валидных токенов
def preprocess_inputs(prompt, completion, tokenizer, vocab_size, pad_token_id):
    reward_input = tokenizer(prompt + completion, return_tensors='pt', padding=True, truncation=True)
    reward_input_ids = reward_input['input_ids']
    
    # Валидирование и замена OOV токенов
    reward_input_ids = validate_and_replace_oov(reward_input_ids, vocab_size, pad_token_id)
    
    # Проверка на то, что входные данные входят в диапазон валидных токенов
    if torch.any(reward_input_ids >= tokenizer.vocab_size):
        raise ValueError(f"final check Failed Removing OOV tokens: {reward_input_ids}")
    
    return reward_input['input_ids'], reward_input['attention_mask']

# Функции по поиску повторяющихся токенов
def find_repeated_ngrams(tokens, n):
    ngram_counts = defaultdict(int)
    repeated_ngrams = set()

    for i in range(len(tokens) - n + 1):
        ngram = tuple(tokens[i:i + n])
        ngram_counts[ngram] += 1
        if ngram_counts[ngram] > 1:
            repeated_ngrams.add(ngram)

    return repeated_ngrams

def clean_repetitive_tokens(text, max_repeats=2, ngram_min=2, ngram_max=5):
    tokens = text.split()
    
    # 1: Limit n-gram repetitions
    for n in range(ngram_min, ngram_max + 1):
        repeated_ngrams = find_repeated_ngrams(tokens, n)
        for ngram in repeated_ngrams:
            count = 0
            new_tokens = []
            i = 0
            while i < len(tokens):
                current_ngram = tuple(tokens[i:i + n])
                if current_ngram == ngram:
                    count += 1
                    if count <= max_repeats:
                        new_tokens.extend(current_ngram)
                    # Skip the repeated ngram
                    i += n
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            tokens = new_tokens

    # 2: Limit single token repetitions
    clean_tokens = []
    prev_token = ''
    repeat_count = 0

    for token in tokens:
        if token == prev_token:
            repeat_count += 1
        else:
            repeat_count = 0

        if repeat_count < max_repeats:
            clean_tokens.append(token)
        prev_token = token

    return ' '.join(clean_tokens)

# def generate_completion(model, tokenizer, prompt, pad_token_id=50256, max_length=50):
#     inputs = tokenizer(prompt, return_tensors="pt")
#     outputs = model.generate(
#         **inputs,
#         pad_token_id=pad_token_id,
#         max_length=max_length,
#         num_return_sequences=1,
#         no_repeat_ngram_size=3,  # Предотвращение повторений
#         repetition_penalty=2.5,  # Штраф за повторение
#         eos_token_id=pad_token_id,
#         temperature=0.7,         # Пониженная температура
#         top_k=50,                # Обрезка по top_k
#         top_p=0.95               # Метод nucleus sampling
#     )
#     return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Средняя награда и KL дивергенция для обученной модели
def evaluate_model(model, prompts_test, reward_model, tokenizer):
    rewards = []
    kl_divs = []
    
    for index, prompt in enumerate(prompts_test):
        print(f"Processing prompt {index + 1}/{len(prompts_test)}")
        try:
            completion = generate_completion(model, tokenizer, prompt)
            # completion = clean_repetitive_tokens(completion)
            print(f"Prompt: {prompt}")
            print(f"Completion: {completion}")

            reward_input_ids, attention_mask = preprocess_inputs(prompt, completion, tokenizer, tokenizer.vocab_size, tokenizer.pad_token_id)
            reward = reward_model(reward_input_ids, attention_mask=attention_mask).logits[0].item()
            kl_div = F.kl_div(
                F.log_softmax(model(reward_input_ids, attention_mask=attention_mask).logits, dim=-1),
                F.softmax(sft_model(reward_input_ids, attention_mask=attention_mask).logits, dim=-1), reduction='batchmean'
            )
            rewards.append(reward)
            kl_divs.append(kl_div.item())
        except IndexError as e:
            continue
            print(f"IndexError encountered for prompt {index + 1}: {e}")
    
    if len(rewards) > 0:
        avg_reward = sum(rewards) / len(rewards)
    else:
        avg_reward = 0.0  # или любое другое значение по умолчанию, которое имеет смысл в вашем контексте

    if len(kl_divs) > 0:
        avg_kl_div = sum(kl_divs) / len(kl_divs)
    else:
        avg_kl_div = 0.0
    print(f'Average Reward: {avg_reward}, Average KL Divergence: {avg_kl_div}')
    return avg_reward, avg_kl_div


# Средняя награда и KL дивергенция для обученной модели
model_name = "lvwerra/gpt2-imdb"
sft_model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained("lvwerra/gpt2-imdb")
tokenizer.pad_token = tokenizer.eos_token

# Оценка средних наград и KL дивергенции для нескольких моделей из Pareto фронта
model_paths = ['./pareto_models/model_weights_0.pth', 
               './pareto_models/model_weights_1.pth', 
               './pareto_models/model_weights_2.pth', 
               './pareto_models/model_weights_3.pth', 
               './pareto_models/model_weights_4.pth']
results = []

for model_path in model_paths:
    pareto_model = GPT2LMHeadModel.from_pretrained(model_name)
    state_dict = torch.load(model_path)
    pareto_model.load_state_dict(state_dict)
    avg_reward, avg_kl_div = evaluate_model(pareto_model, prompts_test, reward_model, tokenizer)
    results.append((model_path, avg_reward, avg_kl_div))

# Оценка модели SFT
avg_reward_sft, avg_kl_div_sft = evaluate_model(sft_model, prompts_test, reward_model, tokenizer)
results.append(("SFT Model", avg_reward_sft, avg_kl_div_sft))

# Оценка модели
avg_reward_model, avg_kl_div_model = evaluate_model(model, prompts_test, reward_model, tokenizer)
results.append(("Model", avg_reward_model, avg_kl_div_model))

for result in results:
    print(f"Model: {result[0]}, Avg Reward: {result[1]}, Avg KL Divergence: {result[2]}")



  state_dict = torch.load(model_path)


Processing prompt 1/100
Prompt: I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are
Completion: I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are not for everyone, but I do enjoy them. I love the fact that the characters are real people, and that they are not just a
Processing prompt 2/100
Prompt: Worth the entertainment value of a rental, especially if you like action movies. This one
Completion: Worth the entertainment value of a rental, especially if you like action movies. This one is a must see.
Processing prompt 3/100
Prompt: its a totally average film with a few semi-alright action sequences that make the plot
Completion: its a totally average film with a few semi-alright action sequences that make the plot seem like it's going nowhere. The acting is average, the script is bad, and the acting is bad. The only thing that makes this film worth watching
Processing prompt 4/100
Prompt: STAR RATING: ***** Saturday Night **** Friday Night *** 

In [None]:
import matplotlib.pyplot as plt

# Гипотетический пример значений для I, M и T
I_values = [2, 3, 4]
M_values = [5, 10, 15]
T_values = [50, 100, 150]

# Средние награды и KL дивергенции, полученные из экспериментов
avg_rewards = [result[1] for result in results[:-1]]
avg_kl_divs = [result[2] for result in results[:-1]]

plt.figure(figsize=(10, 5))
plt.plot(I_values, avg_rewards, label="Avg Reward", marker='o')
plt.plot(I_values, avg_kl_divs, label="Avg KL Divergence", marker='x')
plt.xlabel("I Values")
plt.ylabel("Metrics")
plt.legend()
plt.title("Impact of I on Average Reward and KL Divergence")
plt.show()