# Name: Oliver JACK

# TLDR: GRPO with PPO-style multiple updates per batch (similar to DeepSeek implementation)

# GRPO Training project: teach an LLM to do additions, again

In this notebook, you'll find:
* A basic Transformer with basic tokenizer
* A basic dataset for additions
* A classical pre-trainer, minimizing cross entropy loss
* A Vanilla GRPO

You're not supposed to edit the existing code (you can if you want to...).
You should implement one (or more) of the following:
* GRPO with PPO (the `usual` one)
* RLOO
* ReMax
* DPO
* RAFT
* your own RLHF method!

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

import random
import math
import re
import time

import copy

In [2]:
num_digits = 3

dataset_size = 64_000
train_proportion = 0.9

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Step 1: Construct a tokenizer

In [4]:
pad_token="[PAD]"
eos_token="[EOS]"

In [5]:
class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"

    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [6]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens
ntokens

14

In [7]:
prompt = "12 + 42 ="
inputs = tokenizer.encode(prompt)
inputs, tokenizer.decode(inputs)

([1, 2, 10, 4, 2, 11], '12+42=')

## Step 2: Create a dataset for arithmetic operations

In [8]:
def sample_datapoint(num_digits = 3):
    a_list = [random.randint(0, 9) for _ in range(num_digits)]
    b_list = [random.randint(0, 9) for _ in range(num_digits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    a_str = "".join([str(x) for x in a_list])
    b_str = "".join([str(x) for x in b_list])
    sum_int = a_int + b_int
    return (a_str + "+" + b_str + "=", str(sum_int))

sample_datapoint(3)

('386+039=', '425')

In [9]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(num_digits))
data[:4]

[('891+316=', '1207'),
 ('747+075=', '822'),
 ('358+301=', '659'),
 ('213+636=', '849')]

In [10]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(57600, 6400)

## Step 3: Construct a model

In [11]:
class PositionalEncoding(nn.Module):
    r"""Inject some information about the relative or absolute position of the tokens in the sequence.
        The positional encodings have the same dimension as the embeddings, so that the two can be summed.
        Here, we use sine and cosine functions of different frequencies.
    .. math:
        \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model))
        \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model))
        \text{where pos is the word position and i is the embed idx)
    Args:
        d_model: the embed dim (required).
        dropout: the dropout value (default=0.1).
        max_len: the max. length of the incoming sequence (default=5000).
    """

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        r"""Inputs of forward function
        Args:
            x: the sequence fed to the positional encoder model (required).
        Shape:
            x: [sequence length, batch size, embed dim]
            output: [sequence length, batch size, embed dim]
        """

        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

In [12]:
class TransformerModel(nn.Transformer):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__(d_model=ninp,
                                               nhead=nhead,
                                               dim_feedforward=nhid,
                                               num_encoder_layers=nlayers)
        self.input_emb = nn.Embedding(ntoken, ninp)
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.decoder = nn.Linear(ninp, ntoken)

        self.ninp = ninp
        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        nn.init.uniform_(self.input_emb.weight, -initrange, initrange)
        nn.init.zeros_(self.decoder.bias)
        nn.init.uniform_(self.decoder.weight, -initrange, initrange)

    def _generate_square_subsequent_mask(self, sz):
        return torch.log(torch.tril(torch.ones(sz,sz)))

    def forward(self, src):
        mask = self._generate_square_subsequent_mask(len(src)).to(device)
        self.src_mask = mask

        src = self.input_emb(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output_enc = self.encoder(src, mask=self.src_mask)
        output_dec = self.decoder(output_enc)
        return output_dec, output_enc

In [13]:
model = TransformerModel(ntoken = ntokens,
                         ninp = 128,
                         nhead = 16,
                         nhid = 64,
                         nlayers = 8)
model.to(device)



TransformerModel(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [14]:
print("number of parameters: {}".format(sum([x.numel() for x in model.parameters()])))

number of parameters: 668942


### Useful functions

In [15]:
def generate(model, prompts, new_tokens = 5, mode = "greedy", num_samples = 1, temperature = 0.8):
    input_tensor = torch.repeat_interleave(prompts, repeats = num_samples, dim = 1).to(device)
    # (prompt_length, batch_size * num_samples)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (prompt_length, batch_size * num_samples, ntokens)
        logits = output[-1,:,:] # (batch_size * num_samples, ntokens)
        if mode == "greedy":
            tokens = torch.argmax(logits, -1).view((1,-1)) # (1, batch_size * num_samples)
        else: # mode == "sampling"
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            tokens = torch.multinomial(probs, num_samples = 1).view((1,-1)) # (1, batch_size * num_samples)
        input_tensor = torch.cat((input_tensor, tokens), 0)
    return input_tensor

In [16]:
model.eval()

prompt = "2+3="
prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
output = generate(model, prompt_tensor).view((1,-1))
output, tokenizer.decode(output[0].tolist())

(tensor([[ 2, 10,  3, 11,  7, 13,  7, 13,  7]], device='cuda:0'),
 '2+3=7[EOS]7[EOS]7')

In [17]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [18]:
prompts = [tokenizer.encode("1+1="), tokenizer.encode("21+35=")]
answers = [tokenizer.encode("2"), tokenizer.encode("56")]
padded_prompts, _ = pad(prompts, "prompts")
padded_answers, _ = pad(answers, "answers")
padded_prompts, padded_answers
[tokenizer.decode(p) for p in padded_prompts], [tokenizer.decode(p) for p in padded_answers]

(['[PAD][PAD]1+1=', '21+35='], ['2[EOS][PAD]', '56[EOS]'])

In [19]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test

    prompts = [data[i][0] for i in range(i, i + batch_size)]
    encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    padded_prompts, prompt_length = pad(encoded_prompts, "prompts")

    answers = [data[i][1] for i in range(i, i + batch_size)]
    encoded_answers = [tokenizer.encode(answer) for answer in answers]
    padded_answers, answers_length = pad(encoded_answers, "answers")

    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, prompt_length, answers_length, prompts, answers

In [20]:
X, Y, prompt_length, answers_length, prompts, answers = get_batch("train", 43, 16)
X.shape, Y.shape, prompt_length, answers_length, prompts[0], answers[0]

(torch.Size([8, 16]), torch.Size([5, 16]), 8, 4, '575+346=', '921')

## Step 4: Evaluate

In [21]:
batch_size = 16

In [22]:
def evaluate(batch_size = batch_size):
    # Turn on evaluation mode disables dropout.
    model.eval()
    correct = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            output = generate(model, prompts, answers_length + 1) # (prompt_length + answers_length + 1, batch_size)
            answers_tokens = output[prompt_length:, :] # (answers_length + 1, batch_size), contains tokens
            equality_test = answers_tokens == target_answers # (answers_length + 1, batch_size), contains boolean values
            correct += torch.all(equality_test, axis=0).float().sum()
        accuracy = correct / len(data_test)
    return accuracy.item()

In [23]:
evaluate()

0.0

## Step 5: Train the model, classical approach

### Hyperparameters

In [24]:
epochs = 5
batch_size = 16
learning_rate = 8e-4

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

In [25]:
def train():
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        total_loss = 0.
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size)
            model.zero_grad()
            output, _ = model(input_tensor) # (prompt_length + answers_length + 1, batch_size, ntokens)
            output_answers = output[prompt_length-1:-1,:,:].reshape(-1, ntokens) # ((answers_length + 1) * batch_size, ntokens)
            target_answers = target_answers.view(-1)
            loss = F.cross_entropy(output_answers, target_answers)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if i % log_interval == 0 and batch > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                            elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)
        # Save the model if the test accuracy is the best we've seen so far.
        if not best_test_accuracy or test_accuracy > best_test_accuracy:
            with open("arithmetic_v1.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [26]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.00
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.11 | loss  0.09 | perplexity     1.09
|  1200/ 3600 batches | ms/batch  1.09 | loss  0.07 | perplexity     1.07
|  1800/ 3600 batches | ms/batch  1.07 | loss  0.07 | perplexity     1.07
|  2400/ 3600 batches | ms/batch  1.07 | loss  0.07 | perplexity     1.07
|  3000/ 3600 batches | ms/batch  1.07 | loss  0.07 | perplexity     1.07
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 72.07s | test accuracy  0.01
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  1.07 | loss  0.06 | perplexity     1.07
|  1200/ 3600 batches | ms/batch  1.07 | loss  0.06 | perplexity     1.07
|  1800/ 3600 batches | ms/

In [27]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

533+955=1488[EOS]	 actual result: 1488
969+150=1118[EOS]	 actual result: 1119
876+729=1605[EOS]	 actual result: 1605
376+353=729[EOS]	 actual result: 729
903+755=1658[EOS]	 actual result: 1658
452+736=1188[EOS]	 actual result: 1188
153+522=675[EOS]	 actual result: 675
663+621=1284[EOS]	 actual result: 1284
089+361=450[EOS]	 actual result: 450
706+614=1320[EOS]	 actual result: 1320
245+176=421[EOS]	 actual result: 421
659+747=1406[EOS]	 actual result: 1406
969+305=1274[EOS]	 actual result: 1274
314+784=1097[EOS]	 actual result: 1098
994+280=1274[EOS]	 actual result: 1274
049+609=657[EOS]	 actual result: 658
801+751=1552[EOS]	 actual result: 1552
615+012=626[EOS]	 actual result: 627
844+221=1064[EOS]	 actual result: 1065
596+812=1408[EOS]	 actual result: 1408


## Step 4 bis: Vanilla GRPO training

### Custom reward functions

In [28]:
def accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return 1. if output == answer else 0.

accuracy_reward("123[EOS][PAD][PAD]", "123"), accuracy_reward("123", "124")

(1.0, 0.0)

In [29]:
def distance_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    int_output = int(output)
    int_answer = int(answer)
    return 1 - abs(int_output - int_answer) / max(int_output, int_answer)

distance_accuracy_reward("123[EOS]", "123"), distance_accuracy_reward("123[PAD]", "124")

(1.0, 0.9919354838709677)

In [30]:
def digit_accuracy_reward(output, answer):
    pattern = r"\[EOS\]"
    output = re.sub(pattern, "", output)
    pattern = r"(\[PAD\])*$"
    output = re.sub(pattern, "", output)
    return sum(c1 == c2 for (c1,c2) in zip(output, answer)) / max(len(output), len(answer))

digit_accuracy_reward("123[EOS][PAD][PAD]", "123"), digit_accuracy_reward("123[EOS]", "123")

(1.0, 1.0)

In [31]:
def reward_format(output):
    pattern = r"\d+\[EOS\](\[PAD\])*$"
    return 1. if bool(re.match(pattern, output)) else 0.

reward_format("123[EOS][PAD][PAD]"), reward_format("123[EOS]"), reward_format("123")

(1.0, 1.0, 0.0)

### Hyperparameters

In [32]:
epochs = 3
batch_size = 16
learning_rate = 1e-4
num_samples = 16
temperature = .8

beta = 0.04
m = 5
epsilon = 0.2

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

reward_fun = digit_accuracy_reward
reward_format = reward_format

In [33]:
def compute_rewards(text_outputs, answers):
    repeated_answers = [answer for answer in answers for _ in range(num_samples)]
    rewards = torch.tensor(
        [0.2 * reward_format(output) + 0.8 * reward_fun(output, answer)
         for output, answer in zip(text_outputs, repeated_answers)],
        dtype=torch.float32,
        device=device
    )
    return rewards

In [34]:
def calculate_grpo_advantages(rewards):
    # Reshape rewards to have shape [batch_size, num_samples]
    grouped_rewards = rewards.view(-1, num_samples)

    # Calculate mean and std within each group
    mean_rewards = grouped_rewards.mean(dim=1, keepdim=True)
    std_rewards = grouped_rewards.std(dim=1, keepdim=True) + 1e-8

    # Normalize rewards within each group
    normalized_rewards = (grouped_rewards - mean_rewards) / std_rewards

    # Reshape back to original shape
    advantages = normalized_rewards.view(-1)

    return advantages

In [35]:
def compute_log_probs(model, outputs, prompt_length):
    logits, _ = model(outputs)
    # logits.shape = (prompt_length + answers_length + 1, batch_size * num_samples, vocab_size)

    # we only need the log probabilities for the new tokens
    # this introduces a shift: the logits for a position are the predictions for the next token
    logits = logits[prompt_length-1:-1, :, :]
    # logits.shape = (answers_length + 1, batch_size * num_samples, vocab_size)

    # convert raw logits into log probabilities along the vocabulary axis
    log_probs = F.log_softmax(logits, dim=-1)
    # log_probs.shape = (answers_length + 1, batch_size * num_samples, vocab_size)
    return log_probs

In [36]:
def get_selected_logprobs(log_probs, responses):
    # Add dimension for gathering
    responses = responses.unsqueeze(-1)

    # Gather log probabilities for the tokens that were actually chosen
    selected_log_probs = log_probs.gather(dim=-1, index=responses)
    selected_log_probs = selected_log_probs.squeeze(-1)

    return selected_log_probs

In [37]:
def compute_ppo_grpo_loss(logprobs_current, logprobs_ref, advantages, epsilon, num_samples, logprobs_old=None, beta=0.04):
    # Use current logprobs as old logprobs if not provided
    if logprobs_old is None:
        old_logprobs = logprobs_current.detach()
    else:
        old_logprobs = logprobs_old

    # Calculate KL divergence between current and reference model
    if beta > 0:
        per_token_kl = (torch.exp(logprobs_ref.detach() - logprobs_current) - (logprobs_ref.detach() - logprobs_current) - 1)

    # Calculate PPO clipping objective
    ratio = torch.exp(logprobs_current - old_logprobs)
    clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

    surrogate1 = ratio * advantages.unsqueeze(0)
    surrogate2 = clipped_ratio * advantages.unsqueeze(0)

    # Take minimum of surrogate terms
    per_token_loss = -torch.min(surrogate1, surrogate2)

    # Add KL penalty if beta > 0
    if beta > 0:
        per_token_loss = per_token_loss + beta * per_token_kl

    # Average over tokens and samples
    loss = per_token_loss.mean()

    return loss

In [38]:
def train_PPO_GRPO(verbose=False):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    best_test_accuracy = None
    test_accuracy = evaluate()
    print('-' * 89)
    print('| initialisation | test accuracy {:5.2f}'.format(test_accuracy))
    print('-' * 89)

    # switch eval for train model (enables dropout)
    model.train()

    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        start_time = time.time()

        # Create/update reference model at the beginning of each epoch
        reference_model = copy.deepcopy(model)
        reference_model.eval()

        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):

            # get a batch of prompts and answers
            prompts, _, prompt_length, answers_length, questions, answers = get_batch("train", i, batch_size)
            prompts = prompts.to(device)

            # generate samples for each prompt using current model
            outputs = generate(model,
                              prompts,
                              new_tokens=answers_length + 1,
                              mode="sampling",
                              num_samples=num_samples,
                              temperature=temperature)

            # Get responses
            responses = outputs[prompt_length:, :]

            # Decode outputs for reward calculation
            text_outputs = [tokenizer.decode(outputs[prompt_length:, i].tolist()) for i in range(outputs.size(1))]

            # Calculate rewards and advantages
            rewards = compute_rewards(text_outputs, answers)
            advantages = calculate_grpo_advantages(rewards)

            # Get reference model log probabilities
            with torch.no_grad():
                log_probs_ref = compute_log_probs(reference_model, outputs, prompt_length)
                selected_log_probs_ref = get_selected_logprobs(log_probs_ref, responses)

            # Multiple PPO updates per batch
            selected_log_probs_old = None
            for update_idx in range(m):
                # Compute current model log probabilities
                log_probs_current = compute_log_probs(model, outputs, prompt_length)
                selected_log_probs_current = get_selected_logprobs(log_probs_current, responses)

                # Compute PPO GRPO loss
                loss = compute_ppo_grpo_loss(
                    selected_log_probs_current,
                    selected_log_probs_ref,
                    advantages,
                    epsilon,
                    num_samples,
                    logprobs_old=selected_log_probs_old,
                    beta=beta
                )

                # Optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Update old log probs for next iteration
                selected_log_probs_old = selected_log_probs_current.detach()

            if i % log_interval == 0 and batch > 0:
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f}'.format(
                    batch, len(data_train) // batch_size, elapsed * 1000 / log_interval, loss.item()))

                if verbose:
                    print("\nquestion:", questions[0],
                          "\nanswer:", answers[0],
                          "\noutput:", text_outputs[:num_samples],
                          "\nreward:", rewards[:num_samples],
                          "\nadvantage:", advantages[:num_samples], "\n")

                start_time = time.time()

        test_accuracy = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test accuracy {:5.2f}'.format(
            epoch, (time.time() - epoch_start_time), test_accuracy))
        print('-' * 89)

        # Save the model if the test accuracy is the best we've seen so far
        if not best_test_accuracy or test_accuracy > best_test_accuracy:
            with open("arithmetic_PPO_GRPO.pt", 'wb') as f:
                torch.save(model, f)
            best_test_accuracy = test_accuracy

In [39]:
train_PPO_GRPO(verbose = False)

-----------------------------------------------------------------------------------------
| initialisation | test accuracy  0.79
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  9.58 | loss 41.24
|  1200/ 3600 batches | ms/batch  9.50 | loss 93.90
|  1800/ 3600 batches | ms/batch  9.49 | loss  1.83
|  2400/ 3600 batches | ms/batch  9.53 | loss  0.91
|  3000/ 3600 batches | ms/batch  9.51 | loss  0.72
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 558.57s | test accuracy  0.21
-----------------------------------------------------------------------------------------
|   600/ 3600 batches | ms/batch  8.90 | loss  0.03
|  1200/ 3600 batches | ms/batch  8.85 | loss  0.02
|  1800/ 3600 batches | ms/batch  8.86 | loss  0.02
|  2400/ 3600 batches | ms/batch  8.89 | loss  0.02
|  3000/ 3600 batches | ms/batch  8.86 | loss  0.02
------------------------

In [40]:
model.eval()

for i in range(20):
    prompt, answers = data_test[i]
    prompt_tensor = torch.tensor(tokenizer.encode(prompt)).view((-1,1))
    output = generate(model, prompt_tensor, len(answers) + 1).view((1,-1))
    print(tokenizer.decode(output.tolist()[0]) + "\t actual result: " + answers)

533+955=1488[EOS]	 actual result: 1488
969+150=1119[EOS]	 actual result: 1119
876+729=1605[EOS]	 actual result: 1605
376+353=729[EOS]	 actual result: 729
903+755=1658[EOS]	 actual result: 1658
452+736=1188[EOS]	 actual result: 1188
153+522=675[EOS]	 actual result: 675
663+621=1284[EOS]	 actual result: 1284
089+361=450[EOS]	 actual result: 450
706+614=1320[EOS]	 actual result: 1320
245+176=421[EOS]	 actual result: 421
659+747=1406[EOS]	 actual result: 1406
969+305=1274[EOS]	 actual result: 1274
314+784=1098[EOS]	 actual result: 1098
994+280=1274[EOS]	 actual result: 1274
049+609=658[EOS]	 actual result: 658
801+751=1552[EOS]	 actual result: 1552
615+012=627[EOS]	 actual result: 627
844+221=1065[EOS]	 actual result: 1065
596+812=1408[EOS]	 actual result: 1408
