<a href="https://colab.research.google.com/github/deekshayennam/RLFineTune_GRPO/blob/main/RL_Reasoning_Writing_GRPO_on_base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<h3 align="center"></h3>
<h1 align="center">RL, Reasoning & Writing</h1>

---

<h1 align="center">GRPO on Base model</h1>

This notebook introduces a series of experiment of RL training on a base model, rather than an instruct-model. It is obviously inspired by R0, the *other* DeepSeek model trained from DeepSeekv3: while R1 was more classically post-trained by a series of instruct finetuning and

We reuse the same RL method as R0, GRPO. For a more straightforward way of testing on an instruct model, you can check Will Brown's script that I ported to Google Colab. Here instead we'll take up the opportunity to explore alternative forms of RL tuning that fits better with using a base model as a starting point: poetry writing. We're going to make an RL poet.

<img src="https://raw.githubusercontent.com/Pleias/RL-Reasoning/refs/heads/main/apollinaire_2.png">

It's well known that instruct tuning harms creative writing capabitilies - it has been my frustration ever since ChatGPT got released since I was used to GPT-3 charming, diverse, lustrous style.

We're going to use a new base model pretrained by Pleias, Pleias-350m on an entirely open training set, Common Corpus. Despite its small size (350 million parameters, about the size of GPT-2 medium), Pleias-350m is multilingual and trained on a lot of (public domain) literary works. The model is already able to output poems spontaneously, although in a very "raw" format.

If you're rather used to LLM fine-tuning, the overall code notebook will look quite alien: since we are not even using an instruct version, it is some form of "no-data" training. In the end there is no data aside from the pre-training set. We're just taping into the base model multiverse.

## Setting up the models.

Before getting to the actual matter at hand, let's install the necessary libraries. We'll use vllm to speed up inference training: as a reminder, RL requires generating multiple "drafts" that will then be evaluated using a reward function. The fastest the drafting, the fastest the training and vllm speed up inference considerably in regards to default trl engine.

In [None]:
!pip install vllm

Collecting vllm
  Downloading vllm-0.7.1-cp38-abi3-manylinux1_x86_64.whl.metadata (12 kB)
Collecting blake3 (from vllm)
  Downloading blake3-1.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.2 kB)
Collecting transformers>=4.48.2 (from vllm)
  Downloading transformers-4.48.2-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting fastapi!=0.113.*,!=0.114.0,>=0.107.0 (from vllm)
  Downloading fastapi-0.115.8-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn[standard] (from vllm)
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting prometheus-fastapi-instrumentator>=7.0.0 (from vllm)
  Downloading prometheus_fastapi_instrumentator-7.0.2-py3-none-any.whl.metadata (13 kB)
Collecting tiktoken>=0.6.0 (from vllm)
  Downloading tiktoken-0.8.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.6 kB)
Collecting lm-form

In [None]:
!pip install trl
!pip install datasets

Collecting trl
  Downloading trl-0.14.0-py3-none-any.whl.metadata (12 kB)
Collecting datasets>=2.21.0 (from trl)
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.21.0->trl)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets>=2.21.0->trl)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets>=2.21.0->trl)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets>=2.21.0->trl)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.10.0->accelerate>=0.34.0->trl)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.10.0->a

Now we can start  importing Pleias-350m:

In [None]:
# First, let's import our required libraries
import torch
from torch.nn import functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# Let's use GPT-2 small for this example
model_name = "PleIAs/Pleias-350m-Preview"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Since Pleias-350m is a base model that has never been instruct-tuned, it only we are going to use it in continuation mode like in the good old days of 350m. Given a few lines of poetry, it should predict the next. Let's heat up with some classic:

In [None]:
# Our input prompt
prompt="""Mine eye hath play'd the painter and hath stell'd
Thy beauty's form in table of my heart;
My body is the frame wherein 'tis held,
And perspective it is best painter's art.
For through the painter must you see his skill,
"""

We'll start by converting this text into tokens:

In [None]:
input_ids = tokenizer.encode(prompt, return_tensors='pt')
input_ids = input_ids.to(device)
print(f"\nTokenized input shape: {input_ids.shape}")
print(f"Last Token IDs: {input_ids[0].tolist()[-10:]}")
print(f"Last Decoded tokens: {[tokenizer.decode([id]) for id in input_ids[0].tolist()[-10:]]}")


Tokenized input shape: torch.Size([1, 58])
Last Token IDs: [1231, 265, 49154, 2103, 699, 2142, 771, 11911, 15, 189]
Last Decoded tokens: [' through', ' the', ' painter', ' must', ' you', ' see', ' his', ' skill', ',', '\n']


We'll start doing a full generation first using simply model generate. Let's stick to determinist for now.

In [None]:
inputs = tokenizer(prompt, return_tensors="pt", padding=False).to(device)

outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=200,
            repetition_penalty=1.2,
            do_sample=False,
            use_cache=True,
            pad_token_id=2,
            eos_token_id=2
        )

result = tokenizer.decode(outputs[0], skip_special_tokens=False)

print(result)

<|end_of_text|>Mine eye hath play'd the painter and hath stell'd
Thy beauty's form in table of my heart;
My body is the frame wherein 'tis held,
And perspective it is best painter's art.
For through the painter must you see his skill,
To make a picture perfect ; but to paint
A face or figure with such perfection as yours,
Or that which I have seen, would be an act so great
As if your eyes had been all one size together :
But this I say, though they are not equal yet,
I am sure their parts will never fail me till then:
Their figures shall be like mine, nor can I tell how
They may become more beautiful than myself.
THE FAIRY TALE OF THE BABYLONIAN NIGHTS.
" The fairies' daughters were called by the name of "Babylonians," because they came from Babylonia." — Sir Thomas Browne.
1 1. A Fairy Tale.
By Mr. John Hickes.
[This story was first published at London in 1609.]

It has long since fallen into disuse among us, being now only known to those who know nothing


This is very much a typicall "raw" base model output, that reads like an hallucinated book page, with even footnotes, background information, etc. Pleias-350m can also work out in multiple other languages, including 17th century French with historical spelling (something even Claude can struggle with).

In [None]:
# Our input prompt
prompt="""C’eſt cet amour payé de trop d’ingratitude,
Qui me rend en ces Lieux ſa preſence ſi rude.
Quelle honte pour moy ! Quel triomphe pour luy,
De voir mon infortune égaler ſon ennuy !
"""

inputs = tokenizer(prompt, return_tensors="pt", padding=False).to(device)

outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=200,
            repetition_penalty=1.2,
            do_sample=False,
            use_cache=True,
            pad_token_id=2,
            eos_token_id=2
        )

result = tokenizer.decode(outputs[0], skip_special_tokens=False)

print(result)

<|end_of_text|>C’eſt cet amour payé de trop d’ingratitude,
Qui me rend en ces Lieux ſa preſence ſi rude.
Quelle honte pour moy! Quel triomphe pour luy,
De voir mon infortune égaler ſon ennuy!
Je ne puis que pleurer, & je n'en veux point dire :
Mais ce quil y a de plus affreux dans le Monde,
Ce qui fait ma douleur eſt la perte du monde ;
Et c'eft à moi-meſine de l'avoir perdu.
Jamais il n'y eut un homme dont les vertus
N'eurent jamais tant de gloire auec mes défauts;
Que jai vu parmy eux tous une vertu fi belle,
Quelquefois elle ſe perdit avec leur éclat:
Car ils ont toujours eu des hommes bien faits,
Pour avoir été heureux quand on étoit malfait.
Il faut donc que nous ayons quelque chofe de bon,
Dont nos maux puſſent faire connoitre notre bonté.
Nous avons beau vouloir, mais nous ne pouvons rien.
Lui-même eft malheureux comme nous, &c.
A MONSIEUR


Now several issues. First you can notice we set up a "repetition penalty", otherwise we get the usual curse of small base models that are getting stuck in repetition loop. Basically the generation possibilities are too open-ended is something that disappear avec instruction tuning. What is nice with Reinforcement Learning: you can correc this bias while still keeping the base model "raw" in continuation mode without any overfitting on instruct set — so given how theses instruction sets have been made, without Chatgptization…

If we regenerate without the penalty we immediately see the issue:

In [None]:
outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=200,
            repetition_penalty=1,
            do_sample=False,
            use_cache=True,
            pad_token_id=2,
            eos_token_id=2
        )

result = tokenizer.decode(outputs[0], skip_special_tokens=False)

print(result)

<|end_of_text|>C’eſt cet amour payé de trop d’ingratitude,
Qui me rend en ces Lieux ſa preſence ſi rude.
Quelle honte pour moy! Quel triomphe pour luy,
De voir mon infortune égaler ſon ennuy!


Mais ſi je puis, je le puis, & ſi je le veux,


Je le puis, & je le veux, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,


Et je le veux, & je le puis, & je le veux,




The other immediate issue: sometimes the model lacks enough context to identify the submission as poetry. But if we really want to make a "poet" it would be even better to switch automatically in poetry mode.

Let's try with some Keats:

In [None]:
# Our input prompt
prompt="""Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest"""

inputs = tokenizer(prompt, return_tensors="pt", padding=False).to(device)

outputs = model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=200,
            repetition_penalty=1.2,
            do_sample=False,
            use_cache=True,
            pad_token_id=2,
            eos_token_id=2
        )

result = tokenizer.decode(outputs[0], skip_special_tokens=False)

print(result)

<|end_of_text|>Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest?
I have no more than a few hours left.
And if the sun should rise again,
The moon will be in her place ; and then
When Saturn shall return from his flight,
He'll find me here at last! "

" And you are not going away now?" said he, as she
stood looking down on him with an expression which was
not very pleasant. She had been thinking that it would
be better for her husband's health to go back to England;
but there were many things about which she could not
understand why they did not come together sooner or
later. The first thing she thought was how much happi-
ness might be expected when all these people who lived
in London met once more — ^how soon their happiness
would begin to increase! But what was the use of talk-
ing any longer about such matters? He must see them
again before long. They seemed so different from


As you notice immediately we do have some poetry at first and then it switch to novel mode.

Finally we can notice our poems do not have much structure. Instead we might prefer to have a clear demarcation into verses parts like quatrains or tercets.

## Setting up our function rewards

The good news now: all of theses problems are computable using a simple text algorithmic analysis. No need for LLM as judge or human evaluators, just a proper operationalization of the problem at hand. In an RL context this operationalization is called a "function reward".

In [None]:
import re
import vllm
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer

Now let's go back to problem 1: repetition. We are going to generate both a bad and a good query.

In [None]:
def text_generation(prompt, repetition_penalty=1):
  inputs = tokenizer(prompt, return_tensors="pt", padding=False).to(device)

  outputs = model.generate(
              inputs.input_ids,
              attention_mask=inputs.attention_mask,
              max_new_tokens=200,
              repetition_penalty=repetition_penalty,
              do_sample=False,
              use_cache=True,
              pad_token_id=2,
              eos_token_id=2
          )

  result = tokenizer.decode(outputs[0], skip_special_tokens=False)

  return result

We'll stick in deterministic mode by enforcing no repetition with the no repetition penalty, but obviously reinforcement learning will be done in non-deterministic mode:

In [None]:
# Our input prompt
prompt="""C’eſt cet amour payé de trop d’ingratitude,
Qui me rend en ces Lieux ſa preſence ſi rude.
Quelle honte pour moy ! Quel triomphe pour luy,
De voir mon infortune égaler ſon ennuy !
"""

good_generation = text_generation(prompt, repetition_penalty=1.2)
bad_generation = text_generation(prompt, repetition_penalty=1)

Now let's build a function to check the presence of repetition. I won't comment it in detail, suffice to say it checks both the occurrence of individual words and of n-grams.

In [None]:
def calculate_repetition_score(text):
    """Calculate a continuous repetition score between 0 and 1.
    Lower scores indicate more repetition, higher scores indicate more uniqueness.
    """
    # Clean and tokenize
    words = text.lower().split()
    if len(words) < 8:  # Very short texts get full score
        return 1.0

    # Calculate various repetition metrics
    from collections import Counter

    # 1. N-gram uniqueness scores (for different n-gram sizes)
    def get_ngram_uniqueness(n):
        ngrams = [' '.join(words[i:i+n]) for i in range(len(words)-n+1)]
        if not ngrams:
            return 1.0
        counts = Counter(ngrams)
        # Calculate ratio of unique n-grams to total n-grams
        uniqueness = len(counts) / len(ngrams)
        # Penalize heavily repeated n-grams
        repetition_penalty = sum(1 for count in counts.values() if count > 2) / len(counts) if counts else 0
        return uniqueness * (1 - repetition_penalty)

    # Get scores for different n-gram sizes
    unigram_score = get_ngram_uniqueness(1)
    bigram_score = get_ngram_uniqueness(2)
    trigram_score = get_ngram_uniqueness(3)
    fourgram_score = get_ngram_uniqueness(4)

    # 2. Local repetition (phrases repeating close to each other)
    def local_repetition_penalty():
        window_size = 10
        local_repetitions = 0
        for i in range(len(words) - window_size):
            window = words[i:i+window_size]
            window_counts = Counter(window)
            local_repetitions += sum(1 for count in window_counts.values() if count > 2)
        return 1 / (1 + local_repetitions)

    local_score = local_repetition_penalty()

    # 3. Vocabulary diversity
    vocab_diversity = len(set(words)) / len(words)

    # Combine scores with weights
    weights = {
        'unigram': 0.1,
        'bigram': 0.2,
        'trigram': 0.3,
        'fourgram': 0.2,
        'local': 0.1,
        'vocab': 0.1
    }

    final_score = (
        weights['unigram'] * unigram_score +
        weights['bigram'] * bigram_score +
        weights['trigram'] * trigram_score +
        weights['fourgram'] * fourgram_score +
        weights['local'] * local_score +
        weights['vocab'] * vocab_diversity
    )

    # Normalize to 0-1 range and make it more sensitive in the middle range
    from math import tanh
    normalized_score = tanh(2 * final_score) / 2 + 0.5

    return normalized_score

Finally, to make our rating function usable we need to set up a vectorized reward for all completions to come:

In [None]:
def no_repetition_reward_func(completions, **kwargs) -> list[float]:
    # Handle both string and conversational formats
    responses = []
    for completion in completions:
        if isinstance(completion, str):
            responses.append(completion)
        elif isinstance(completion, list):
            responses.append(completion[0]["content"])
        else:
            raise ValueError(f"Unexpected completion format: {type(completion)}")

    # Calculate continuous scores
    scores = [calculate_repetition_score(response) for response in responses]

    return scores

Now let's move on to the other issue: enforcing verses even with short prompts.

In [None]:
prompt="""Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest,
This cradle of my glory, this soft clime,
This calm luxuriance of blissful light"""

good_generation = text_generation(prompt, repetition_penalty=1.2)

prompt="""Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest"""

mixed_generation = text_generation(prompt, repetition_penalty=1.2)

prompt="""Saturn is fallen, am I too to fall?"""

bad_generation = text_generation(prompt, repetition_penalty=1.2)

print("Bad generation")
print(bad_generation)
print("Mixed generation")
print(mixed_generation)
print("Good generation")
print(good_generation)

Bad generation
<|end_of_text|>Saturn is fallen, am I too to fall?
And the sun will rise again.
The moon will be a-rising; and it shall shine as brightly as ever."
"I have heard of this," said Mr. Hickman, "but not in my own family or mine. It was only when we were young that our father had any idea what he meant by saying so much about his son's going into business. He did say something like this: 'If you are going to make money,' says your father, 'you must do it.' And then there came up from the other side of the house an old woman who sat on one side of him talking with her husband, and she told him all sorts of things which he could never understand at first sight. But after some time they began to talk together more freely than before, till finally their conversation became very animated, and they talked over everything for several hours without speaking another word. Then they went away, leaving me alone with them, and I thought
Mixed generation
<|end_of_text|>Saturn is fallen, a

Let's target first the bad generation (the one that switch nearly immediately to prose mode due to bad poetry inertia)

In [None]:
from statistics import mean

prompt = "Saturn is fallen, am I too to fall?"

#First we get the values from the initial prompt:
prompt_lines = [line.strip() for line in prompt.split('\n') if line.strip()]
prompt_lengths = [len(line) for line in prompt_lines]
target_length = mean(prompt_lengths)

lines = [line.strip() for line in bad_generation.split('\n') if line.strip()]

#Then we analyze the generations.
line_lengths = [len(line) for line in lines]
actual_mean = mean(line_lengths)
length_ratio = actual_mean / target_length
print(length_ratio)

uppercase_count = sum(1 for line in lines if line and line[0].isupper())
uppercase_ratio = uppercase_count / len(lines)

print(uppercase_ratio)

# Text is verse-like if:
# 1. Mean line length is within 30% of prompt
# 2. At least 80% of lines start with uppercase
is_verse_like = (0.7 <= length_ratio <= 1.3) and (uppercase_ratio >= 0.8)

print(is_verse_like)

6.642857142857143
0.5
False


So definitely a bad generation: lines are on average 6 times longer as the original prompt and only half of them start with an upper case.

At this point let's consolidate into a function:

In [None]:
from statistics import mean

def check_verse(prompt, completion):

  #First we get the values from the initial prompt:
  prompt_lines = [line.strip() for line in prompt.split('\n') if line.strip()]
  prompt_lengths = [len(line) for line in prompt_lines]
  target_length = mean(prompt_lengths)

  lines = [line.strip() for line in completion.split('\n') if line.strip()]

  #Then we analyze the generations.
  line_lengths = [len(line) for line in lines]
  actual_mean = mean(line_lengths)
  length_ratio = actual_mean / target_length

  print("Length ratio:")
  print(length_ratio)

  uppercase_count = sum(1 for line in lines if line and line[0].isupper())
  uppercase_ratio = uppercase_count / len(lines)

  print("Uppercase ratio:")
  print(uppercase_ratio)

  # Text is verse-like if:
  # 1. Mean line length is within 30% of prompt
  # 2. At least 80% of lines start with uppercase
  is_verse_like = (0.7 <= length_ratio <= 1.3) and (uppercase_ratio >= 0.8)

  return is_verse_like

In [None]:
print("Mixed generation")
print(check_verse(prompt, mixed_generation))

print("\nGoog generation")
print(check_verse(prompt, good_generation))

Mixed generation
Length ratio:
1.3473684210526315
Uppercase ratio:
0.3157894736842105
False

Goog generation
Length ratio:
1.1324675324675324
Uppercase ratio:
0.9545454545454546
True


So clearly function is working: the mixed generation nearly pass the length lines requirement (as the prose part is outputted line by line), but not the uppercase requirement. Good generation is all fine.

Similarly as previous, let's wrap it up into a function. This time, we include both the prompts and the completions:

In [None]:
from statistics import mean

def check_verse(prompt, completion):

  try:
    #First we get the values from the initial prompt:
    prompt_lines = [line.strip() for line in prompt.split('\n') if line.strip()]
    prompt_lengths = [len(line) for line in prompt_lines]
    target_length = mean(prompt_lengths)

    lines = [line.strip() for line in completion.split('\n') if line.strip()]

    #Then we analyze the generations.
    line_lengths = [len(line) for line in lines]
    actual_mean = mean(line_lengths)
    length_ratio = actual_mean / target_length

    uppercase_count = sum(1 for line in lines if line and line[0].isupper())
    uppercase_ratio = uppercase_count / len(lines)

    # Text is verse-like if:
    # 1. Mean line length is within 30% of prompt
    # 2. At least 80% of lines start with uppercase
    is_verse_like = (0.7 <= length_ratio <= 1.3) and (uppercase_ratio >= 0.8)

  except:
    is_verse_like=False

  return is_verse_like

def verse_reward_func(prompts, completions, **kwargs) -> list[float]:
    return [0.5 if check_verse(prompt, response) else 0.0 for prompt, response in zip(prompts, completions)]

Finally let's try to add some built-in structure. We won't aim for anything complex yet: just favoring poems formatted with blocks of four verses, the quatrain.

Since it willl be the most dramatic function we also include a print function: we'll be able to check in the way the learning progression of our RL model.

In [None]:
def check_quatrain(text):
    # Split text by double newlines to get stanzas
    raw_groups = text.split('\n\n')

    # Process each group to get non-empty lines
    verse_groups = []
    for group in raw_groups:
        lines = [line.strip() for line in group.split('\n') if line.strip()]
        if lines:  # Only add non-empty groups
            verse_groups.append(lines)

    if not verse_groups:
        return 0.0

    # Count groups of exactly 4 verses
    quatrain_count = sum(1 for group in verse_groups if len(group) == 4)

    print("Poem:")
    print(text)
    print("Quatrain estimation: " + str(quatrain_count))

    # Return reward equal to the number of quatrains found
    # If no quatrains are found, return 0
    return float(quatrain_count) if quatrain_count > 0 else 0.0

def quatrain_reward_func(prompts, completions, **kwargs) -> list[float]:
    return [check_quatrain(completion) for completion in completions]

Now let's check the quatrain detection method on a sonnet from Petrarch:

In [None]:
sonnet = """Ace non trovo, e non ho da far guerra;
E temo, e spero, ed ardo, e son’ un ghiaccio;
E volo sopra ’l cielo, e giaccio in terra;
E nulla stringo, e tutto ’l mondo abbraccio.

Tal m’ha in prigion, che non m’apre, nè serra,
Nè per suo mi riten, nè scioglie il laccio;
E non m’ancide Amor’, e non mi sferra,
Nè mi vuol vivo, nè mi trae d’impaccio.

Veggio senz'occhi; e non ho lingua, e grido;
E bramo di perir, e cheggio aita;
Ed ho in odio me stesso, ed amo altrui:

Pascomi di dolor; piangendo rido;
Egualmente mi spiace morte, e vita.
In questo stato son, Donna, per vui. """

check_quatrain(sonnet)

2.0

Two quatrain, it's all good: we are ready to go.

## Setting up the RL dataset

At this point we have the first part of the RL pipeline: two reward function that determines whether the generation was a success or not. All we need now is just a collection of seed prompt and, oh wait, where is there any good LLM dataset for poetry?

There are almost no one but, as it happens, I'm just in the middle of a big data parsing of Wikisource and this is as good as an occasion as any to introduce the largest set of public domain "verses": 207,554 individual verses randomized. Likely enough for some nice training.

In [None]:
import pandas as pd
dataset = pd.read_parquet("https://huggingface.co/datasets/PleIAs/verse-wikisource/resolve/main/verse_wikisource.parquet")

In [None]:
dataset

Unnamed: 0,verse,size_verse,order_verse,title,link
0,"Pensive to foster cares, careless of joys;",7,742,St. Peter's Complaint,https://en.wikisource.org/wiki/St%2E%5FPeter%2...
1,"On the obscure and fluctuating main,",6,8,"Elegiac Sonnets, and Other Poems, Volume 2, Th...",https://en.wikisource.org/wiki/Elegiac%5FSonne...
2,On the floor the poor mother groped madly abou...,15,81,The Last Bullet,https://en.wikisource.org/wiki/The%5FLast%5FBu...
3,"To tower up in completeness, trophy-like,",6,357,Balaustion's Adventure/V,https://en.wikisource.org/wiki/Balaustion%27s%...
4,The waves cleave not to him nor he to the waves;,11,24,Ode to Youth,https://en.wikisource.org/wiki/Ode%5Fto%5FYouth
...,...,...,...,...,...
207549,And meets me with ten thousand smiles!,7,22,Blockhead and Beehive,https://en.wikisource.org/wiki/Blockhead%5Fand...
207550,And all the souls that her burden madeCried ou...,12,185,The Story and Song of Black Roderick,https://en.wikisource.org/wiki/The%5FStory%5Fa...
207551,"Other times, stung by the œstrum of some swift...",10,11,The Bothie of Toper-na-fuosich/3,https://en.wikisource.org/wiki/The%5FBothie%5F...
207552,They smelled a dead one passing near!,7,94,The Czechoslovak Review/Volume 3/Spectre's Bride,https://en.wikisource.org/wiki/The%5FCzechoslo...


We are going to make some slight changes to get into prompt mode, including prepending a newline so that our RL pipeline gets at least a "clue" that we are expecting some poetry.

In [None]:
prompt_list = []

for verse in dataset["verse"].tolist()[0:1000]:
  prompt_list.append(f"{verse}\n")

dataset = Dataset.from_dict({'prompt': prompt_list})

dataset

Dataset({
    features: ['prompt'],
    num_rows: 1000
})

In [None]:
dataset[0]

{'prompt': 'Pensive to foster cares, careless of joys;\n'}

## Training

We have now everything to do some actual training.

Let's define some directory (not featured here, but if you want to keep persistent traces, you might want to connect your colab to the Drive).

In [None]:
output_dir="outputs/Pleias-350m-GRPO"
run_name="Pleias-350m-GRPO-Poetry"

Since it's a small model and there are anyway some instabilities issues, we're dropping Peft. Most of the arguments are rather standard (if you've done any fine tuning, you're in familiar territory). In contrast with the llama GRPO script I'm lowering the completion length as poems do not need to be that long anyway.

In [None]:
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-5,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=16,
    max_prompt_length=256,
    max_completion_length=200,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    log_on_each_node=False,
    use_vllm=True,
    vllm_gpu_memory_utilization=.3,
    vllm_device="cuda:0",
    report_to="none" #I'm disabling Wandb.
)

torch.distributed process group is initialized, but parallel_mode != ParallelMode.DISTRIBUTED. In order to use Torch DDP, launch your script with `python -m torch.distributed.launch


Unfortunately due to a persistent bug either on GRPO or tokenizer config side, we have to wrap the tokenizer (sorry for verbose code)

In [None]:
# First, initialize your tokenizer as before
tokenizer = AutoTokenizer.from_pretrained(model_name,
    padding_side="left"
)
tokenizer.eos_token = "<|end_of_text|>"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "left"

class TokenizerWrapper:
    def __init__(self, base_tokenizer):
        self.base_tokenizer = base_tokenizer

        # Explicitly set important attributes
        self.pad_token_id = base_tokenizer.pad_token_id
        self.eos_token_id = base_tokenizer.eos_token_id
        self.padding_side = base_tokenizer.padding_side
        self.vocab_size = base_tokenizer.vocab_size
        self.pad_token = base_tokenizer.pad_token
        self.eos_token = base_tokenizer.eos_token

        # Copy remaining attributes
        for attr_name in dir(base_tokenizer):
            if not attr_name.startswith('_') and not hasattr(self, attr_name):
                setattr(self, attr_name, getattr(base_tokenizer, attr_name))

    def __call__(self, *args, **kwargs):
        kwargs['return_token_type_ids'] = False
        outputs = self.base_tokenizer(*args, **kwargs)
        if isinstance(outputs, dict) and 'token_type_ids' in outputs:
            del outputs['token_type_ids']
        return outputs

    def batch_decode(self, *args, **kwargs):
        return self.base_tokenizer.batch_decode(*args, **kwargs)

    def decode(self, *args, **kwargs):
        return self.base_tokenizer.decode(*args, **kwargs)

# Wrap your tokenizer
wrapped_tokenizer = TokenizerWrapper(tokenizer)

# Create your trainer with the wrapped tokenizer
trainer = GRPOTrainer(
    model=model,
    processing_class=wrapped_tokenizer,
    reward_funcs=[
        no_repetition_reward_func,
        verse_reward_func,
        quatrain_reward_func
    ],
    args=training_args,
    train_dataset=dataset
)



INFO 02-02 19:43:56 config.py:526] This model supports multiple tasks: {'generate', 'score', 'embed', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 02-02 19:43:56 llm_engine.py:232] Initializing a V0 LLM engine (v0.7.1) with config: model='PleIAs/Pleias-350m-Preview', speculative_config=None, tokenizer='PleIAs/Pleias-350m-Preview', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='xgrammar'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=PleIAs/Pleias-350m-Preview, nu

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 02-02 19:44:08 model_runner.py:1116] Loading model weights took 0.6828 GB
INFO 02-02 19:44:19 worker.py:266] Memory profiling takes 5.65 seconds
INFO 02-02 19:44:19 worker.py:266] the current vLLM instance can use total_gpu_memory (39.56GiB) x gpu_memory_utilization (0.30) = 11.87GiB
INFO 02-02 19:44:19 worker.py:266] model weights take 0.68GiB; non_torch_memory takes 0.00GiB; PyTorch activation peak memory takes 0.60GiB; the rest of the memory reserved for KV Cache is 10.59GiB.
INFO 02-02 19:44:25 executor_base.py:108] # CUDA blocks: 13342, # CPU blocks: 5041
INFO 02-02 19:44:25 executor_base.py:113] Maximum concurrency for 2048 tokens per request: 104.23x
INFO 02-02 19:44:28 model_runner.py:1435] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occurs during cudagraph capture, consider decreasing `gpu_memory_u

Capturing CUDA graph shapes: 100%|██████████| 35/35 [06:36<00:00, 11.32s/it]

INFO 02-02 19:51:05 model_runner.py:1563] Graph capturing finished in 396 secs, took 0.14 GiB
INFO 02-02 19:51:05 llm_engine.py:429] init engine (profile, create kv cache, warmup model) took 416.48 seconds





In [None]:
trainer.train()

Poem:

Pattison — Boodle's to bis pleasure — Charles II.'s "Adonais."

Some hear what ye do. You young lady,

In love's most hummer martyrdom!

Bright eyes are glowing o'er your fair face.
Steadfast brow, and steady eye.
Digitized by VjOOQIC



228 SEVENTEENTH CENTURY.
The same glance on moonhght, and the same

Written words of passionate cry.
Never can it be forgot,

That now you teil me. When you will.
Or sound me a short threat.
Hammer-bird, in vain you seek to kill.
And not to pinch you at last.
Awake, awakon, my spirit.
And rouse me slowly from yon gew-day 1
CALL me, my darling — call me T



Sad is the
Quatrain estimation: 0
Poem:

Until they have pardon'd our sin and made

That sight, which proves such pain, a sight of wrath, 43

Which is to punishment all an assail.
Thus the painter, so that he may not rash

Such art-work, the painting must not depart;

Or the more pure, the more unpity'd saint

Never saw mortals naked, sick, or lame,

Spider-stricken, or any man but man.
The t

Step,Training Loss
1,0.0
2,-0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


[1;30;43mLe flux de sortie a été tronqué et ne contient que les 5000 dernières lignes.[0m
Who is an angel
That hath no stint
To me, but the best

Quatrain estimation: 3
Poem:
That fills her thro’ the air, and blesses

her,
The Captain of the sea

For the brave hearts of women
Who work and toil with their hands,
And deck their faces
And be like the ladies

in their hearts
That exclaim, "Oh, my brave little girls!
For to-morrow we will live
By our songs."

[Thorn, an English sea captain:]
A song, which the ladies sing
With the sailors,

In our North Sea
Who are twenty or

And are shipmates there.
For they are half-way on
The sea captain's business,
And the sailors

Will answer to it
With their voices
And with their hearts
And with their souls.
And it is so —

They are sweet friends
And it's theirs to kiss
And it's theirs to sing

The song when
The Colonel
Quatrain estimation: 3
Poem:
       "If the wind blew high
      With gentle power
  A last note, the wheeling ear
       Put on my 

TrainOutput(global_step=250, training_loss=0.005283422563225031, metrics={'train_runtime': 1074.7294, 'train_samples_per_second': 0.93, 'train_steps_per_second': 0.233, 'total_flos': 0.0, 'train_loss': 0.005283422563225031})

In [None]:
output_dir = "my_trained_model"
trainer.save_model(output_dir)

# 2. For inference, you can load the model and tokenizer
from transformers import AutoModelForCausalLM

# Load the saved model
model = AutoModelForCausalLM.from_pretrained(output_dir)
model.eval()  # Set to evaluation mode

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(65536, 1024)
    (layers): ModuleList(
      (0-25): 26 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=1024, out_features=1024, bias=False)
          (k_proj): Linear(in_features=1024, out_features=512, bias=False)
          (v_proj): Linear(in_features=1024, out_features=512, bias=False)
          (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=1024, out_features=2560, bias=False)
          (up_proj): Linear(in_features=1024, out_features=2560, bias=False)
          (down_proj): Linear(in_features=2560, out_features=1024, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((1024,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((1024,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((1024,), eps=1e-05)
    (rotary_emb): 

We are now ready to run some inference:

In [None]:
# 3. Test inference
def generate_text(prompt, model, tokenizer, max_length=200):
    inputs = tokenizer(prompt, return_tensors="pt", padding=True)

    # Move to the same device as model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Generate
    outputs = model.generate(
        **inputs,
        max_length=max_length,
        num_return_sequences=1,
        repetition_penalty=1,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id
    )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
from pprint import pprint

test_prompt = """Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest,
"""
generated_text = generate_text(test_prompt, model, wrapped_tokenizer)
print(f"Generated: {generated_text}")

Generated: Saturn is fallen, am I too to fall?
Am I to leave this haven of my rest,

And to part with all that life has made me,
That I might know what love was and what
It is.
For in the same place from which I flew,
I saw a young man so happy,

Who did not care for me at first;
To him he never went.
He had no love but the love he bore.
He had the heart of an old man,

And could love nothing else.
The woman whom he loved
Was the mother of my children:
She knew how to be kind,

And yet she could have none.
His love was only the love
Of her nature and his own ;
But her nature was different.
For she was beautiful.
Now Saturn is fallen,

And I lie in the tomb.
With our blood we shall meet again


A very nice thing with pure RL training: while nearly all of our prompt set is English, this does not harm at all the multilingual capacities of the original model. Instead, the new poetic rules transfer very well to a new language:

In [None]:
test_prompt = """Vers l’Azur attendri d’Octobre pâle et pur
"""
generated_text = generate_text(test_prompt, model, wrapped_tokenizer)
print(f"Generated: {generated_text}")

Generated: Vers l’Azur attendri d’Octobre pâle et pur
Le Ciel des vents et du vent est plein,
Et la terre ne se découvre plus guère.
Les nuages sont moins grands que ceux de ces jours;
On ne voit pas encore le matin; mais on croit voir
Depuis deux heures du matin les étoiles noires.
Du ciel une lumière qui vient de loin

Doit paraître en couleur, à la vue du monde ;
Elle monte aussi dans notre âme
Par ce mystère si doux
Pour nous faire penser aux choses célestes :

C’est elle dont il faut tenir compte...
Aussi cette nuit va-t-elle bientôt sonner
Dans nos yeux pour entendre les sons que produit
L'ombre du jour ou les rayons du soir.
Mais dès que nous aurons vu
Ce qui s’envole avec la lumière,

La connaissance aura besoin
De quelques
