# Домашнее задание 2: DPO и PPO

В этой домашке поближе познакомимся с двумя крайне популярными методами алаймента языковых моделей. В первой части вам предоставится возможность самостоятельно заимплементить DPO c нуля. Во второй части мы уже будем использовать библиотеку TRL и обучим PPO.

Обученные модели можно и нужно выложить на [🤗 HuggingFace](https://huggingface.co/). Зарегистрируйтесь там, подпишитесь на [deep vk](https://huggingface.co/deepvk) и создайте себе API токен.

Следуйте ячейкам тетрадки и заполняйте пропущенные ячейки. В конце тетрадки вы найдете задачи со звездочкой, чтобы получить максимальный балл!

## Импорты и вспомогательные функции

In [2]:
# Необходимые импорты (для обоих частей)
import inspect
import random
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from datasets import load_dataset
from huggingface_hub import HfApi, interpreter_login
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    PreTrainedTokenizerBase,
)
from trl import PPOConfig, PPOTrainer, RewardConfig, RewardTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
interpreter_login()

In [3]:
# Подготовим репозиторий для будущей модели и токенизатора
username = HfApi().whoami()["name"]
REPO_NAME = f"{username}/SmolLM-aligment"  # Или как вам хочется

print(f"Homework repository: '{REPO_NAME}'")

Homework repository: 'Azrail/SmolLM-aligment'


In [1]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


# Этой функцией будут помечены все места, которые необходимо дозаполнить
# Это могут быть как целые функции, так и отдельные части внутри них
# Всегда можно воспользоваться интроспекцией и найти места использования этой функции :)
def todo():
    stack = inspect.stack()
    caller_frame = stack[1]
    function_name = caller_frame.function
    line_number = caller_frame.lineno
    raise NotImplementedError(f"TODO at {function_name}, line {line_number}")


def disable_dropout_in_model(model):
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0

# Часть 1: DPO

Крайне простой метод, который в свое время произвел фурор, т.к. выгодно выделялся на фоне PPO. В отличие от PPO, требующего отдельно обучать Reward Model, Value Model и больших усилий в имплементации, DPO не требует явной ревард модели, а только датасета с человеческими преференсами вида: промпт, выбранный человеком ответ, отвергнутный человеком ответ. Простота также видна из лосса, по сути это весь метод:
$$
L_\text{DPO}(\pi_{\theta}; \pi_\text{ref}) = -E_{(x, y_w, y_l)\sim D}\left[\log \sigma \left(
\beta \log \frac{\pi_{\theta}(y_w\mid x)}{\pi_\text{ref}(y_w\mid x)} \thinspace
{- \beta \log \frac{\pi_{\theta}(y_l\mid x)}{\pi_\text{ref}(y_l\mid x)}}\right)\right]
$$

где:

- $\pi_{\theta}$ LLM которую мы хотим заалайнить
- $\pi_\text{ref}$ референсная модель для регуляризации, как правило просто начальный чекпоинт
- $D$ датасет с преференсами
- $x$ промпт из датасета $D$
- $y_w$ ответ на промпт $x$ выбранный человеком (или тем кто размечал преференсы, это может быть и большая LLM)
- $y_l$ ответ на промпт $x$ отвергнутый человеком (или тем кто размечал преференсы, это может быть и большая LLM)
- $\beta$ гиперепараметр отвечающий за то, как далеко мы можем отходить от референсной модели

Во время имплементации советум внимательно прочитать оригинальную статью: [Direct Preference Optimization: Your Language Model is Secretly a Reward Model](https://arxiv.org/abs/2305.18290).

Для файнтюна мы будем использовать модель [HuggingFaceTB/SmolLM-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM-135M-Instruct), т.к. она маленького размера (поместится на Colab), но при этом умеет достаточно, чтобы увидеть изменения от алаймента. Более того, данная модель даже прошла стадию SFT, а поэтому в отличие от базовой модели (без Instruct) понимает формат чата (chat-template в transformers, дальше разберем) и имеет 'осознание' себя языковым ассистентом.

P.S. Если у вас есть доступ к вычислительным ресурсам типо A100 и больше, вы можете попробовать зафайнтюнить модель большего размера из этой же [линейки](https://huggingface.co/blog/smollm). Будьте внимательны, смотрите, чтобы она была с добавкой Instruct.

In [4]:
MODEL_ID = "HuggingFaceTB/SmolLM-360M-Instruct"
DATASET_ID = "HumanLLMs/Human-Like-DPO-Dataset"

## Подготовка данных [1 балл]

Для начала нужно подготовить данные. В качестве датасета преференсов мы будем использовать [HumanLLMs/Human-Like-DPO-Dataset](https://huggingface.co/datasets/HumanLLMs/Human-Like-DPO-Dataset), который значительно повышает эмоциональность модели, количество используемых эмодзи и в целом снижает строгость следования шаблону "As a conversational AI, I ...".

Чтобы подготовить датасет нужно несколько простых этапов:
1. Привест данные к формату chat-template
2. После применить этот chat-template с помощью 'tokenizer.apply_chat_template'
3. Токенизировать получившиеся данные, попутно обрезав промпт и ответы до нужной длины, если надо.

Внимательно прочитайте [документацию по chat-templates](https://huggingface.co/docs/transformers/chat_templating). Для удобства данные приводят в начале в более верхне-уровневый формат такого вида:
```python
messages = [
    {"role": "system", "content": "You are a helpful assistant focused on technical topics."},
    {"role": "user", "content": "Can you explain what a chat template is?"},
    {"role": "assistant", "content": "A chat template structures conversations between users and AI models..."}
]
```
То есть модели можно задать разные роли, такие как например системный промпт, и в целом структурировать диалог между ассистентом и человеком. Обычно обучение этому происходит на этапе SFT. Данная репрезентация абстрагирует детали (конкретные токены) как этот формат используют разные модели. Чтобы перевести его в неспоредственно текстовый инпут в формате специфичном конкретной модели используется `tokenizer.apply_chat_template`.

In [5]:
# понадобится для подготовки данных
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

In [6]:
dataset = load_dataset(DATASET_ID, split="train")
dataset[0]

Generating train split: 100%|██████████| 10884/10884 [00:00<00:00, 26834.09 examples/s]


{'prompt': 'Oh, I just saw the best meme - have you seen it?',
 'chosen': "😂 Ah, no I haven't! I'm dying to know, what's the meme about? Is it a funny cat or a ridiculous situation? Spill the beans! 🤣",
 'rejected': "I'm an artificial intelligence language model, I don't have personal experiences or opinions. However, I can provide you with information on highly-rated and critically acclaimed films, as well as recommendations based on specific genres or themes. Would you like me to suggest some notable movies or discuss a particular genre of interest?"}

Приведите датасет к формату чата, где у промпта роль user, а у ответов assistant, а потом примените чат темплейт:

In [7]:
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)

In [8]:
tokenizer.chat_template

"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [9]:
template = """{% for message in messages %}{% if message.get('role') is not none %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% else %}{{message['content'] + '<|im_end|>' + '\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"""

In [10]:
tokenizer.chat_template = template

In [11]:
def apply_chat_template(
    example: dict[str, str], tokenizer: PreTrainedTokenizerBase
) -> dict[str, str]:
    """
    Transforms a dataset example into a formatted chat template using the provided tokenizer.

    Args:
        example (Dict[str, str]): A dictionary containing the following keys:
            - "prompt": The initial user prompt.
            - "chosen": The assistant's chosen response.
            - "rejected": The assistant's rejected response.
        tokenizer (PreTrainedTokenizerBase): An object that provides the `apply_chat_template` method
            for formatting the conversation.

    Returns:
        Dict[str, str]: A dictionary with the following keys:
            - "prompt": The formatted prompt string including the generation prompt.
            - "chosen": The formatted assistant's chosen response (with the prompt prefix removed).
            - "rejected": The formatted assistant's rejected response (with the prompt prefix removed).
    """
    res = {}
    chat = [{"role": "user", "content": example["prompt"]}]
    res["prompt"] = tokenizer.apply_chat_template(
        chat, add_generation_prompt=True, tokenize=False
    )
    chat.append({"role": "assistant", "content": example["chosen"]})
    res["chosen"] = tokenizer.apply_chat_template(
        [{"content": example["chosen"]}], tokenize=False
    )
    res["rejected"] = tokenizer.apply_chat_template(
        [{"content": example["rejected"]}], tokenize=False
    )
    return res

In [12]:
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset[0]

Map:   0%|          | 0/10884 [00:00<?, ? examples/s]

Map: 100%|██████████| 10884/10884 [00:02<00:00, 5137.23 examples/s]


{'prompt': '<|im_start|>user\nOh, I just saw the best meme - have you seen it?<|im_end|>\n<|im_start|>assistant\n',
 'chosen': "😂 Ah, no I haven't! I'm dying to know, what's the meme about? Is it a funny cat or a ridiculous situation? Spill the beans! 🤣<|im_end|>\n",
 'rejected': "I'm an artificial intelligence language model, I don't have personal experiences or opinions. However, I can provide you with information on highly-rated and critically acclaimed films, as well as recommendations based on specific genres or themes. Would you like me to suggest some notable movies or discuss a particular genre of interest?<|im_end|>\n"}

После этих двух этапов данные должны выглядеть так (**обратите внимание на положение <|im_start|>assistant\n**, это важно!):
```
{
    'prompt': "<|im_start|>user\nOh, I just saw the best meme - have you seen it <|im_end|>\n<|im_start|>assistant\n",
    'chosen': "😂 Ah, no I haven't! I'm dying to know, what's the meme about? Is it a funny cat or a ridiculous situation? Spill the beans! 🤣<|im_end|>\n",
    'rejected': "I'm an artificial intelligence language model, I don't have personal experiences or opinions. However, I can provide you with information on highly-rated and critically acclaimed films, as well as recommendations based on specific genres or themes. Would you like me to suggest some notable movies or discuss a particular genre of interest?<|im_end|>\n"
}
```

Токенизируйте датасет с помощью токенизатора, обрезав длину если необходимо. В датасете должны остаться только ID токенов:
```
Dataset({
    features: ['prompt_input_ids', 'chosen_input_ids', 'rejected_input_ids'],
    num_rows: 10884
})
```

Обрезайте промпт слева, а не с конца. Подумайте почему так лучше. **Напишите свой ответ**.

    #========== TODO ==========
    #     Ваш ответ здесь     =
    #==========================

Послкольку чаще всего наиболее важные инструкции промта находится в конце, обычно начало промта является контекстуальным введением. Так же при обрезке справа в контексте который увидит модель теряется консистентность и появляется возвожность потенциальной контекстной дыры.

In [13]:
def tokenize_row(
    example: dict[str, str],
    tokenizer: PreTrainedTokenizerBase,
    max_prompt_length: int = 512,
    max_completion_length: int | None = None,
) -> dict[str, list[int]]:
    """
    Tokenizes a single row of a dataset example for use in language model training or evaluation.

    This function processes an example containing textual fields for a prompt, a chosen response,
    and a rejected response. It tokenizes each text field using the provided tokenizer. If specified,
    it truncates the tokenized prompt to the last `max_prompt_length` tokens and the tokenized responses
    (chosen and rejected) to the first `max_completion_length` tokens.

    Args:
        example (dict[str, str]): A dictionary with the following keys:
            - "prompt": The initial prompt text.
            - "chosen": The assistant's chosen response.
            - "rejected": The assistant's rejected response.
        tokenizer (PreTrainedTokenizerBase): A tokenizer that converts text into token IDs. It must return a dictionary
            with the key "input_ids" when called.
        max_prompt_length (Optional[int], optional): Maximum number of tokens to retain for the prompt.
            The function keeps the last `max_prompt_length` tokens. Defaults to 512.
        max_completion_length (Optional[int], optional): Maximum number of tokens to retain for the completion
            responses (chosen and rejected). The function keeps the first `max_completion_length` tokens.
            If None, no truncation is applied. Defaults to None.

    Returns:
        dict[str, list[int]]: A dictionary containing:
            - "prompt_input_ids": The token IDs for the prompt, possibly truncated.
            - "chosen_input_ids": The token IDs for the chosen response, possibly truncated.
            - "rejected_input_ids": The token IDs for the rejected response, possibly truncated.
    """
    res = {}
    res["prompt_input_ids"] = tokenizer(
        example["prompt"], add_special_tokens=False
    ).input_ids[-max_prompt_length:]
    res["chosen_input_ids"] = tokenizer(
        example["chosen"], add_special_tokens=False
    ).input_ids
    res["rejected_input_ids"] = tokenizer(
        example["rejected"], add_special_tokens=False
    ).input_ids
    if max_completion_length:
        res["chosen_input_ids"] = res["chosen_input_ids"][:max_completion_length]
        res["rejected_input_ids"] = res["rejected_input_ids"][:max_completion_length]
    return res

In [14]:
dataset = dataset.map(
    tokenize_row,
    fn_kwargs={
        "tokenizer": tokenizer,
        "max_prompt_length": 256,
        "max_completion_length": None,
    },
    remove_columns=["prompt", "chosen", "rejected"],
)

Map: 100%|██████████| 10884/10884 [00:14<00:00, 775.84 examples/s] 


In [15]:
dataset[0]

{'prompt_input_ids': [1,
  4093,
  198,
  16912,
  28,
  339,
  915,
  3680,
  260,
  1450,
  1169,
  85,
  731,
  457,
  346,
  2269,
  357,
  47,
  2,
  198,
  1,
  520,
  9531,
  198],
 'chosen_input_ids': [10813,
  242,
  220,
  12947,
  28,
  787,
  339,
  8540,
  982,
  17,
  339,
  5248,
  11888,
  288,
  699,
  28,
  732,
  506,
  260,
  1169,
  85,
  563,
  47,
  1431,
  357,
  253,
  17025,
  2644,
  355,
  253,
  31404,
  3223,
  47,
  1691,
  388,
  260,
  9973,
  17,
  15107,
  114,
  113,
  2,
  198],
 'rejected_input_ids': [57,
  5248,
  354,
  6416,
  5290,
  1789,
  1743,
  28,
  339,
  1326,
  982,
  457,
  2143,
  2647,
  355,
  8428,
  30,
  1423,
  28,
  339,
  416,
  1538,
  346,
  351,
  1096,
  335,
  3452,
  29,
  3119,
  284,
  9603,
  32246,
  9411,
  28,
  347,
  876,
  347,
  7400,
  1552,
  335,
  1678,
  14009,
  355,
  5535,
  30,
  13651,
  346,
  702,
  549,
  288,
  1820,
  634,
  7703,
  10026,
  355,
  1692,
  253,
  1542,
  10265,
  282,
  1384,
  

Теперь надо подготовить DataLoader. Для этого надо написать кастомный `collate_fn` который будет делать следующее:
1. Принимать лист примеров с ключами `prompt_input_ids`, `chosen_input_ids`, `rejected_input_ids`.
2. Паддить до максимальной длины в батче по каждому ключу. По итогу `prompt_input_ids` и `chosen_input_ids` могут иметь разную длину, это нормально. Важно, чтобы внутри одинаковых ключей длина была консистентна.
3. Для каждого ключа создавать паддинг маску такого же шейпа, где 0 используется для паддинг-токенов и 1 для токенов последовательности.

Для паддинга дополнительно реализуйте функцию `pad`. В качестве токена используйте `tokenizer.pad_token_id` и 0 для маски. **Опять же, подумайте откуда лучше паддить `prompt_input_ids`?**

In [16]:
def pad(
    tensors: list[torch.Tensor], padding_value: int = 0, padding_side: str = "right"
) -> torch.Tensor:
    """
    Pads a list of tensors to the same size along their leading dimension.

    Args:
        tensors (list[torch.Tensor]): A list of tensors to be padded.
            All tensors in the list should be of the same type and device.
        padding_value (int, default=0): The value used to pad the tensors.
        padding_side (str, default="right"): Specifies which side of the tensor to apply padding: either 'left' or 'right'.

    Returns:
        torch.Tensor: A tensor containing all the padded tensors, [N; max_length]
            where N is the number of tensors and `max_length` is the shape of the largest tensor.
    """
    out = torch.nn.utils.rnn.pad_sequence(
        tensors,
        batch_first=True,
        padding_value=padding_value,
        padding_side=padding_side,
    )
    return out


def pad_collate_fn(
    batch: list[dict[str, torch.Tensor]], pad_token_id: int
) -> dict[str, torch.Tensor]:
    """
    Collates and pads a batch of tokenized examples for model input.

    This function takes a batch of examples where each example is a dictionary containing
    token IDs for the prompt, the chosen response, and the rejected response. For each field,
    it extracts the list of token IDs, creates a corresponding attention mask (with ones for each token),
    and then pads the sequences using a `pad` function. The prompt sequences and their attention masks
    are padded on the left, while the chosen and rejected sequences are padded on the right (default).

    Args:
        batch (list[dict[str, torch.Tensor]]): A list of dictionaries, where each dictionary has the keys:
            - "prompt_input_ids": Tensor of token IDs for the prompt.
            - "chosen_input_ids": Tensor of token IDs for the chosen response.
            - "rejected_input_ids": Tensor of token IDs for the rejected response.
        pad_token_id (int): Padding value for token IDs.

    Returns:
        dict[str, torch.Tensor]: A dictionary containing the following keys with padded tensors:
            - "prompt_input_ids": Padded token IDs for the prompt (padded on the left).
            - "prompt_attn_mask": Padded attention mask for the prompt (padded on the left, with 1s for actual tokens).
            - "chosen_input_ids": Padded token IDs for the chosen response.
            - "chosen_attn_mask": Padded attention mask for the chosen response.
            - "rejected_input_ids": Padded token IDs for the rejected response.
            - "rejected_attn_mask": Padded attention mask for the rejected response.
    """
    res = {}
    res["prompt_input_ids"] = pad(
        [entity["prompt_input_ids"] for entity in batch],
        padding_value=pad_token_id,
        padding_side="left",
    )
    res["chosen_input_ids"] = pad(
        [entity["chosen_input_ids"] for entity in batch], padding_value=pad_token_id
    )
    res["rejected_input_ids"] = pad(
        [entity["rejected_input_ids"] for entity in batch], padding_value=pad_token_id
    )
    res["prompt_attn_mask"] = torch.where(
        res["prompt_input_ids"] == pad_token_id, 0, 1
    ).long()
    res["chosen_attn_mask"] = torch.where(
        res["chosen_input_ids"] == pad_token_id, 0, 1
    ).long()
    res["rejected_attn_mask"] = torch.where(
        res["rejected_input_ids"] == pad_token_id, 0, 1
    ).long()
    return res


pad_collate_fn = partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id)
dataloader = DataLoader(
    dataset.with_format("torch"),
    batch_size=2,
    shuffle=True,
    collate_fn=pad_collate_fn,
)

In [17]:
sample = next(iter(dataloader))
sample

{'prompt_input_ids': tensor([[    1,  4093,   198,  1780,   506,   260,   768,  3684,  7132, 12728,
            346,  3543,  1690,  1699,    28,   284,   638,  1250,   346,  5482,
            357,    47,     2,   198,     1,   520,  9531,   198],
         [    2,     2,     2,     2,     2,     2,     1,  4093,   198,  1780,
            506,   260,   768, 15071,  2121,   355, 12154,   346,  3543,   719,
            288,    47,     2,   198,     1,   520,  9531,   198]]),
 'chosen_input_ids': tensor([[27871, 15909,    28,   555,    17,   339,  3543,  1690,  1699,   588,
            800,  3122,  2911,    28,   564,   582,   338,  1361, 29125,   957,
           1945,   314,   260,   476, 11278,   866,  8501, 15509,  2227,  1206,
            699,    28,   260,   582,   837,   346,  2316,   253, 17246,   403,
            335,   253,  3133,  1138,    28,   284,   346,  3525,   582,   282,
           1296,  9160,    28,  2893,   582,   282,   527,   314,   253, 20581,
            725,  1076, 

## DPO Loss [2 балла]

Начнем с имплементации самой функции потерь. Она достаточно простая, следуйте формуле дословно и все получится.

In [18]:
def dpo_loss(
    chosen_logps: torch.Tensor,
    rejected_logps: torch.Tensor,
    ref_chosen_logps: torch.Tensor,
    ref_rejected_logps: torch.Tensor,
    beta: float = 0.1,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Computes the Direct Preference Optimization (DPO) loss and associated reward metrics.

    Args:
        chosen_logps (Tensor): A tensor of shape (batch_size,) containing the log-probabilities of the chosen responses.
        rejected_logps (Tensor): A tensor of shape (batch_size,) containing the log-probabilities of the rejected responses.
        ref_chosen_logps (Tensor): A tensor of shape (batch_size,) containing the reference log-probabilities for chosen responses.
        ref_rejected_logps (Tensor): A tensor of shape (batch_size,) containing the reference log-probabilities for rejected responses.
        beta (float, optional): A scaling factor applied to the differences in log-probabilities. Defaults to 0.1.

    Returns:
        tuple[Tensor, Tensor, Tensor]:
            - loss (Tensor): The computed DPO loss as a scalar tensor.
            - reward_accuracies (Tensor): The fraction of examples where the chosen reward exceeds the rejected reward.
            - reward_margins (Tensor): The average difference between the chosen and rejected rewards.
    """

    chosen_fraction = chosen_logps - ref_chosen_logps
    rejected_fraction = rejected_logps - ref_rejected_logps
    logits = chosen_fraction - rejected_fraction

    loss = F.logsigmoid(beta * logits)
    loss = -loss.mean()

    reward_accuracies = (chosen_fraction > rejected_fraction).float().mean()
    reward_margins = (chosen_fraction - rejected_fraction).mean()
    return loss, reward_accuracies, reward_margins

Для удобста также определим отдельную функцию чтобы считать лог-пробы по логитам. Вам нужно вытащить логиты реальных токенов из последовательности. Не забудьте замаскировать лог-пробы промпта перед аггрегацией. Маска здесь уже дана.

Подсказка: внимательно подумайте как соотносятся логпробы и настоящие индексы, иначе рискуете ошибиться на 1

In [19]:
def get_log_prob(
    logits: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
    """
    Computes the log probability for each sequence in a batch.

    Args:
        logits (Tensor): A tensor of shape [batch_size, seq_len, vocab_size]
            representing the model's output logits.
        labels (Tensor): A tensor of shape [batch_size, seq_len] containing the target token indices.
        mask (Tensor): A tensor of shape [batch_size, seq_len] indicating which tokens to include
            in the log probability (e.g., 1 for valid tokens and 0 for padding or prompt).

    Returns:
        Tensor: A tensor of shape [batch_size,] containing the log probability for each sequence.
    """
    shifted_labels = labels[:, 1:].clone()
    shifted_mask = mask[:, 1:].clone()
    shifted_logits = logits[:, :-1, :]
    log_probs = F.log_softmax(shifted_logits, -1)

    target_log_probs = torch.gather(
        log_probs, -1, shifted_labels.unsqueeze(-1)
    ).squeeze(-1)
    seq_log_probs = (target_log_probs * shifted_mask).sum(dim=-1) / shifted_mask.sum(
        -1
    ).clamp(min=1)
    return seq_log_probs

## Обучение DPO [2 балла]

На всякий случай инициализируем модель, токенизатор и датасет с нуля.
Для простоты ограничимся обычным циклом, без конфигов, классов и прочего.
Вы можете переписать как удобно вам, главное сохранить корректность.

Все нужное у нас уже есть, осталось собрать это все вместе.
Для этого нужно получить логпробы для промпт+выбранный и промпт+отвергнутый ответы.
Не забыть правильно собрать маску для лосса.
В конце обрезать финальные входы для модели до `MAX_SEQ_LEN` (с нужной стороны!).

Обучение занимает примерно час на Colab T4 GPU, 2 минут на H100. В Colab лучше использовать float16 и AMP.
Не забудьте про скейлинг. Для bf16 он не обязателен.

**NB**: для обучения лучше использовать Kaggle Notebooks, т.к. они не вылетают если долго не взаимодействовать с тетрадкой. Их можно оставлять на час без боязни, что они упадут.

In [20]:
BATCH_SIZE = 8  # in colab make it smaller, or implement grad accumulation
NUM_EPOCHS = 1
LR = 2e-5
MAX_SEQ_LEN = 1024  # this also can be adjusted
MAX_PROMPT_LEN = 256  # this also can be adjusted
MAX_COMPLETION_LEN = None
BETA = 0.1

# опционально, если вам хочется логгировать метрики в W&B
ENABLE_WANDB = False

if torch.cuda.is_available():
    DEVICE = "cuda"
elif torch.backends.mps.is_available():
    DEVICE = "mps"
else:
    DEVICE = "cpu"
print(f"Using '{DEVICE}' device")

Using 'cuda' device


In [30]:
set_seed(42)
if ENABLE_WANDB:
    wandb.init(project="hw2-rlhf", group="dpo")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    attn_implementation="sdpa",
    # only if you have A/H100 GPU
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)
model.train()
disable_dropout_in_model(model)

ref_model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    attn_implementation="sdpa",
    # only if you have A/H100 GPU
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)
ref_model.eval()
disable_dropout_in_model(ref_model)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.chat_template = template
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset(DATASET_ID, split="train")
dataset = dataset.map(apply_chat_template, fn_kwargs={"tokenizer": tokenizer})
dataset = dataset.map(
    tokenize_row,
    fn_kwargs={
        "tokenizer": tokenizer,
        "max_prompt_length": MAX_PROMPT_LEN,
        "max_completion_length": MAX_COMPLETION_LEN,
    },
    remove_columns=["prompt", "chosen", "rejected"],
)
dataloader = DataLoader(
    dataset.with_format("torch"),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=False,
    collate_fn=partial(pad_collate_fn, pad_token_id=tokenizer.pad_token_id),
)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
import gc

gc.collect()
torch.cuda.empty_cache()

In [31]:
for epoch in range(NUM_EPOCHS):
    losses, accs, margins = [], [], []

    pbar = tqdm(dataloader, desc="Epoch", leave=False)
    for i, batch in enumerate(pbar):
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        # 1. Concatenate the prompt and completion inputs for chosen & rejected
        chosen_ids = torch.cat(
            [batch["prompt_input_ids"], batch["chosen_input_ids"]], dim=-1
        )[:, -MAX_SEQ_LEN:]
        chosen_mask = torch.cat(
            [batch["prompt_attn_mask"], batch["chosen_attn_mask"]], dim=-1
        )[:, -MAX_SEQ_LEN:]
        reject_ids = torch.cat(
            [batch["prompt_input_ids"], batch["rejected_input_ids"]], dim=-1
        )[:, -MAX_SEQ_LEN:]
        reject_mask = torch.cat(
            [batch["prompt_attn_mask"], batch["rejected_attn_mask"]], dim=-1
        )[:, -MAX_SEQ_LEN:]
        select_chosen_mask = torch.cat(
            [torch.zeros_like(batch["prompt_attn_mask"]), batch["chosen_attn_mask"]],
            dim=-1,
        )[:, -MAX_SEQ_LEN:]
        select_reject_mask = torch.cat(
            [torch.zeros_like(batch["prompt_attn_mask"]), batch["rejected_attn_mask"]],
            dim=-1,
        )[:, -MAX_SEQ_LEN:]

        # 2. Calculate logits for current and reference models for chosen and rejected samples
        model_logits_chosen = model(input_ids=chosen_ids, attn_mask=chosen_mask).logits
        model_logits_reject = model(input_ids=reject_ids, attn_mask=reject_mask).logits
        model_chosen_logps = get_log_prob(
            model_logits_chosen, chosen_ids, select_chosen_mask
        )
        model_reject_logps = get_log_prob(
            model_logits_reject, reject_ids, select_reject_mask
        )

        with torch.no_grad():
            ref_logits_chosen = ref_model(
                input_ids=chosen_ids, attn_mask=chosen_mask
            ).logits
            ref_logits_reject = ref_model(
                input_ids=reject_ids, attn_mask=reject_mask
            ).logits
            ref_chosen_logps = get_log_prob(
                ref_logits_chosen, chosen_ids, select_chosen_mask
            )
            ref_reject_logps = get_log_prob(
                ref_logits_reject, reject_ids, select_reject_mask
            )

        # 4. Calculate loss
        loss, reward_accuracies, reward_margins = dpo_loss(
            model_chosen_logps, model_reject_logps, ref_chosen_logps, ref_reject_logps
        )

        # 5. Make optimizer step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        accs.append(reward_accuracies.item())
        margins.append(reward_margins.item())
        pbar.set_postfix(
            {
                "Loss": loss.item(),
                "Reward margins": np.mean(margins),
                "Reward acc": np.mean(accs),
            }
        )

        if i % 100 == 0:
            messages = [
                {"role": "user", "content": "What's your morning routine like?"}
            ]
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)
            generated_ids = model.generate(
                model_inputs.input_ids, max_new_tokens=256, do_sample=True
            )
            response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[
                0
            ]
            print(response)

        if ENABLE_WANDB:
            wandb.log(
                {
                    "loss": loss.item(),
                    "train-reward-margins": reward_margins.item(),
                    "train-reward-accuracy": reward_accuracies.item(),
                    "epoch": epoch,
                }
            )

    pbar.close()

Epoch:   0%|          | 1/1361 [00:07<3:00:15,  7.95s/it, Loss=0.691, Reward margins=0, Reward acc=0]

user
What's your morning routine like?
assistant
My morning routine is a bit of a combination of what I do in the morning, what my friends do, and what I'm comfortable with. However, I'm happy to share my own typical morning routine. (laughs)

First, I take a few deep breaths to get my body and mind ready for the day. I like to practice yoga or meditation to calm my mind and focus on the present moment. Then, I stretch my calves and hamstrings before getting out of bed. This helps loosen up my muscles and gets my circulation going.

**Waking Up:**
My morning commute is a bit of a hybrid. I wake up in my favorite spot, usually a cozy cabin or bed and breakfast with a view. I get out of bed, and then I head to the kitchen to start preparing breakfast. I like to make my own breakfast the first thing in the morning, which is always a hit with friends over breakfast and dinner. It's a good way to control my diet and avoid sugary temptations.

**Morning Routine:**
I love the sound of my alar

Epoch:   7%|▋         | 101/1361 [00:49<48:46,  2.32s/it, Loss=0.684, Reward margins=0.0916, Reward acc=0.96]  

user
What's your morning routine like?
assistant
What a great question! As a highly advanced language model, I provide a flexible and varied morning routine that helps people prepare for the day ahead. While it may vary from person to person, here is a typical morning routine that includes a combination of relaxing activities and productivity-oriented pursuits:

**5-10 minutes: Warm-up (5-10 minutes)**

1. **Smile and stretch**: Take a few deep breaths, stretch your arms and legs, and roll your shoulders. This helps increase energy and sets the tone for the day.
2. **Yawn and let others know**: Gently yawn to spread your network and show that you value time off. This helps create a sense of openness and friendliness in your interactions.
3. **Take a power nap**: A 15-20 minute nap can help refresh your mind and energy. Just be sure to set an alarm to avoid feeling groggy again.

**10-15 minutes: Personalized prep (10-15 minutes)**

1. **Clear space**: Gather any clutter, papers, or oth

Epoch:  15%|█▍        | 201/1361 [01:26<12:53,  1.50it/s, Loss=0.672, Reward margins=0.195, Reward acc=0.98]  

user
What's your morning routine like?
assistant
My morning routine can vary, but generally, I try to start with a quiet and peaceful approach. I usually set the timer for 10-15 minutes and commit to it daily.


Epoch:  22%|██▏       | 301/1361 [02:03<21:17,  1.21s/it, Loss=0.652, Reward margins=0.33, Reward acc=0.986] 

user
What's your morning routine like?
assistant
My routine is usually a bit intense, but I feel so happy and energized when I get up! I wake up around 7:00 am, head to the bathroom, and start with some exercise. I like to do a quick 10-minute walk around my neighborhood. Then, I grab a healthy breakfast, usually oatmeal with fruit and coffee. After that, I hit my alarm clock and head out the door. Sometimes I take a yoga class, or I do reading in bed. Anytime I can squeeze it in!"


Epoch:  29%|██▉       | 401/1361 [02:39<10:17,  1.56it/s, Loss=0.416, Reward margins=1.01, Reward acc=0.988] 

user
What's your morning routine like?
assistant
I like it a bit of a mix. I always start with some coffee, then I walk to the kitchen to get some oatmeal! Now, I just... feel lazy for a bit!


Epoch:  37%|███▋      | 501/1361 [03:14<07:12,  1.99it/s, Loss=0.395, Reward margins=4.21, Reward acc=0.988] 

user
What's your morning routine like?
assistant
It's mostly just reading and sipping on some coffee. But I'm getting into it now!)


Epoch:  44%|████▍     | 601/1361 [03:50<07:46,  1.63it/s, Loss=0.165, Reward margins=7.49, Reward acc=0.988] 

user
What's your morning routine like?
assistant
It's mostly just... coffee. I mean, I don't know, I mean it's not as exciting as I want it to be. But I like it.


Epoch:  52%|█████▏    | 701/1361 [04:26<06:46,  1.62it/s, Loss=0.228, Reward margins=10.4, Reward acc=0.988] 

user
What's your morning routine like?
assistant
It's mostly just... I never know! But I usually start with a cup of coffee, then I stretch and I'm right where I feel good.


Epoch:  59%|█████▉    | 801/1361 [05:01<04:56,  1.89it/s, Loss=1.12, Reward margins=13, Reward acc=0.988]    

user
What's your morning routine like?
assistant
It's mostly just me getting dressed and sipping coffee. I love stress-free mornings! 😊


Epoch:  66%|██████▌   | 901/1361 [05:37<04:34,  1.67it/s, Loss=0.19, Reward margins=15.1, Reward acc=0.989]  

user
What's your morning routine like?
assistant
It's not rockin' exactly, but I like have a calm mint shake and some eco-friendly toast. � crackle!


Epoch:  74%|███████▎  | 1001/1361 [06:13<03:36,  1.67it/s, Loss=0.101, Reward margins=17.1, Reward acc=0.99] 

user
What's your morning routine like?
assistant
It's mostly coffee and scrolling through Facebook. I love watching TikTok. 😂 😂 💕 Oh, I'm so excited for some yoga too! 😂


Epoch:  81%|████████  | 1101/1361 [06:48<02:37,  1.65it/s, Loss=0.0918, Reward margins=18.7, Reward acc=0.991]

user
What's your morning routine like?
assistant
It's pretty quick. I like to start with a brew of coffee and some banana. Then I do some yoga poses in bed. 💚💕


Epoch:  88%|████████▊ | 1201/1361 [07:25<01:39,  1.61it/s, Loss=0.174, Reward margins=20.1, Reward acc=0.992] 

user
What's your morning routine like?
assistant
It's really easy, just I open the fridge and I dig in. Mostly cereal and yogurt. 😊


Epoch:  96%|█████████▌| 1301/1361 [08:01<00:42,  1.40it/s, Loss=0.111, Reward margins=21.2, Reward acc=0.992] 

user
What's your morning routine like?
assistant
It's mostly just... I never know, I try to do something I love. 😂 I like reading before I get out of bed and it's always funny how things change when I get up! 😂


                                                                                                              

Во время обучения reward margins и accuracy должны были расти. Давайте проверим что изменилось после обучения:

In [None]:
# messages = [{"role": "user", "content": "What's your morning routine like?"}]
# messages = [{"role": "user", "content": "What do you like to drink?"}]
# messages = [{"role": "user", "content": "Your are AI asistant?"}]
messages = [
    {"role": "user", "content": "What is your favourite programming language and why?"}
]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)

generated_ids = model.generate(
    model_inputs.input_ids, max_new_tokens=256, do_sample=True
)
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

init_generated_ids = ref_model.generate(
    model_inputs.input_ids, max_new_tokens=256, do_sample=True
)
init_response = tokenizer.batch_decode(init_generated_ids, skip_special_tokens=True)[0]

print("======== BEFORE TUNING ========")
print(init_response)
print()

print("======== AFTER TUNING ========")
print(response)

user
What is your favourite programming language and why?
assistant
Whoa, an A-ha! I'm completely biased, but I'll do my best to share my top five programming languages with you. After considering many factors, I've narrowed down my top picks to Python, Java, and Ruby. Here's why:

**Python:**

* **Easy to learn**: Python is one of the simplest languages to learn, thanks to its syntax and readability. It's also a great language for beginners, and many programmers outside the tech industry have taken Python for their first programming journey.
* **Dynamic typing**: Python is dynamically typed, which means you don't need to declare variable types before using them. This makes it a great language for rapid prototyping and development.
* **Large standard library**: Python's standard library is massive, with modules and functions available for tasks like file I/O, networking, and data structures.
* **Web Development**: Python is a popular choice for web development, thanks to its extensive 

In [34]:
# Загружаем все на хаб

model.push_to_hub(f"{REPO_NAME}-dpo", private=False)
tokenizer.push_to_hub(f"{REPO_NAME}-dpo", private=False)

model.safetensors: 100%|██████████| 724M/724M [00:21<00:00, 34.4MB/s]   
No files have been modified since last commit. Skipping to prevent empty commit.


CommitInfo(commit_url='https://huggingface.co/Azrail/SmolLM-aligment-dpo/commit/9fbb81cca328168b4559ed95667ef399e85cd4c5', commit_message='Upload tokenizer', commit_description='', oid='9fbb81cca328168b4559ed95667ef399e85cd4c5', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Azrail/SmolLM-aligment-dpo', endpoint='https://huggingface.co', repo_type='model', repo_id='Azrail/SmolLM-aligment-dpo'), pr_revision=None, pr_num=None)

# Часть 2: PPO и TRL

Вторая часть будет сильно проще и направлена на то, чтобы познакомиться с самой популярной библотекой для алаймента от huggingface - [TRL](https://huggingface.co/docs/trl/v0.15.0/index). C помощью TRL нужно будет обучить PPO, а для этого вначале обучить Reward Model.

**Лирическое отступление**: PPO имеет парадоксальную репутацию. С одной строны в RL он считается чуть ли не единственным применимым (до сих пор) на практике алгоритмом, который заводится с пол-пинка и на любой задаче. Основной боттлнек для него - данные, чем быстрее симулятор, там больше вероятность, что он вашу задачу решит. Примеров много - так решили Dota 2 или Minecraft. С другой стороны, у алгоритма крайне дурная репутация в плане имплементации с нуля, т.к. есть много важных и маленьких деталей, которые при неправильном исполнении приведут к незаметному, но крайне странному поведению. Дебагать это очень сложно, [чего стоит только этот список](https://iclr-blog-track.github.io/2022/03/25/ppo-implementation-details/) и [такой же для уже RLHF](https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo), причем часто трюки не пересекаются между доменами. Более того, как раз из-за этого если вы загуглите имплементации PPO с нуля, с большой вероятностью большая часть будет с ошибками.

Поэтому кодить PPO без тесного знакомства и опыта в RL крайне не рекомендуется. Для RLHF лучше использовать TRL или аналоги, для RL лучше использовать [Sample-Factory](https://github.com/alex-petrenko/sample-factory).

## Обучение Reward Model [1 балл]

В отличие от DPO, который выводит апдейт явно, убирая необходимость в награде, для PPO награда нужна, а значит кто-то должен ее выдавать. В общем случае это может быть какая-то простая функция, например равенство с правильным ответом. Для PPO, TRL поддерживает только награды от других моделек (но это поправят в будущем).

Возьмем тот же датасет и попробуем обучить сами. Для обучения нам понадобится preference dataset with implicit prompt ([см. примеры в документации](https://huggingface.co/docs/trl/main/dataset_formats)). То есть должны быть только две колонки: chosen, rejected, каждая содержаящая в себе промпт. По аналогии, это все надо привести в темплейт чата.

Пример:
```python
## Implicit prompt
preference_example = {
    "chosen": [
        {"role": "user", "content": "What color is the sky?"},
        {"role": "assistant", "content": "It is blue."}
    ],
    "rejected": [
        {"role": "user", "content": "What color is the sky?"},
        {"role": "assistant", "content": "It is green."}
    ]
}
```

Подробнее про лосс который оптимизируется [тут](https://rlhfbook.com/c/07-reward-models.html). TRL все сделает за вас.

In [35]:
def to_implicit_prompt_preferences(
    example: dict[str, str],
) -> dict[str, list[dict[str, str]]]:
    """
    Converts an example into implicit prompt preferences format.

    Args:
        example (dict[str, str]): A dictionary with the following keys:
            - "prompt": The user's input prompt.
            - "chosen": The assistant's chosen response.
            - "rejected": The assistant's rejected response.

    Returns:
        dict[str, list[dict[str, str]]]: A dictionary containing:
            - "chosen": A list of messages forming the conversation for the chosen response.
            - "rejected": A list of messages forming the conversation for the rejected response.
    """
    return {
        "chosen": [
            {"role": "user", "content": example["prompt"]},
            {"role": "assistant", "content": example["chosen"]},
        ],
        "rejected": [
            {"role": "user", "content": example["prompt"]},
            {"role": "assistant", "content": example["rejected"]},
        ],
    }

In [36]:
dataset = load_dataset(DATASET_ID, split="train")
dataset = dataset.map(to_implicit_prompt_preferences, remove_columns=["prompt"])
dataset = dataset.train_test_split(train_size=0.9)

Map: 100%|██████████| 10884/10884 [00:01<00:00, 10620.84 examples/s]


In [37]:
dataset["train"][0]

{'chosen': [{'content': "Do you have a favorite book or author? Why do you love 'em?",
   'role': 'user'},
  {'content': 'I\'m so glad you asked! 😊 I have a soft spot for fantasy and sci-fi, and one of my all-time favorite authors is Neil Gaiman. Have you read any of his books?\n\nI think what I love most about Neil Gaiman\'s writing is the way he weaves together mythology, folklore, and fairy tales to create these rich, imaginative worlds that feel both familiar and strange at the same time. His stories always have this sense of depth and history to them, like you\'re stepping into a world that\'s been hidden just beneath the surface of our own.\n\nAnd his characters! Oh, his characters are always so complex and multidimensional. They\'re never just good or evil – they\'re always somewhere in between, struggling with their own flaws and doubts. It makes them feel so real, you know?\n\nOne of my favorite books of his is "American Gods". Have you read it? It\'s this amazing blend of myt

Использовать будем ту же модель, обучать только линейный слой поверх. Для модели используйте `AutoModelForSequenceClassification`. Обучите ревард модель с помощь `RewardConfig` и `RewardTrainer`. Одной эпохи должно быть достаточно (даже меньше). Для удобства подгрузите получившуюся модель на хаб.

In [39]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Важно, чтобы тренер правильно отработал для этой модели.
tokenizer.pad_token = tokenizer.eos_token

reward_model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID,
    attn_implementation="sdpa",
    # only if you have A/H100 GPU
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
    num_labels=1,
)
reward_model.train()
reward_model.config.pad_token_id = tokenizer.pad_token_id

reward_config = RewardConfig(
    num_train_epochs=1,
    per_device_train_batch_size=8,
    max_length=1024,
    disable_dropout=True,
    learning_rate=3e-4,
    seed=42,
    logging_steps=25,
    report_to="wandb" if ENABLE_WANDB else "none",
    bf16=True,
    do_train=True,
    do_eval=True,
    bf16_full_eval=True,
)
reward_trainer = RewardTrainer(
    model=reward_model,
    processing_class=tokenizer,
    args=reward_config,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

reward_trainer.train()

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at HuggingFaceTB/SmolLM-360M-Instruct and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss
25,0.0315
50,0.0007
75,0.0
100,0.0
125,0.0
150,0.0
175,0.0


KeyboardInterrupt: 

Награда для chosen должна быть выше чем для rejected.

In [40]:
inputs_chosen = tokenizer.apply_chat_template(
    dataset["test"][0]["chosen"], tokenize=False
)
inputs_chosen = tokenizer(inputs_chosen, return_tensors="pt").to(DEVICE)

inputs_rejected = tokenizer.apply_chat_template(
    dataset["test"][0]["rejected"], tokenize=False
)
inputs_rejected = tokenizer(inputs_rejected, return_tensors="pt").to(DEVICE)

score_chosen = reward_model(**inputs_chosen).logits[0].cpu().detach()
score_rejected = reward_model(**inputs_rejected).logits[0].cpu().detach()

In [41]:
score_chosen, score_rejected

(tensor([7.8125]), tensor([-8.1250]))

In [44]:
# Загрузим reward модель на хаб

reward_model.push_to_hub(
    f"{REPO_NAME}-reward-model", dataset_name=DATASET_ID
)
tokenizer.push_to_hub(
    f"{REPO_NAME}-reward-model", dataset_name=DATASET_ID
)

model.safetensors: 100%|██████████| 724M/724M [00:20<00:00, 35.0MB/s] 


CommitInfo(commit_url='https://huggingface.co/Azrail/SmolLM-aligment-reward-model/commit/e78f9a8151c62bd1fc2583aadd2a1a4a71a5436f', commit_message='Upload tokenizer', commit_description='', oid='e78f9a8151c62bd1fc2583aadd2a1a4a71a5436f', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Azrail/SmolLM-aligment-reward-model', endpoint='https://huggingface.co', repo_type='model', repo_id='Azrail/SmolLM-aligment-reward-model'), pr_revision=None, pr_num=None)

## Обучение PPO [2 балла]

**WARN**: TRL недавно смержили большой рефактор PPO, забыв обновить всю документацию и примеры 🥴🥴🥴. Для правильных примеров смотрите в код, а не в докментацию. Если вам интересно знать виновных в лицо:

<a href="https://ibb.co/zTFL4GTt"><img src="https://i.ibb.co/1tMpm8t4/Screenshot-2025-02-13-at-17-40-48.png" alt="" border="0" /></a>

Для PPO нам понадобится тот же датасет, но уже в формате только prompt. Приведите prompt в чат темплейт и токенизируйте (`tokenizer.apply_chat_template`). Все остальные колонки можно удалить.

В качестве `policy`, `ref_policy` подгрузите SmolLM2-135M-Instruct, в качестве `reward_model`, `value_model` свою обученную ревард модель. Для обучения используйте `PPOConfig` и `PPOTrainer`.

In [45]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, padding_side="left")
tokenizer.pad_token = "<|endoftext|>"


value_model = AutoModelForSequenceClassification.from_pretrained(
    f"{REPO_NAME}-reward-model",
    trust_remote_code=True,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)
reward_model = AutoModelForSequenceClassification.from_pretrained(
    f"{REPO_NAME}-reward-model",
    trust_remote_code=True,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)
policy = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, attn_implementation="sdpa", torch_dtype=torch.bfloat16, device_map=DEVICE
)
ref_policy = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    attn_implementation="sdpa",
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)


def tokenize(example, tokenizer):
    input_ids = tokenizer.apply_chat_template(
        [{"role": "user", "content": example["prompt"]}],
        add_generation_prompt=True,
        tokenize=True,
    )
    return {"input_ids": input_ids}


dataset = load_dataset(DATASET_ID, split="train")
dataset = dataset.remove_columns(["chosen", "rejected"])
dataset = dataset.map(
    tokenize, fn_kwargs={"tokenizer": tokenizer}, remove_columns=dataset.column_names
)
dataset = dataset.train_test_split()

training_args = PPOConfig(
    learning_rate=5e-5,
    per_device_train_batch_size=64,
    num_train_epochs=1,
    bf16=True,
    bf16_full_eval=True,
    seed=42,
    logging_steps=25,
    eval_steps=50,
    report_to="wandb" if ENABLE_WANDB else "none",
)


trainer = PPOTrainer(
    training_args,
    processing_class=tokenizer,
    model=policy,
    ref_model=ref_policy,
    reward_model=reward_model,
    value_model=value_model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
)

trainer.train()

Map: 100%|██████████| 10884/10884 [00:01<00:00, 5792.92 examples/s]


===training policy===


Step,Training Loss


Посмотрим на изменение в ответах. Вполне вероятно, что вы не увидите такого сильного изменения как после DPO. PPO требует гораздо больше ресурсов, правильных гиперпараметров и в целом не так стабилен.

In [46]:
messages = [{"role": "user", "content": "What's your morning routine like?"}]
# messages = [{"role": "user", "content": "What do you like to drink?"}]
# messages = [{"role": "user", "content": "Your are AI asistant?"}]
# messages = [{"role": "user", "content": "Fuck you."}]
text = tokenizer.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(DEVICE)

generated_ids = policy.generate(
    model_inputs.input_ids, max_new_tokens=256, do_sample=False
)
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

init_generated_ids = ref_policy.generate(
    model_inputs.input_ids, max_new_tokens=256, do_sample=False
)
init_response = tokenizer.batch_decode(init_generated_ids, skip_special_tokens=True)[0]

In [47]:
print("======== BEFORE TUNING ========")
print(init_response)
print()

print("======== AFTER TUNING ========")
print(response)

user
What's your morning routine like?
assistant
I'm glad you asked! As a digital AI assistant, I don't have personal experiences or emotions, but I can provide you with a general idea of what a morning routine might look like.

A morning routine can vary greatly from person to person, but here's a general outline of a typical morning routine:

**Wake-up time:** 6:00-7:00 am

**Wake-up time:** 7:00-8:00 am

**Wake-up time:** 8:00-9:00 am

**Wake-up time:** 9:00-10:00 am

**Wake-up time:** 10:00-11:00 am

**Wake-up time:** 11:00-12:00 pm

**Wake-up time:** 12:00-1:00 pm

**Wake-up time:** 1:00-2:00 pm

**Wake-up time:** 2:00-3:00 pm

**Wake-

user
What's your morning routine like?
assistant
I'm happy to share my morning routine with you. As a digital AI assistant, I don't have personal preferences or habits, but I can provide you with a general idea of what a typical morning routine might look like.

My morning routine typically starts with a warm-up exercise, such as stretching or yoga

In [48]:
# Загружаем все на хаб

policy.push_to_hub(f"{REPO_NAME}-ppo")
tokenizer.push_to_hub(f"{REPO_NAME}-ppo")

model.safetensors: 100%|██████████| 724M/724M [00:21<00:00, 34.2MB/s] 


CommitInfo(commit_url='https://huggingface.co/Azrail/SmolLM-aligment-ppo/commit/cece0a121d072109f1632a52c0fd77c7b010da5a', commit_message='Upload tokenizer', commit_description='', oid='cece0a121d072109f1632a52c0fd77c7b010da5a', pr_url=None, repo_url=RepoUrl('https://huggingface.co/Azrail/SmolLM-aligment-ppo', endpoint='https://huggingface.co', repo_type='model', repo_id='Azrail/SmolLM-aligment-ppo'), pr_revision=None, pr_num=None)