In [2]:
import numpy as np
import random
from accelerate import Accelerator
from utils import *
import grpo_utils

model_name = "models/sft_SmolLM-135M-Instruct"
batch_size = 2
n_rollouts = 3
buffer_size = 6
max_new_tokens = 100


# load essentials
llm = load_model(model_name) # For full finetuning
# llm = load_peft_model(model_name) # For only lora weights training
tokenizer = load_tokenizer(model_name)
dataloader = get_dataloader("syllogism", tokenizer)
optimizer = torch.optim.Adam(llm.parameters(), lr=1e-5)


# Initialize accelerator
accelerator = Accelerator()
llm, tokenizer, dataloader, optimizer = accelerator.prepare(
    llm, tokenizer, dataloader, optimizer
)



OSError: models/sft_SmolLM-135M-Instruct is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `hf auth login` or by passing `token=<your_token>`

In [None]:
llm.config

In [None]:
batch = next(iter(dataloader))
batch.keys()

In [None]:
batch["inputs"]["attention_mask"]

In [None]:
print(tokenizer.decode(batch["inputs"]["input_ids"][0]))

In [None]:
# print(tokenizer.batch_decode(batch["inputs"]["input_ids"])[0])

In [None]:
print(batch["validator"][1])

# Data Collection Step

In [7]:
batch = next(iter(dataloader))
input_ids = batch["inputs"]["input_ids"]
attention_mask = batch["inputs"]["attention_mask"]
validator = batch["validator"]
input_size = input_ids.shape[1]

with torch.no_grad():
    full_responses = llm.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        max_new_tokens=max_new_tokens, # 100
        do_sample=True,
        top_p=0.95,
        num_return_sequences=n_rollouts,
        temperature=1,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    assistant_responses = full_responses[:, input_size:]

    # Calculate the logits for each selected tokens
    log_probs = grpo_utils.calculate_logits(llm, full_responses, attention_mask)

    # Convert tokens to string
    decoded_responses = tokenizer.batch_decode(
        assistant_responses, skip_special_tokens=True
    )

    # model_responses = [batch_size*n_rollouts, max_new_tokens]
    rewards = grpo_utils.calculate_rewards(
        decoded_responses, np.repeat(validator, n_rollouts)
    )

    # advantages = [batch_size, n_rollouts]
    rewards = np.reshape(rewards, [batch_size, n_rollouts])
    advantages = (rewards - np.mean(rewards, axis=1, keepdims=True)) / (
        np.std(rewards, axis=1, keepdims=True) + 1e-8
    )

    advantages = advantages.reshape(-1, 1)
    advantages = torch.tensor(advantages, dtype=torch.float32).to(llm.device)

    padded_tokens = (full_responses != tokenizer.eos_token_id).int()
    response_start_idx = padded_tokens.argmax(axis=-1)
    response_end_idx = padded_tokens.shape[1] - torch.flip(
        padded_tokens, dims=[1]
    ).argmax(dim=1)

    response_mask = torch.zeros_like(padded_tokens)
    for i in range(len(response_mask)):
        response_mask[i, input_size : response_end_idx[i]] = 1
experience = [
    {
        "input_sequence": full_responses[
            i, response_start_idx[i] : response_end_idx[i]
        ],
        "log_probs": log_probs[i, response_start_idx[i] : response_end_idx[i]],
        "response_mask": response_mask[
            i, response_start_idx[i] : response_end_idx[i]
        ],
        "advantages": advantages[i],
    }
    for i in range(advantages.shape[0])
]


## Training Step

Run this block a couple of times and you should see the loss go down!
If it doesn't, decrease the learning rate.

In [None]:

full_sequence = left_pad([b["input_sequence"] for b in experience]).to(
    accelerator.device
)
attention_mask = left_pad(
    [torch.ones_like(b["input_sequence"]) for b in experience], 0
).to(accelerator.device)
old_log_probs = left_pad([b["log_probs"] for b in experience]).to(accelerator.device)
response_mask = left_pad([b["response_mask"] for b in experience]).to(accelerator.device)
advantages = (
    torch.cat([b["advantages"] for b in experience], dim=0)
    .unsqueeze(-1)
    .to(accelerator.device)
)

log_probs = grpo_utils.calculate_logits(llm, full_sequence, attention_mask)

loss = grpo_utils.calculate_grpo_loss(
    log_probs=log_probs,
    old_log_probs=old_log_probs,
    response_mask=response_mask,
    advantages=advantages,
)
print(loss)

accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()



# Visualization helper

In [12]:
import pandas as pd
def get_df(idx, clip_epsilon=1):
    num_tokens = 25
    importance_sampling_ratio = torch.exp(log_probs - old_log_probs)

    unclipped_ratio = importance_sampling_ratio
    clipped_ratio = torch.clamp(
        importance_sampling_ratio, 1 - clip_epsilon, 1 + clip_epsilon
    )

    clipped_loss = clipped_ratio * advantages
    unclipped_loss = unclipped_ratio * advantages

    loss = -torch.min(unclipped_loss, clipped_loss)
    tokens = [tokenizer.decode([token]) for token in full_sequence[idx, -num_tokens:]]
    df = pd.DataFrame(
        {
            "advantages": advantages[idx].item(),
            "old_log_probs": old_log_probs[idx, -num_tokens:].detach().cpu().float().numpy(),
            "log_probs": log_probs[idx, -num_tokens:].detach().cpu().float().numpy(),
            "ratio": importance_sampling_ratio[idx, -num_tokens:].detach().cpu().float().numpy(),
            "unclipped_ratio": unclipped_ratio[idx, -num_tokens:].detach().cpu().float().numpy(),
            f"clipped_ratio (eps={clip_epsilon})": clipped_ratio[idx, -num_tokens:].detach().cpu().float().numpy(),
            "unclipped_loss": unclipped_loss[idx, -num_tokens:].detach().cpu().float().numpy(),
            f"clipped_loss (eps={clip_epsilon})": clipped_loss[idx, -num_tokens:].detach().cpu().float().numpy(),
            "loss": loss[idx, -num_tokens:].detach().cpu().float().numpy(),
        }
    )
    df.index = tokens
    return df

In [None]:
df = get_df(0, clip_epsilon=0.2)
df

# Visualize log probs change

In [14]:
import matplotlib.pyplot as plt
import torch

def plot_log_probs(idx):
    gap = 4
    height = 1
    num_tokens = 25

    plt.figure(figsize=(6, 6))
    tokens = [tokenizer.decode([token]) for token in full_sequence[idx, -num_tokens:]]
    plt.barh(
        [gap*i for i in range(num_tokens)],
        log_probs[idx, -num_tokens:].detach().cpu().float().numpy(),
        label="log_probs",
        height=height,
    )
    plt.barh(
        [gap*i+1 for i in range(num_tokens)],
        old_log_probs[idx, -num_tokens:].detach().cpu().float().numpy(),
        label="old_log_probs",
        height=height,
    )
    plt.xlabel("log_probs")
    plt.yticks(range(0, gap*num_tokens, gap), reversed(tokens))
    plt.legend(loc='upper right')
    plt.title(f"Advantage: {advantages[idx].item():.2f}")

In [None]:

idx = 0
plot_log_probs(idx)

In [None]:
print(tokenizer.decode(full_sequence[idx], skip_special_tokens=False))

In [None]:
advantages[idx]