In [None]:
import numpy as np
import random
import torch
from accelerate import Accelerator
from utils import *
import grpo_utils

from grpo_utils import load_model, load_tokenizer, get_dataloader

model_name = "HuggingFaceTB/SmolLM-135M-Instruct"
batch_size = 2
# number of sequences generated
n_rollouts = 3
buffer_size = 6
# thinking space
max_new_tokens = 100

llm = load_model(model_name)
for param in llm.parameters():
    param.requires_grad = True
tokenizer = load_tokenizer(model_name)
dataloader = get_dataloader("webinstruct", tokenizer)
optimizer = torch.optim.AdamW(llm.parameters(), lr = 1e-5)

accelerator = Accelerator()
llm, tokenizer, dataloader, optimizer = accelerator.prepare(llm, tokenizer, dataloader, optimizer)

Loaded 2335220 samples from WebInstructSub


In [2]:
batch = next(iter(dataloader))
batch.keys()

dict_keys(['inputs', 'validator'])

In [3]:
batch["validator"]

[{'question': 'I need to prove: If $p\\equiv 5 \\bmod 16$ then there exists $x\\in \\mathbb{Z}_2$ such that $px^4=1$. How can I approach this proof?',
  'expected_answer': 'According to the corollary of Theorem 4 in Serre\'s "Cours d\'arithmétique", there exists a solution in $\\mathbb{Z}_2$ if $p$ is congruent to $1$ modulo a higher power of $2$. Since $p\\equiv 5 \\bmod 16$, it is congruent to $1$ modulo $8$. Therefore, there exists a solution $x\\in \\mathbb{Z}_2$ such that $px^4=1$.',
  'source': 'mathstackexchange',
  'orig_question': "I need to prove: If $p\\equiv 1 \\bmod 16$ then there exists $x\\in \\mathbb{Z}_2$ ($2$-adic ring) so that $$px^4=1.$$ I'm not sure how to start this. I thought maybe to use some results on quadratic form $py^2=1$ and then to prove that solution $y$ has root in $\\mathbb{Z}_2$?",
  'orig_answer': 'Unless you made a misprint, this does not seem possible. If p is simply odd, then it is invertible in $Z_2$, and the problem amounts to extract a 4-th roo

In [4]:
batch["inputs"].keys()

dict_keys(['input_ids', 'attention_mask'])

In [10]:
batch["inputs"]["input_ids"]

tensor([[   57,   737,   288,  ...,     2,     2,     2],
        [16865,   260,  8431,  ...,     2,     2,     2]], device='mps:0')

In [None]:
batch["inputs"]["input_ids"].shape

torch.Size([2, 512])

In [6]:
print(tokenizer.decode(batch["inputs"]["input_ids"][0]))

I need to prove: If $p\equiv 5 \bmod 16$ then there exists $x\in \mathbb{Z}_2$ such that $px^4=1$. How can I approach this proof?<|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><|im_end|><

In [11]:
batch["validator"]

[{'question': 'I need to prove: If $p\\equiv 5 \\bmod 16$ then there exists $x\\in \\mathbb{Z}_2$ such that $px^4=1$. How can I approach this proof?',
  'expected_answer': 'According to the corollary of Theorem 4 in Serre\'s "Cours d\'arithmétique", there exists a solution in $\\mathbb{Z}_2$ if $p$ is congruent to $1$ modulo a higher power of $2$. Since $p\\equiv 5 \\bmod 16$, it is congruent to $1$ modulo $8$. Therefore, there exists a solution $x\\in \\mathbb{Z}_2$ such that $px^4=1$.',
  'source': 'mathstackexchange',
  'orig_question': "I need to prove: If $p\\equiv 1 \\bmod 16$ then there exists $x\\in \\mathbb{Z}_2$ ($2$-adic ring) so that $$px^4=1.$$ I'm not sure how to start this. I thought maybe to use some results on quadratic form $py^2=1$ and then to prove that solution $y$ has root in $\\mathbb{Z}_2$?",
  'orig_answer': 'Unless you made a misprint, this does not seem possible. If p is simply odd, then it is invertible in $Z_2$, and the problem amounts to extract a 4-th roo

In [12]:
input_ids = batch["inputs"]["input_ids"]
attention_mask = batch["inputs"]["attention_mask"]
validator = batch["validator"]
input_size = input_ids.shape[1]

In [None]:
with torch.no_grad():
    full_responses = llm.generate(
        input_ids = input_ids,
        attention_mask = attention_mask,
        max_new_tokens = max_new_tokens,
        do_sample = True,
        top_p = 0.95,
        num_return_sequences = n_rollouts,
        temperature = 1,
        eos_token_id = tokenizer.eos_token_id)

    assistant_reponses = full_responses[:, input_size:]

    log_probs = grpo_utils.calculate_logits(llm, full_responses, attention_mask)