# Groupe Relative Policy Optimization (GRPO)

Install the Hugging Face libraries to run this notebook.

In [1]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from transformers import AutoConfig

class Embedder(nn.Module):
    def __init__(self, level, predictor=None, model_name = "gpt2"):
        super(Embedder, self).__init__()
        if level ==1:
            self.model = AutoModelForCausalLM.from_pretrained(model_name).get_input_embeddings().to(device)
            self.model.requires_grad_(False)
        else:
            if predictor is None:
                self.model = AutoModelForCausalLM.from_pretrained(model_name).transformer.to(device)
                self.model.wte = nn.Identity()
            else:
                self.model = predictor
        self.level = level
        self.cls_token = nn.Parameter(torch.randn(1, 1, 768))
        self.cls_token.requires_grad_(True)
        self.no_pred = predictor is None
    
    def forward(self, input):
        if self.level == 1:
            return self.model(input.squeeze(-1))
        cls = self.cls_token.repeat(input.size(0), 1, 1)
        input = torch.cat([input, cls], dim=1)
        if self.no_pred:
            out = self.model(inputs_embeds = input)[0][:, -1, :]
        else:
            out = self.model(input)[:,-1]
        return out

In [5]:
class Abstract_level(nn.Module):
    def __init__(self, level, model_name = "gpt2"):
        super(Abstract_level, self).__init__()
        self.abstract_size = level

        self.level = level
        if level==1:
            self.model = AutoModelForCausalLM.from_pretrained(model_name).transformer.to(device)
            self.model.wte = nn.Identity()
            self.model.requires_grad_(False)
        else:
            config = AutoConfig.from_pretrained(model_name)
            # Create the model from configuration (with random weights)
            self.model = AutoModelForCausalLM.from_config(config).transformer.to(device)
            self.model.wte = nn.Identity()
    
    def forward(self, input):
        out = self.model(inputs_embeds = input)[0]
        return out

In [6]:
class Token_pred(nn.Module):
    def __init__(self, model_name = "gpt2"):
        super(Token_pred, self).__init__()
        self.layer = AutoModelForCausalLM.from_pretrained(model_name).lm_head.to(device)
        self.layer.requires_grad_(False)
        #self.layer = nn.Linear(768*2, 50257).to(device)
        #self.lin2 = nn.Linear(768, 768).to(device) n
    
    def forward(self, input):
        #x = torch.cat([input, condition], dim = -1)
        #x = input + condition
        return self.layer(input)

In [None]:
class WorldModel_LLM(nn.Module):
    def __init__(self, level, use_abstract = True, sum_inside=False):
        super(WorldModel_LLM, self).__init__()
        self.alevel = level
        self.aalevel = 4
        self.use_abstract = use_abstract

        self.lvl0_predictor = Abstract_level(1)
        if use_abstract:
            self.lvl1_predictor = Abstract_level(self.alevel)

        self.token_embedder = Embedder(1)
        if use_abstract:
            self.lvl1_embedder = Embedder(self.alevel)
            #self.lvl1_embedder = Embedder(self.alevel, predictor=self.lvl0_predictor)

        self.token_pred = Token_pred()

        self.criterion = nn.CrossEntropyLoss()

        self.lvl1_embed_save = []
        self.lvl1_embed_pred_save = [torch.zeros(768).to(device)]

        self.sum_inside = sum_inside
    
    def forward(self, tokens):
        nseq = tokens.shape[1]//self.alevel
        self.aalevel = nseq
        tokens = tokens[:, :tokens.shape[1]//self.alevel*self.alevel]
        
        tokens = tokens.reshape(tokens.shape[0], -1, self.alevel).reshape(-1, self.alevel) # (batch_size * nseq, alevel)
        token_embedding = self.token_embedder(tokens) # (batch_size * nseq, alevel, embedding_size)

        if self.use_abstract:
            lvl1_embedding = self.lvl1_embedder(token_embedding).reshape(-1, nseq, 768) # (batch_size, nseq, embedding_size)
        
            lvl1_embedding_pred = self.lvl1_predictor(lvl1_embedding) # (batch_size, nseq, embedding_size)
            lvl1_embedding_pred = torch.cat((torch.zeros_like(lvl1_embedding_pred[:,:1]),lvl1_embedding_pred[:,:-1]), dim = 1)
            lvl1_embedding_pred = lvl1_embedding_pred.unsqueeze(2).repeat(1,1, self.alevel, 1).reshape(-1, self.alevel, 768) # (batch_size * nseq, alevel, embedding_size)

            if self.sum_inside:
                token_embedding_pred = self.lvl0_predictor(token_embedding[:,:-1] + lvl1_embedding_pred[:,:-1]) # (batch_size * nseq, alevel, embedding_size)
            else:
                token_embedding_pred = self.lvl0_predictor(token_embedding[:,:-1])+ lvl1_embedding_pred[:,:-1] # (batch_size * nseq, alevel, embedding_size)
            tokens_pred = self.token_pred(token_embedding_pred) # (batch_size * nseq, alevel, vocab_size)
        else:
            token_embedding_pred = self.lvl0_predictor(token_embedding[:,:-1]) # (batch_size * nseq, alevel, embedding_size)
            tokens_pred = self.token_pred(token_embedding_pred) # (batch_size * nseq, alevel, vocab_size)

        loss_token = self.criterion(tokens_pred.reshape(-1, 50257), tokens[:,1:].reshape(-1))
        return tokens_pred, loss_token
    
    def generate(self, tokens, ntokens):
        with torch.no_grad():
            while tokens.shape[1] <= ntokens:
                if tokens.shape[1] < self.alevel:
                    token_embedding = self.token_embedder(tokens)
                    token_embedding_pred = self.lvl0_predictor(token_embedding)[:,-1]

                else:
                    reste = tokens.shape[1]%self.alevel
                    token_embedding = self.token_embedder(tokens[:, -self.alevel:] )

                    if reste==0:
                        new_embedding = self.lvl1_embedder(token_embedding)[-1]
                        self.lvl1_embed_save.append(new_embedding)
                        lvl1_embedding= torch.stack(self.lvl1_embed_save)
                        self.lvl1_embed_pred_save.append(self.lvl1_predictor(lvl1_embedding[-self.aalevel:])[-1])
                        lvl1_embed_pred_save = torch.stack(self.lvl1_embed_pred_save)
                    
                        lvl1_embedding_pred = lvl1_embed_pred_save[-2].repeat(self.alevel-reste, 1)
                    else:
                        current_pred_lvl1 = lvl1_embed_pred_save[-1].repeat(reste, 1)
                        prev_pred_lvl1 = lvl1_embed_pred_save[-2].repeat(self.alevel-reste, 1)
                        lvl1_embedding_pred = torch.cat([prev_pred_lvl1, current_pred_lvl1], dim = 0)
                    
                    if self.sum_inside:
                        token_embedding_pred = self.lvl0_predictor(token_embedding[:,1:] + lvl1_embedding_pred[1:])[:,-1]
                    else:
                        token_embedding_pred = self.lvl0_predictor(token_embedding[:,1:])[:,-1] + lvl1_embedding_pred[1:][-1]

                tokens_pred = self.token_pred(token_embedding_pred)
                next_token = torch.multinomial(tokens_pred.softmax(-1), 1)
                tokens = torch.cat([tokens, next_token], dim = 1)
        return tokens

In [39]:
text = "This is an example of me and of"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)
print(input_ids.shape)

wmllm = WorldModel_LLM(2).to(device)
wmllm(input_ids)
with torch.no_grad():
    #op = wmllm(input_ids)
    op = wmllm.generate(input_ids, 32)
    print(tokenizer.decode(op[0].tolist()))


torch.Size([1, 8])
This is an example of me and of Starting:
k.


s in to on,



.



1:



In [9]:
wmllm.aalevel

4

## Dataset

In [8]:
with open('sotu.txt', 'r', encoding='utf-8') as f:
    text = f.read()
    text = text.replace('\n', ' ')
    text = text.replace('  ', '\n')


text = text.lower()
print(text[:1000])

mr. speaker, mr. president, and distinguished members of the house and senate, honored guests, and fellow citizens: less than 3 weeks ago, i joined you on the west front of this very building and, looking over the monuments to our proud past, offered you my hand in filling the next page of american history with a story of extended prosperity and continued peace. and tonight i'm back to offer you my plans as well. the hand remains extended; the sleeves are rolled up; america is waiting; and now we must produce. together, we can build a better america.
it is comforting to return to this historic chamber. here, 22 years ago, i first raised my hand to be sworn into public life. so, tonight i feel as if i'm returning home to friends. and i intend, in the months and years to come, to give you what friends deserve: frankness, respect, and my best judgment about ways to improve america's future. in return, i ask for an honest commitment to our common mission of progress. if we seize the opport

In [9]:
data = tokenizer.encode(text, return_tensors='pt').squeeze()
data.shape

Token indices sequence length is longer than the specified maximum sequence length for this model (260738 > 1024). Running this sequence through the model will result in indexing errors


torch.Size([260738])

In [10]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

lendata = train_data.shape[0]
data.shape, data[:50]

(torch.Size([260738]),
 tensor([43395,    13, 10834,    11,   285,    81,    13,  1893,    11,   290,
         18876,  1866,   286,   262,  2156,   290, 34548,    11, 21014, 10650,
            11,   290,  5891,  4290,    25,  1342,   621,   513,  2745,  2084,
            11,  1312,  5399,   345,   319,   262,  7421,  2166,   286,   428,
           845,  2615,   290,    11,  2045,   625,   262, 28814,   284,   674]))

In [11]:
context_length = 128
batch_size = 32
 
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_length - 1, (batch_size,))
    X = torch.stack([data[i:i+context_length] for i in ix])
    return X

X = get_batch("train")
X.shape

torch.Size([32, 128])

## Training

In [12]:
def  validation(model):
    # Define hyperparameters
    epoch = 1

    # Training loop
    with torch.no_grad():
        model.eval()
        total_loss = 0
        num =5
        for i in range(num):
            # Get batch
            X = get_batch('val')
            
            # Forward pass
            output, loss_token = model(X.to(device))
            loss = loss_token
            
            total_loss += loss.item()  

        avg_loss = total_loss / num 
        print(f"Validation: {epoch+1}, Average Loss: {     avg_loss}")
    return avg_loss

In [13]:
model = WorldModel_LLM(16).to(device)

In [14]:
learning_rate = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#model.load_state_dict(torch.load("MW_sum_inside.pt"))
best_loss = 10000

In [15]:
# Define hyperparameters
epochs = 10

# Define optimizer

print(int(1.33*lendata/(batch_size*context_length)))

# Training loop
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for i in range(0,int(1.33*lendata/(batch_size*context_length))):
        # Get batch
        X = get_batch('train')
        
        # Forward pass
        optimizer.zero_grad()
        output, loss_token = model(X.to(device))
        loss = loss_token

        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()  
         
        if i % 1 == 0:
            print(f"Epoch: {epoch+1}, Step: {i}, loss_token: {loss_token.item()}")
    
    avg_loss = total_loss/int(1.33*lendata/(batch_size*context_length))
    print(f"Epoch: {epoch+1}, Average Loss: {     avg_loss}")

    if epoch > -1:
        val_loss = validation(model)
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "model.pt")
            print("Model saved")



76
Epoch: 1, Step: 0, loss_token: 5.647838115692139
Epoch: 1, Step: 1, loss_token: 6.901641845703125
Epoch: 1, Step: 2, loss_token: 5.1891279220581055
Epoch: 1, Step: 3, loss_token: 5.018613338470459
Epoch: 1, Step: 4, loss_token: 4.989016056060791
Epoch: 1, Step: 5, loss_token: 4.783699035644531
Epoch: 1, Step: 6, loss_token: 4.966264247894287
Epoch: 1, Step: 7, loss_token: 4.771054744720459
Epoch: 1, Step: 8, loss_token: 4.942931652069092
Epoch: 1, Step: 9, loss_token: 4.854284763336182
Epoch: 1, Step: 10, loss_token: 4.758890628814697
Epoch: 1, Step: 11, loss_token: 4.736922264099121
Epoch: 1, Step: 12, loss_token: 4.703001499176025
Epoch: 1, Step: 13, loss_token: 4.7543768882751465
Epoch: 1, Step: 14, loss_token: 4.678609848022461
Epoch: 1, Step: 15, loss_token: 4.812869548797607
Epoch: 1, Step: 16, loss_token: 4.659309387207031
Epoch: 1, Step: 17, loss_token: 4.771054267883301
Epoch: 1, Step: 18, loss_token: 4.656765937805176
Epoch: 1, Step: 19, loss_token: 4.583284378051758
Epoch

KeyboardInterrupt: 

## Validation

In [None]:
model = WorldModel_LLM(16).to(device)

In [18]:
model.load_state_dict(torch.load("model.pt"))

  model.load_state_dict(torch.load("model.pt"))


<All keys matched successfully>

In [37]:
validation(model)

Validation: 2, Average Loss: 4.662626457214356


4.662626457214356

In [19]:
text = "I will explain the meaning of life. The meaning of life is"
input_ids = tokenizer.encode(text, return_tensors='pt').to(device)

with torch.no_grad():
    res = model.generate(input_ids, ntokens=300)
    output = tokenizer.decode(res[0].tolist())
print(output)


I will explain the meaning of life. The meaning of life is that our happiness begins with us.of finding our home again.our lifelong poor child who grew up restricted on food alone during recess and have little interest in her childhood childhood children. "if this last point is too far to burden again, and we could halt installation at this stage without the necessary delays, but we let it grow steadily until we could keep it within our budget goals. 24:00 so we did what we had to do: we changed departments making sure that educators could work together to help students achieve early success.

"One way of improving our solution is to extend the program to include our neighbors on double commutes than we do here.

Largest rail service fee in government process in decades in groceries waste-laden baggies. America's hops and potato industry has also been big, and remain major operations, meaning unemployed people now face up to five years of unemployment insurance. the important thing is 

Visualize sentences worresponding to different level 1 embedding values:

In [20]:
i = 0
size = 32
while i*size < len(res[0]):
    print(tokenizer.decode(res[0][i*size:size+i*size].tolist()))
    print('-----------')
    i += 1

I will explain the meaning of life. The meaning of life is that our happiness begins with us.of finding our home again.our lifelong poor child who grew
-----------
 up restricted on food alone during recess and have little interest in her childhood childhood children. "if this last point is too far to burden again, and we could
-----------
 halt installation at this stage without the necessary delays, but we let it grow steadily until we could keep it within our budget goals. 24:00 so we did
-----------
 what we had to do: we changed departments making sure that educators could work together to help students achieve early success.

"One way of improving our solution
-----------
 is to extend the program to include our neighbors on double commutes than we do here.

Largest rail service fee in government process in decades in
-----------
 groceries waste-laden baggies. America's hops and potato industry has also been big, and remain major operations, meaning unemployed people now face up