In [1]:
import numpy as np
import random
import torch
from accelerate import Accelerator
from utils import *
import grpo_utils

from grpo_utils import load_model, load_tokenizer, get_dataloader

model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
batch_size = 2
n_rollouts = 3
buffer_size = 6
max_new_tokens = 100

llm = load_model(model_name)
for param in llm.parameters():
    param.requires_grad = True
tokenizer = load_tokenizer(model_name)
dataloader = get_dataloader("webinstruct", tokenizer)
optimizer = torch.optim.AdamW(llm.parameters(), lr = 1e-5)

accelerator = Accelerator()
llm, tokenizer, dataloader, optimizer = accelerator.prepare(llm, tokenizer, dataloader, optimizer)

Loading WebInstructSub dataset...
Loaded 2335220 samples from WebInstructSub


In [2]:
batch = next(iter(dataloader))
batch.keys()

dict_keys(['inputs', 'validator'])

In [3]:
batch["validator"]

[{'question': 'How do you find the standard form equation for a line that is perpendicular to the line defined by the equation \\( x - 7y = 9 \\) and passes through the point \\( (5, 4) \\)?',
  'expected_answer': 'The standard form equation for the line perpendicular to \\( x - 7y = 9 \\) and passing through \\( (5, 4) \\) is \\( 7x + y - 39 = 0 \\).\n\nExplanation:\n1. The slope of the given line \\( x - 7y = 9 \\) can be found by rewriting it in standard form: \\( x - 7y - 9 = 0 \\), which gives \\( a = 1 \\) and \\( b = -7 \\). The slope of this line is \\( -\\frac{a}{b} = -\\frac{1}{-7} = \\frac{1}{7} \\).\n\n2. The slope of a line perpendicular to the given line is the negative reciprocal of its slope, which is \\( -\\frac{1}{\\frac{1}{7}} = -7 \\).\n\n3. Using the point-slope form, the equation of the line passing through \\( (5, 4) \\) with slope \\( -7 \\) is \\( \\frac{y - 4}{x - 5} = -7 \\).\n\n4. To express this in standard form, we solve for \\( y \\) and rearrange terms:\

In [4]:
batch["inputs"].keys()

dict_keys(['input_ids', 'attention_mask'])

In [5]:
batch["inputs"]["input_ids"].shape

torch.Size([2, 512])

In [6]:
print(tokenizer.decode(batch["inputs"]["input_ids"][0]))

How do you find the standard form equation for a line that is perpendicular to the line defined by the equation \( x - 7y = 9 \) and passes through the point \( (5, 4) \)?<|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|