# RLHF Fine-Tuning

## Load the SFT and Reward Models

In [1]:
%cp /content/drive/MyDrive/copy\ files/reward_model.pt .
%cp /content/drive/MyDrive/copy\ files/sft_model_epoch_1.zip .

In [None]:
!unzip sft_model_epoch_1.zip

## Reward model

In [3]:
import torch
from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class RewardHead(nn.Module):
    """
    The RewardHead class implements a head for GPT2
    that returns a scalar for each output token.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.reward = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.reward.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.reward.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.reward(output)


class GPT2RewardModel(nn.Module):
    """
    GPT2 model with a reward head on top.
    """

    def __init__(self, model_name):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        # config = self.llm.config
        # Add the reward head
        self.reward_head = RewardHead(self.llm.config)

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:

        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )

        # Get the last hidden state
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # Apply the reward head
        rewards = self.reward_head(last_hidden_state).squeeze(-1)

        return torch.sigmoid(rewards)



In [4]:
model_name = "gpt2"
reward_model = GPT2RewardModel(model_name)
reward_model.load_state_dict(torch.load("reward_model.pt", map_location='cpu'))

<All keys matched successfully>

## Model with Value Head

In [5]:
import torch
from typing import Optional
from torch import nn
import numpy as np
from transformers import AutoModelForCausalLM

class ValueHead(nn.Module):
    """
    The ValueHead class implements a head for GPT2
    that returns a scalar for each output token.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.value = nn.Linear(self.hidden_size, 1)
        self._post_init()

    def _post_init(self):
        nn.init.normal_(self.value.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.value.bias)

    def forward(self, hidden_states):
        output = hidden_states
        return self.value(output)


class ModelForCausalLMWithValueHead(nn.Module):
    """
    GPT2 model with a value head on top.
    """

    def __init__(self, model_path):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_path)
        # config = self.llm.config
        # Add the reward head
        self.v_head = ValueHead(self.llm.config)

    def forward(
        self,
        input_ids,
        attention_mask,
    ) -> Optional[torch.FloatTensor]:

        transformer_outputs = self.llm.forward(
            input_ids,
            attention_mask=attention_mask,
            output_hidden_states = True,
        )
        lm_logits = transformer_outputs.logits
        # Get the last hidden state
        last_hidden_state = transformer_outputs.hidden_states[-1]

        # Apply the reward head
        value = self.v_head(last_hidden_state).squeeze(-1)
        return lm_logits, value

    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)


In [6]:
model_path = './sft_model_epoch_1'
model = ModelForCausalLMWithValueHead(model_path)

## Preparing Dataset

In [7]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

In [8]:
%pip install datasets==3.5.0



In [9]:
from datasets import load_dataset
dataset = load_dataset("sst2")
dataset

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

In [10]:
ds_train, ds_val = dataset['train'], dataset['validation']

## Filtering

In [11]:
len(ds_train)

67349

In [12]:
ds_train = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)

In [13]:
len(ds_train)

31105

In [14]:
ds_val = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

In [15]:
len(ds_val)

807

In [16]:
import random
input_min_token_length = 2
input_max_token_length = 8
input_token_length_range = list(range(input_min_token_length, input_max_token_length))
print(input_token_length_range)

[2, 3, 4, 5, 6, 7]


In [17]:
random.choice(input_token_length_range)

6

In [18]:
def tokenize(sample):
    input_size = random.choice(input_token_length_range)
    sample['input_ids'] = tokenizer.encode(sample['sentence'])[:input_size]
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

map_kwargs = {
    "batched": False,
    "remove_columns": ['idx', 'sentence', 'label']
}

tokenized_dataset_train = ds_train.map(tokenize, **map_kwargs)
tokenized_dataset_val = ds_val.map(tokenize, **map_kwargs)


In [19]:
tokenized_dataset_train.set_format(type='torch')
tokenized_dataset_val.set_format(type='torch')

In [20]:
tokenized_dataset_train[6]

{'input_ids': tensor([1640,  883]),
 'attention_mask': tensor([1, 1]),
 'query': 'for those'}

In [21]:
REWARD_TOKEN_ID = tokenizer.eos_token_id

In [22]:
from torch.utils.data import DataLoader

batch_size = 32

def collator(batch):
    return dict((key, [d[key] for d in batch]) for key in batch[0])

train_dataloader = DataLoader(tokenized_dataset_train, batch_size=batch_size, collate_fn=collator, shuffle=True)
val_dataloader = DataLoader(tokenized_dataset_val, batch_size=batch_size, collate_fn=collator, shuffle=True)

In [23]:
batch = next(iter(train_dataloader))
batch

{'input_ids': [tensor([ 271,  257, 5391,   12,   86, 2175]),
  tensor([8340,  257, 3807]),
  tensor([23442,   262]),
  tensor([19188,   326,  4077, 29815,   550,  3750,   257]),
  tensor([  292, 29408]),
  tensor([5832,  705,  260, 9431]),
  tensor([  338,  3729,   281, 30438,  1700,   286,   326]),
  tensor([1640,  257]),
  tensor([ 361,  262, 2587,  318, 3731,  290]),
  tensor([  86, 4066,  656,  262]),
  tensor([ 271,  257, 1339,  286, 1165,  867]),
  tensor([  533, 47300,   290,  7744,   503]),
  tensor([ 72, 588, 326, 895, 342, 837, 339]),
  tensor([15596,   935,  5442,  6616,  1108,   326]),
  tensor([   8, 1266, 2499, 1833]),
  tensor([ 1169,   442,   378,   559, 14448,   284,   374]),
  tensor([  944,    12, 13716,   284,  5698,   257,  9155]),
  tensor([ 1169, 32339,   290,   262,   584,  3435,  1949]),
  tensor([  271,   517,   287,  1842,   351, 36666,  9449]),
  tensor([21754,  4729, 44918, 15579, 14720,   329,  3081]),
  tensor([   64, 41456,   837]),
  tensor([ 338,  281,

In [24]:
output_min_length = 5
output_max_length = 16

# https://huggingface.co/docs/trl/how_to_train#how-to-generate-text-for-training

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.pad_token_id
}

## Sample Generation

In [25]:
new_tokens = random.choice(list(range(output_min_length, output_max_length)))
generation_kwargs["max_new_tokens"] = new_tokens
sample = tokenizer('Hi, this')
sample

{'input_ids': [17250, 11, 428], 'attention_mask': [1, 1, 1]}

In [26]:
query_response = model.generate(
    input_ids=torch.tensor(sample['input_ids']).unsqueeze(0),
    attention_mask=torch.tensor(sample['attention_mask']).unsqueeze(0),
    **generation_kwargs
    ).squeeze(0)
query_response

tensor([17250,    11,   428,   318, 17774,   220,  1849,   220,   220,   220,
          220,   220,   220,   220])

In [27]:
tokenizer.decode(query_response)

'Hi, this is entertaining \xa0       '

In [28]:
with torch.no_grad():
    query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID])])
    attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
    score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
score

tensor(0.9943)

## Batch Generation

In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
reward_model = reward_model.to(device)

query_tensors = batch['input_ids']
query_attention_masks = batch['attention_mask']

response_tensors = []
query_response_tensors = []
score_tensors = []

for i, query in enumerate(query_tensors):
    query = query.to(device)
    query_attention_mask = query_attention_masks[i].to(device)
    new_tokens = random.choice(list(range(output_min_length, output_max_length)))
    generation_kwargs["max_new_tokens"] = new_tokens
    query_response = model.generate(
        input_ids=query.unsqueeze(0),
        attention_mask=query_attention_mask.unsqueeze(0),
        **generation_kwargs
    ).squeeze(0)

    response_len = len(query_response) - len(query)
    response_tensors.append(query_response[-response_len:])
    query_response_tensors.append(query_response)

    with torch.no_grad():
        query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
        attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
        score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
        score = 2 * (score - 0.5)
    score_tensors.append(score)

batch["response"] = [tokenizer.decode(response) for response in response_tensors]
print(batch['response'])

[' hunk of a movie .   ', ' but fail to make it feel authentic      ', ' edge of his seat to bring this nicely made comedy to', ' step further ?        ', ' in nature     could have', ' huston is smart .     !', ' time period. \xa0      ', ' nuance that few such one-dimensional thrillers could ever hope to', ' under-designed or ', " realm of one man 's farts , like neither a short movie nor", ' thin-skinned kids being forced to watch overexposed , too forced to', ' of place in every frame \xa0of the film  ', ' knows how to get the ball rolling , hitting his target', ' is extraordinary .            ', ' state conscience  and a condition had to be', 'alph polanski  \xa0and there are some', ' storyline  \ue607  \ue607', ' desperately to convince -- or even understand -- the audience that this is', ' often than in love with the story itself', ' entertainment    ', ' deeply affecting examination of a genteel two-hour', ' theme episode , encountering it with freshness and curiousity', ' ever

## Compute Reward

$\text {reward} = \text {score} - \log (\frac {\pi^{RL}_\theta} {\pi^{SFT}})$

In [30]:
from copy import deepcopy
sft_model = deepcopy(model)

In [31]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [32]:
input_data = data_collator([
    {'input_ids': ids,
     'attention_mask': torch.ones_like(ids)} for ids in query_response_tensors
]).to(device)
input_data

{'input_ids': tensor([[  271,   257,  5391,    12,    86,  2175,   289,  2954,   286,   257,
          3807,   764,   220,   220,   220, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 8340,   257,  3807,   475,  2038,   284,   787,   340,  1254, 16425,
           220,   220,   220,   220,   220,   220, 50256, 50256, 50256, 50256,
         50256],
        [23442,   262,  5743,   286,   465,  5852,   284,  2222,   428, 16576,
           925, 10997,   284, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [19188,   326,  4077, 29815,   550,  3750,   257,  2239,  2252,  5633,
           220,   220,   220,   220,   220,   220,   220,   220, 50256, 50256,
         50256],
        [  292, 29408,   287,  3450,   220,   220,   220,   220,   714,   423,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256],
        [ 5832,   705,   260,  9431, 33376,   261,   318,  4451,   764,   220,
           220,   220,   220,  5

In [33]:
def compute_rewards(input_data, query_tensors, response_tensors, score_tensors):
    with torch.no_grad():
        logits, values = model(**input_data) # b, seq, vocab
        ref_logits, _ = sft_model(**input_data)
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)

        labels = input_data['input_ids'][:, 1:] # b, seq

        logp = torch.gather(logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq
        ref_logp = torch.gather(ref_logp, 2, labels.unsqueeze(-1)).squeeze(-1) # batch, seq

        kl = logp - ref_logp
        beta = 0.2
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:,:] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks


In [34]:
logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
print(rewards[0])
print(input_data['input_ids'][0])
print(input_data['attention_mask'][0])

tensor([-0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000, -0.0000, -0.9969, -0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='cuda:0')
tensor([  271,   257,  5391,    12,    86,  2175,   289,  2954,   286,   257,
         3807,   764,   220,   220,   220, 50256, 50256, 50256, 50256, 50256,
        50256], device='cuda:0')
tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       device='cuda:0')


In [35]:
print(masks[0])
print(values[0])

tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
       device='cuda:0')
tensor([-0.0000, -0.0000, -0.0000,  0.0000,  0.0000, -1.6054, -0.1053, -1.0586,
        -1.0855, -0.2100, -1.4899,  3.4527,  0.9113,  1.7163,  0.0000, -0.0000,
        -0.0000, -0.0000, -0.0000, -0.0000], device='cuda:0')


## Compute Advantage

In [36]:
def masked_mean(values, mask):
    return (values * mask).sum() / mask.sum()

def masked_var(values, mask):
    mean = masked_mean(values, mask)
    centred_values = values - mean
    return masked_mean(centred_values ** 2, mask)

def masked_whiten(values, mask):
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened

def compute_advantage(rewards, values, masks):
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95

    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)

    returns = advantages + values
    return advantages, returns


In [37]:
advantages, returns = compute_advantage(rewards, values, masks)
print(advantages[0])
print(returns[0])

tensor([ 0.1104,  0.0946,  0.0779,  0.0603,  0.0418,  0.9713,  0.1140,  0.6619,
         0.6910,  0.1882,  0.9330, -1.9612, -0.5839, -1.1121,  0.4116,  0.4116,
         0.4116,  0.4116,  0.4116,  0.4116], device='cuda:0')
tensor([ 0.1104,  0.0946,  0.0779,  0.0603,  0.0418, -0.6341,  0.0088, -0.3967,
        -0.3945, -0.0218, -0.5569,  1.4915,  0.3275,  0.6042,  0.4116,  0.4116,
         0.4116,  0.4116,  0.4116,  0.4116], device='cuda:0')


## Mini-batch PPO Training

### Training Config

In [38]:
learning_rate = 1e-5
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [39]:
np.random.permutation(batch_size)

array([13, 17, 20,  7, 26,  1, 25, 23, 11, 21,  4, 14,  6,  3,  9,  2,  0,
       15,  5, 10, 30, 19, 29, 22, 24,  8, 28, 12, 27, 18, 16, 31])

In [40]:
mini_batch_size = 4
ppo_epochs = 4

cliprange_ratio = 0.2

v_loss_coeff = 0.1

ratio_threshold = 10

def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)

    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss

    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0

    return loss, v_loss

def mini_batch_train():
    for ep in range(ppo_epochs):
        batch_inds = np.random.permutation(batch_size)

        for start in range(0, batch_size, mini_batch_size):
            end = start + mini_batch_size
            mini_batch_inds = batch_inds[start:end]

            mb_model_inputs = {
                'input_ids': input_data['input_ids'][mini_batch_inds],
                'attention_mask': input_data['attention_mask'][mini_batch_inds]
            }
            mb_logits, mb_vpreds = model(**mb_model_inputs)
            mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
            mb_logprobs = torch.gather(mb_logits, 2, mb_model_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)

            loss, loss_v = compute_loss(
                logprobs[mini_batch_inds],
                values[mini_batch_inds],
                mb_logprobs,
                mb_vpreds[:, :-1],
                masks[mini_batch_inds],
                advantages[mini_batch_inds],
                returns[mini_batch_inds]
            )

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            print('loss/total', loss.item())
    print('mini-batch training finished')



In [41]:
mini_batch_train()

loss/total -0.6469668745994568
loss/total -1.084718942642212
loss/total -0.9519063234329224
loss/total -0.5013611912727356
loss/total -0.7945360541343689
loss/total -0.4254860281944275
loss/total -1.0511527061462402
loss/total -0.9178730845451355
loss/total -0.31340065598487854
loss/total -1.4144445657730103
loss/total -0.74092698097229
loss/total -1.5609691143035889
loss/total -0.8986895084381104
loss/total -1.1131693124771118
loss/total -1.3536183834075928
loss/total -0.83698570728302
loss/total -0.8659009337425232
loss/total -1.666944980621338
loss/total -0.9409258365631104
loss/total -1.2696202993392944
loss/total -0.6142494678497314
loss/total -1.3109718561172485
loss/total -0.7585713267326355
loss/total -0.9758541584014893
loss/total -1.3880150318145752
loss/total -0.833629310131073
loss/total -0.638252317905426
loss/total -1.1900427341461182
loss/total -0.4314558207988739
loss/total -1.388943076133728
loss/total -1.1693115234375
loss/total -1.4005227088928223
mini-batch training

## Train RLHF

In [42]:
num_epochs = 1

for epoch in range(num_epochs):
    for batch in train_dataloader:
        # Generate responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']

        response_tensors = []
        query_response_tensors = []
        score_tensors = []

        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = random.choice(list(range(output_min_length, output_max_length)))
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)

            response_len = len(query_response) - len(query)
            response_tensors.append(query_response[-response_len:])
            query_response_tensors.append(query_response)

            with torch.no_grad():
                query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
                attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
                score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
                score = 2 * (score - 0.5)
            score_tensors.append(score)

        input_data = data_collator([
            {
                'input_ids': ids,
                'attention_mask': torch.ones_like(ids)
            }
            for ids in query_response_tensors
        ]).to(device)

        # rewards and advantages
        logprobs, rewards, values, masks = compute_rewards(input_data, query_tensors, response_tensors, score_tensors)
        advantages, returns = compute_advantage(rewards, values, masks)

        # mini batch training
        mini_batch_train()
    print(f'epoch {epoch + 1} finished')

loss/total -0.2181205004453659
loss/total -0.27999866008758545
loss/total -0.04347766935825348
loss/total -0.09446786344051361
loss/total -0.2365225851535797
loss/total -0.1936279833316803
loss/total -0.18569737672805786
loss/total 0.015834838151931763
loss/total -0.3695436716079712
loss/total -0.3765247166156769
loss/total -0.1999007761478424
loss/total -0.04889871925115585
loss/total -0.3259511888027191
loss/total -0.5705932378768921
loss/total -0.5666579008102417
loss/total 0.31472232937812805
loss/total -0.24711640179157257
loss/total 0.11507843434810638
loss/total -0.6125611066818237
loss/total -0.2456435263156891
loss/total -0.31063202023506165
loss/total -0.14951065182685852
loss/total -0.42727845907211304
loss/total -0.6759032011032104
loss/total -0.24428074061870575
loss/total -0.3461518883705139
loss/total -0.35590457916259766
loss/total -0.20021793246269226
loss/total -0.36694443225860596
loss/total -0.5016358494758606
loss/total -0.2143326699733734
loss/total -0.49306899309

KeyboardInterrupt: 

## Validation

In [43]:
len(tokenized_dataset_val)

807

In [44]:
val_gen_lengths = [0] * len(tokenized_dataset_val)
for i in range(len(tokenized_dataset_val)):
    val_gen_lengths[i] = random.choice(list(range(output_min_length, output_max_length)))

In [45]:
val_gen_lengths[:10]

[5, 8, 7, 11, 10, 8, 8, 9, 8, 10]

In [46]:
def validate():
    scores = []
    for b, batch in enumerate(val_dataloader):
        # Generate_responses
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs["max_new_tokens"] = new_tokens
            query_response = model.generate(
                input_ids=query.unsqueeze(0),
                attention_mask=query_attention_mask.unsqueeze(0),
                **generation_kwargs
                ).squeeze(0)
            query_response_score = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(query_response_score, dtype=torch.long)
            score = reward_model(query_response_score.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('avg score:', sum(scores) / len(scores))

In [47]:
validate()

avg score: 0.6572876147916621


In [48]:
torch.save(model.state_dict(), 'ppo_model_epoch_1.pt')

In [51]:
model_path = './sft_model_epoch_1'
model = ModelForCausalLMWithValueHead(model_path).to(device)

In [52]:
validate()

avg score: 0.07222306484626571
