# Groupe Relative Policy Optimization (GRPO)

Install the Hugging Face libraries to run this notebook.

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

Your goal is to fill in the `GRPOTrainer` class. You have two options (and you can do both):
* the "normal GRPO" with clipped surrogate objective
* or the "vanilla GRPO" with original objective

In [3]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from transformers import AutoConfig

class Embedder(nn.Module):
    def __init__(self, level, model_name = "gpt2"):
        super(Embedder, self).__init__()
        if level ==1:
            self.model = AutoModelForCausalLM.from_pretrained(model_name).get_input_embeddings().to(device)
        else:
            #config = AutoConfig.from_pretrained(model_name)
            # Create the model from configuration (with random weights)
            self.model = AutoModelForCausalLM.from_pretrained(model_name).transformer.to(device)
            self.model.wte = nn.Identity()
        self.level = level
    
    def forward(self, input):
        if self.level == 1:
            return self.model(input.squeeze(-1))
        
        out = self.model(inputs_embeds = input)[0][:, -1, :]
        return out

In [5]:
class Abstract_model(nn.Module):
    def __init__(self, level, model_name = "gpt2"):
        super(Abstract_model, self).__init__()
        if level==1:
            self.model = AutoModelForCausalLM.from_pretrained(model_name).transformer.to(device)
            self.model.wte = nn.Identity()
        else:
            config = AutoConfig.from_pretrained(model_name)
            # Create the model from configuration (with random weights)
            self.model = AutoModelForCausalLM.from_config(config).transformer.to(device)
            self.model.wte = nn.Identity()
    
    def forward(self, input):
        out = self.model(inputs_embeds = input)[0]
        return out

In [6]:
class Abstract_level(nn.Module):
    def __init__(self, level, model_name = "gpt2"):
        super(Abstract_level, self).__init__()
        self.abstract_size = level
        self.model_abstract = Abstract_model(level)

        self.level = level
    
    def forward(self, embeddings):
        output_abstract = self.model_abstract(embeddings) # (batch_size, nseq, embedding_size)
        loss = 0
        '''
        if self.level >1:
            output_abstract = output_abstract/torch.linalg.norm(output_abstract, dim=-1).unsqueeze(-1)
            loss = - torch.sum(output_abstract[:,:-1].reshape(-1, 768) * embeddings[:,1:].reshape(-1, 768), axis = -1).mean() #+ 0*torch.abs(torch.sum(embeddings[:,1:].reshape(-1, 768)*embeddings[:,:-1].reshape(-1, 768), axis = -1)).mean()
        '''
        return output_abstract, loss

In [None]:
class Token_pred(nn.Module):
    def __init__(self, model_name = "gpt2"):
        super(Token_pred, self).__init__()
        #self.layer = AutoModelForCausalLM.from_pretrained(model_name).lm_head.to(device)
        self.layer = nn.Linear(768*2, 50257).to(device)
        #self.lin2 = nn.Linear(768, 768).to(device) n
    
    def forward(self, input, condition):
        x = torch.cat([input, condition], dim = -1)
        #x = input + condition
        return self.layer(x)

In [8]:
class WorldModel_LLM(nn.Module):
    def __init__(self, level, use_abstract = True, model_name = "gpt2"):
        super(WorldModel_LLM, self).__init__()
        self.alevel = level
        self.aalevel = 4
        self.use_abstract = use_abstract

        self.token_embedder = Embedder(1)
        if use_abstract:
            self.lvl1_embedder = Embedder(self.alevel)

        self.lvl0_predictor = Abstract_level(1)
        if use_abstract:
            self.lvl1_predictor = Abstract_level(self.alevel)

        self.token_pred = Token_pred()

        self.criterion = nn.CrossEntropyLoss()

        self.lvl1_embed_save = None
    
    def forward(self, tokens):
        nseq = tokens.shape[1]//self.alevel
        self.aalevel = nseq
        tokens = tokens[:, :tokens.shape[1]//self.alevel*self.alevel]
        
        tokens = tokens.reshape(tokens.shape[0], -1, self.alevel).reshape(-1, self.alevel) # (batch_size * nseq, alevel)
        token_embedding = self.token_embedder(tokens) # (batch_size * nseq, alevel, embedding_size)

        if self.use_abstract:
            lvl1_embedding = self.lvl1_embedder(token_embedding).reshape(-1, nseq, 768) # (batch_size, nseq, embedding_size)
        
        if nseq>1:
            token_embedding = token_embedding.reshape(-1, nseq, self.alevel, 768)[:,1:].reshape(-1, self.alevel, 768)
            tokens = tokens.reshape(-1, nseq, self.alevel)[:,1:].reshape(-1, self.alevel)

        token_embedding_pred, _ = self.lvl0_predictor(token_embedding) # (batch_size * nseq, alevel, embedding_size)
        loss_abstract = 0
        if self.use_abstract:
            lvl1_embedding_pred, loss_abstract = self.lvl1_predictor(lvl1_embedding) # (batch_size, nseq, embedding_size)

            lvl1_embedding_pred = lvl1_embedding_pred[:,:-1].unsqueeze(2).repeat(1,1, self.alevel, 1).reshape(-1, self.alevel, 768)

        if self.use_abstract:
            tokens_pred = self.token_pred(token_embedding_pred, lvl1_embedding_pred)
        else:
            tokens_pred = self.token_pred(token_embedding_pred, 0)
        loss_token = self.criterion(tokens_pred[:,:-1].reshape(-1, 50257), tokens[:,1:].reshape(-1))
        return tokens_pred, loss_abstract, loss_token
    
    def generate(self, tokens, ntokens):
        with torch.no_grad():
            while tokens.shape[1] <= ntokens:
                if tokens.shape[1] <= self.alevel:
                    token_embedding = self.token_embedder(tokens)
                    token_embedding_pred, _ = self.lvl0_predictor(token_embedding)
                    token_embedding_pred = token_embedding_pred[:,-1]

                    if self.lvl1_embed_save is None:
                        lvl1_embedding = self.lvl1_embedder(token_embedding).reshape(-1, 1, 768)
                        lvl1_embedding_pred, _ = self.lvl1_predictor(lvl1_embedding)
                        lvl1_embedding_pred = lvl1_embedding_pred[:,-1]
                        self.lvl1_embed_save = lvl1_embedding_pred
                    else:
                        lvl1_embedding_pred = self.lvl1_embed_save

                else:
                    nseq = tokens.shape[1]//self.alevel
                    reste = tokens.shape[1]%self.alevel
                    token_embedding = self.token_embedder(tokens)
                    token_embedding_pred, _ = self.lvl0_predictor(token_embedding[:, 1-self.alevel:])
                    token_embedding_pred = token_embedding_pred[:,-1]

                    if reste==0:
                        embed_for_lvl1 = token_embedding[:, :tokens.shape[1]//self.alevel*self.alevel].reshape(tokens.shape[0], -1, self.alevel, 768).reshape(-1, self.alevel, 768)
                        lvl1_embedding = self.lvl1_embedder(embed_for_lvl1).reshape(-1, nseq, 768)
                        lvl1_embedding_pred, _ = self.lvl1_predictor(lvl1_embedding[:, -self.aalevel-1:])
                        lvl1_embedding_pred = lvl1_embedding_pred[:,-2]
                        self.lvl1_embed_save = lvl1_embedding_pred
                    else:
                        lvl1_embedding_pred = self.lvl1_embed_save

                tokens_pred = self.token_pred(token_embedding_pred, lvl1_embedding_pred)
                next_token = torch.multinomial(tokens_pred.softmax(-1), 1)
                tokens = torch.cat([tokens, next_token], dim = 1)
        return tokens

In [9]:
text = "This is an example of me and of"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
print(input_ids.shape)

wmllm = WorldModel_LLM(16).to(device)
with torch.no_grad():
    #op = wmllm(input_ids)
    op = wmllm.generate(input_ids, 15)
    tokenizer.decode(op[0].tolist())


torch.Size([1, 8])


In [11]:
wmllm.aalevel

4

In [12]:
text = "This is an example sentence. What do you think of this simple fact: i "
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
print(input_ids.shape)

wmllm = Simple_LLM().to(device)
res,_, loss1 = wmllm(input_ids)
print(res.shape, loss1)


torch.Size([1, 17])
torch.Size([1, 17, 50257]) tensor(4.2338, device='cuda:0', grad_fn=<NllLossBackward0>)


In [13]:
op = wmllm.generate(input_ids, 50)
tokenizer.decode(op[0].tolist())

'This is an example sentence. What do you think of this simple fact: i \xa0have a \xa0a \xa0a \xa0a \xa0a \xa0a \xa0a \xa0a \xa0a \xa0a \xa0a'

## Train

In [10]:
with open('llm_mva/sotu.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    text = text.replace('\n', ' ')
    text = text.replace('  ', '\n')


text = text.lower()
print(text[:1000])

mr. speaker, mr. president, and distinguished members of the house and senate, honored guests, and fellow citizens: less than 3 weeks ago, i joined you on the west front of this very building and, looking over the monuments to our proud past, offered you my hand in filling the next page of american history with a story of extended prosperity and continued peace. and tonight i'm back to offer you my plans as well. the hand remains extended; the sleeves are rolled up; america is waiting; and now we must produce. together, we can build a better america.
it is comforting to return to this historic chamber. here, 22 years ago, i first raised my hand to be sworn into public life. so, tonight i feel as if i'm returning home to friends. and i intend, in the months and years to come, to give you what friends deserve: frankness, respect, and my best judgment about ways to improve america's future. in return, i ask for an honest commitment to our common mission of progress. if we seize the opport

In [11]:
data = tokenizer.encode(text, return_tensors='pt').squeeze()
data.shape

Token indices sequence length is longer than the specified maximum sequence length for this model (260738 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([260738])

In [12]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

lendata = train_data.shape[0]
data.shape, data[:50]

(torch.Size([260738]),
 tensor([43395,    13, 10834,    11,   285,    81,    13,  1893,    11,   290,
         18876,  1866,   286,   262,  2156,   290, 34548,    11, 21014, 10650,
            11,   290,  5891,  4290,    25,  1342,   621,   513,  2745,  2084,
            11,  1312,  5399,   345,   319,   262,  7421,  2166,   286,   428,
           845,  2615,   290,    11,  2045,   625,   262, 28814,   284,   674]))

In [13]:
context_length = 512
batch_size = 8
 
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length - 1, (batch_size,))
    X = torch.stack([data[i:i+context_length] for i in ix])
    return X

X = get_batch("train")
X.shape

torch.Size([8, 512])

In [14]:
def  validation(model):
    # Define hyperparameters
    epoch = 1

    # Training loop
    with torch.no_grad():
        model.eval()
        total_loss = 0
        num =5
        for i in range(num):
            # Get batch
            X = get_batch('val')
            
            # Forward pass
            output, loss_abstract, loss_token = model(X.to(device))
            loss = loss_abstract + loss_token
            
            total_loss += loss.item()  

        avg_loss = total_loss / num 
        print(f"Validation: {epoch+1}, Average Loss: {     avg_loss}")
    return avg_loss

In [15]:
model = WorldModel_LLM(32).to(device)
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#model.load_state_dict(torch.load("model.pt"))
best_loss = 10000


In [None]:
# Define hyperparameters
epochs = 6

# Define optimizer

print(int(1.33*lendata/(batch_size*context_length)))

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for i in range(0,int(1.33*lendata/(batch_size*context_length))):
        # Get batch
        X = get_batch('train')
        
        # Forward pass
        optimizer.zero_grad()
        output, loss_abstract, loss_token = model(X.to(device))
        loss = loss_abstract + loss_token

        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()  
         
        if i % 1 == 0:
            print(f"Epoch: {epoch+1}, Step: {i}, loss_token: {loss_token.item()}")
    
    avg_loss = total_loss/int(1.33*lendata/(batch_size*context_length))
    print(f"Epoch: {epoch+1}, Average Loss: {     avg_loss}")

    if epoch > -1:
        val_loss = validation(model)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "model.pt")
            print("Model saved")



76
Epoch: 1, Step: 0, loss_token: 4.844322681427002
Epoch: 1, Step: 1, loss_token: 4.419561862945557
Epoch: 1, Step: 2, loss_token: 4.005755424499512
Epoch: 1, Step: 3, loss_token: 4.0972371101379395
Epoch: 1, Step: 4, loss_token: 3.9675912857055664
Epoch: 1, Step: 5, loss_token: 3.9356353282928467
Epoch: 1, Step: 6, loss_token: 3.8363096714019775
Epoch: 1, Step: 7, loss_token: 4.015597820281982
Epoch: 1, Step: 8, loss_token: 3.6146605014801025
Epoch: 1, Step: 9, loss_token: 3.837869882583618
Epoch: 1, Step: 10, loss_token: 4.036443710327148
Epoch: 1, Step: 11, loss_token: 3.822021484375
Epoch: 1, Step: 12, loss_token: 3.77701997756958
Epoch: 1, Step: 13, loss_token: 3.7468626499176025
Epoch: 1, Step: 14, loss_token: 3.6518428325653076
Epoch: 1, Step: 15, loss_token: 3.598021984100342
Epoch: 1, Step: 16, loss_token: 3.7293009757995605
Epoch: 1, Step: 17, loss_token: 3.5807290077209473
Epoch: 1, Step: 18, loss_token: 3.741987943649292
Epoch: 1, Step: 19, loss_token: 3.6366846561431885
E

In [None]:
torch.save(model.state_dict(), 'model_weights.pth')  


In [25]:
text = "I will explain the meaning of life. The meaning of life is"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

with torch.no_grad():
    res = model.generate(input_ids, ntokens=300)
    output = tokenizer.decode(res[0].tolist())
print(output)


I will explain the meaning of life. The meaning of life is to serve it by serving freedom. and life in a free country is nothing but gratitude for a country which has serviced us in spite of the pressures of austerity and unemployment.
well, yes. by the time too are we ready. so let's back to the business of agriculture. as i said, the livestock emissions caps have been cut--a's focus. all of us at home today will agree. we must take action now to keep our farmers from overstering home mortgages deregulated--and to prevent americans putting more of the house on the backs of loans. they have the same strategy the federal government is just pushing forward: depreciate mortgages, savers, corporate americans cut to rock bottom mortgages for savers who don't know what's going on, government cannot keep down a financial recession. this year, our economy in succor, anything without appreciation for the commodity prices, did not have all the same characteristics six years ago. pieces of our co

In [None]:
text = "I will explain the meaning of life. The meaning of life is"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

with torch.no_grad():
    res = model.generate(input_ids, ntokens=300)
    output = tokenizer.decode(res[0].tolist())
print(output)


I will explain the meaning of life. The meaning of life is to serve it by serving freedom. and life in a free country is nothing but gratitude for a country which has serviced us in spite of the pressures of austerity and unemployment.
well, yes. by the time too are we ready. so let's back to the business of agriculture. as i said, the livestock emissions caps have been cut--a's focus. all of us at home today will agree. we must take action now to keep our farmers from overstering home mortgages deregulated--and to prevent americans putting more of the house on the backs of loans. they have the same strategy the federal government is just pushing forward: depreciate mortgages, savers, corporate americans cut to rock bottom mortgages for savers who don't know what's going on, government cannot keep down a financial recession. this year, our economy in succor, anything without appreciation for the commodity prices, did not have all the same characteristics six years ago. pieces of our co

In [26]:
i = 0
size = 32
while i*size < len(res[0]):
    print(tokenizer.decode(res[0][i*size:size+i*size].tolist()))
    print('-----------')
    i += 1

I will explain the meaning of life. The meaning of life is to serve it by serving freedom. and life in a free country is nothing but gratitude for a
-----------
 country which has serviced us in spite of the pressures of austerity and unemployment.
well, yes. by the time too are we ready. so let's
-----------
 back to the business of agriculture. as i said, the livestock emissions caps have been cut--a's focus. all of us at home today will agree.
-----------
 we must take action now to keep our farmers from overstering home mortgages deregulated--and to prevent americans putting more of the house on the backs of
-----------
 loans. they have the same strategy the federal government is just pushing forward: depreciate mortgages, savers, corporate americans cut to rock bottom mortgages for
-----------
 savers who don't know what's going on, government cannot keep down a financial recession. this year, our economy in succor, anything without appreciation for
-----------
 the commodity pri

In [24]:
model2 = AutoModelForCausalLM.from_pretrained(model_name).to(device)
res = model2.generate(input_ids, max_length=100)
output = tokenizer.decode(res[0])
print(output)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


I will explain the meaning of life. The meaning of life is that we are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not ours. We are living in a world that is not


In [27]:
validation(model)

Validation: 2, Average Loss: 4.411460208892822


4.411460208892822

In [None]:
Epoch: 1, Step: 0, loss_token: 3.2408559322357178
Epoch: 1, Step: 1, loss_token: 3.53310227394104
Epoch: 1, Step: 2, loss_token: 3.226858615875244
Epoch: 1, Step: 3, loss_token: 3.383723258972168
Epoch: 1, Step: 4, loss_token: 3.362529754638672
Epoch: 1, Step: 5, loss_token: 3.258737325668335
Epoch: 1, Step: 6, loss_token: 3.2521677017211914
Epoch: 1, Step: 7, loss_token: 3.2775447368621826
Epoch: 1, Step: 8, loss_token: 3.3377628326416016
Epoch: 1, Step: 9, loss_token: 3.27563214302063
Epoch: 1, Average Loss: 0.001695163107930288

In [None]:
Epoch: 1, Step: 0, loss_token: 3.429766893386841
Epoch: 1, Step: 1, loss_token: 3.404374122619629
Epoch: 1, Step: 2, loss_token: 3.586282968521118
Epoch: 1, Step: 3, loss_token: 3.647005796432495
Epoch: 1, Step: 4, loss_token: 3.5233004093170166
Epoch: 1, Step: 5, loss_token: 3.4722964763641357
Epoch: 1, Step: 6, loss_token: 3.5063366889953613
Epoch: 1, Step: 7, loss_token: 3.5912227630615234
Epoch: 1, Step: 8, loss_token: 3.6837527751922607
Epoch: 1, Step: 9, loss_token: 3.5423386096954346
Epoch: 1, Average Loss: 0.0018095974177236417

## Other

In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("all-MiniLM-L6-v2")  # A lightweight model for embeddings

def get_embedding(text):
    return model.encode(text, convert_to_tensor=True)  # Returns a tensor



torch.Size([384])


In [None]:
text = "This is an example sentence."
text2 = "that sentence is an example"
embedding = get_embedding(text)
embedding2 = get_embedding(text2)

torch.sum(embedding * embedding2) / (torch.norm(embedding) * torch.norm(embedding2))

tensor(0.8691, device='cuda:0')

In [95]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
  (2): Normalize()
)