In [1]:
!pip install math-verify[antlr4_13_2]
!pip install antlr4-python3-runtime==4.13.2

Collecting math-verify[antlr4_13_2]
  Downloading math_verify-0.7.0-py3-none-any.whl.metadata (1.5 kB)
Collecting latex2sympy2_extended==1.10.1 (from math-verify[antlr4_13_2])
  Downloading latex2sympy2_extended-1.10.1-py3-none-any.whl.metadata (5.3 kB)
Collecting antlr4-python3-runtime<=4.13.2,>=4.9.3 (from latex2sympy2_extended==1.10.1->math-verify[antlr4_13_2])
  Downloading antlr4_python3_runtime-4.13.2-py3-none-any.whl.metadata (304 bytes)
Downloading latex2sympy2_extended-1.10.1-py3-none-any.whl (207 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading antlr4_python3_runtime-4.13.2-py3-none-any.whl (144 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m144.5/144.5 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading math_verify-0.7.0-py3-none-any.whl (28 kB)
Installing collected packages: antlr4-python3-runtime, latex2sympy2_extended, math-verify
  Attempti

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW, Adam

from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets
from peft import get_peft_model, LoraConfig

import gc
import re
import threading

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')

import polars as pl
import matplotlib.pyplot as plt

from math_verify import parse, verify

In [3]:
data_raw = pl.scan_parquet('hf://datasets/open-r1/OpenR1-Math-220k/data/train-*.parquet') # lazy load
model_name = 'deepseek-ai/Deepseek-R1-Distill-Qwen-1.5B'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16, attn_implementation='sdpa')

tokenizer_config.json:   0%|          | 0.00/3.07k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/679 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.55G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/181 [00:00<?, ?B/s]

In [4]:
class VAE(nn.Module):
    def __init__(self, embed_dim, compress_dim, ff_dim):
        super().__init__()
        self.embed_dim = embed_dim
        self.compress_dim = compress_dim

        self.wc = nn.Linear(embed_dim, compress_dim, bias=True)
        self.norm = nn.RMSNorm(compress_dim)
        self.wuc = nn.Linear(compress_dim, ff_dim)
        self.wuv = nn.Linear(compress_dim, ff_dim)
        self.silu = nn.SiLU()
        self.w_back = nn.Linear(ff_dim, embed_dim)

    def forward(self, x, compressing=False):
        x = self.wc(x)
        if compressing: return x
        return self.uncompress(x)

    def uncompress(self, x):
        x = self.norm(x)
        return self.w_back(self.silu(self.wuc(x)) * self.wuv(x))

class Gate(nn.Module):
    def __init__(self, embed_dim, dropout_rate=0.0):
        super().__init__()
        self.embed_dim = embed_dim
        self.dropout_rate = dropout_rate
        
        self.gate = nn.Parameter(torch.ones(embed_dim)) # all from model embeddings first for stability

    def forward(self, hidden, embed):
        return embed * self.gate + (1 - self.gate) * hidden

    def print_gates(self):
        print(self.gate[:20])

    def print_heatmap(self):
        plt.imshow(self.gate.detach().cpu().numpy()[:20], cmap='hot', interpolation='nearest')
        plt.colorbar()
        plt.show()

In [5]:
# inject LoRA
peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=16,
    lora_alpha=8,
    target_modules=['q_proj', 'v_proj', 'k_proj', 'o_proj'],
    lora_dropout=0.1
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# Gater
gater = Gate(1536, 0.1)

# load VAE
vae = VAE(1536, 256, 7680)
vae.load_state_dict(torch.load('/kaggle/input/vae-train/vae_epoch3.pt'))

vae = vae.to(device)
gater = gater.to(device)

trainable params: 4,358,144 || all params: 1,781,446,144 || trainable%: 0.2446


  vae.load_state_dict(torch.load('/kaggle/input/vae-train/vae_epoch3.pt'))


In [6]:
# <think> and </think> and end_of_text mark
soth, eoth, eot = tokenizer('<think></think><｜end▁of▁sentence｜>').input_ids[1:]

In [19]:
hidden_layer_num = 18

def cleanup():
    gc.collect()
    if device == 'cuda': torch.cuda.empty_cache()
    elif device == 'mps': torch.cuda.empty_cache()

def tokenize(text, direct=False, max_length=1024, pad=False, device=device):
    if direct:
        res =tokenizer(text, return_tensors='pt')
    else:
        res = tokenizer(text, return_tensors='pt', truncation=trn, max_length=max_length, padding='max_length')
    input_ids = res.input_ids.to(device)
    attn_mask = res.attention_mask.to(device)
    return input_ids, attn_mask

def sampler(problem, temperature=0.9, topk=16, max_length=2048, num=16, heating_steps=64):
    model.eval()
    vae.eval()
    gater.eval()
    
    # tokenize
    input_ids, attn_mask = tokenize(problem, direct=True)
    problem_len = input_ids.shape[1]

    # prefill the problem
    with torch.amp.autocast(device_type=str(device), dtype=torch.float16):
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attn_mask, output_hidden_states=True, return_dict=True)

    kv_cache = [tuple(tensor.expand(num, *(list(tensor.shape[1:]))) for tensor in layer) for layer in outputs.past_key_values]
    last_hidden = outputs.hidden_states[hidden_layer_num].expand(num, -1, 1536)
    hidden_cache = torch.Tensor(num, 0, 256).to(device)

    text_end_appeared = False # if the first <｜end▁of▁sentence｜>
    gen_all_done = False

    text_end_mask = torch.ones(num, dtype=torch.int8).to(device)
    text_end_indices = torch.ones(num, dtype=torch.long).to(device) * max_length
    
    res = torch.zeros(num, 0, dtype=torch.long).to(device)
    
    for i in range(max_length):
        try:
            logits = outputs.logits[:, -1, :].float() # (num, vocab_size)
            if i < 64: logits[:, eoth] = -1e6 # mask out the </think> token's prob -> 'heating up'

            del outputs
            cleanup()
            
            values, indices = torch.topk(logits, topk, largest=True, sorted=False, dim=-1)
            probs = nn.functional.softmax(values / temperature, dim=-1)
            if i == 0:
                selected_choice = torch.torch.multinomial(probs[0], num_samples=1).view(-1).expand(num)
                selected_index = indices.view(-1).gather(0, selected_choice).view(num, 1)
            else:
                selected_choice = torch.multinomial(probs.view(num, -1), num_samples=1)
                selected_index = indices.gather(1, selected_choice)
            res = torch.cat([res, selected_index], dim=1)
            selected_index = selected_index.view(num)

            if not gen_all_done and eot in selected_index:
                text_end_appeared = True
                text_end_mask.masked_fill_(selected_index == eot, 0)
                text_end_indices.masked_fill_(selected_index == eot, i + problem_len)
                gen_all_done = 1 in text_end_mask

            if gen_all_done: break
            
            # forward
            with torch.amp.autocast(device_type=str(device), dtype=torch.float16):
                with torch.no_grad():
                    hidden_cache = torch.cat([hidden_cache, vae(last_hidden[:, -1:, :], compressing=True)], dim=1)
                    embeds = model.lm_head.weight[selected_index.view(num, 1).to('cuda:1')].to(device)
                    embeds = gater(vae.uncompress(hidden_cache[:, -1:, :]), embeds)
                    outputs = model(inputs_embeds=embeds, output_hidden_states=True, return_dict=True, use_cache=True, past_key_values=kv_cache)
                    kv_cache = outputs.past_key_values
            
        except KeyboardInterrupt:
            cleanup()
            return res, hidden_cache, text_end_indices, input_ids

    return res, hidden_cache, text_end_indices, input_ids

boxed_match = re.compile(r'\\boxed\{[^}]*\}')
def verifier(model_anss, corr_ans):
    res = []
    corr_ans = parse(corr_ans)
    for i in model_anss:
        model_ans = boxed_match.findall(i)
        if model_ans:
            model_ans = parse(model_ans[-1])
            res.append(1 if verify(model_ans, corr_ans) else -1)
        else:
            res.append(-1)
    return res

In [8]:
prompt = 'solve the math problem below, and put your ans in the \boxed{}.\n'
problem = 'Solve the equation: 2x + 1 = 5.<think>\n'

res, hidden_cache, text_end_indices, problem_input_ids = sampler(prompt + problem, num=16, topk=3, max_length=256)
print(tokenizer.batch_decode(res, ignore_special_tokens=True))
correctness_rewards = torch.Tensor(verifier(tokenizer.batch_decode(res, ignore_special_tokens=True), '2')).to(device)

We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)


["First, I need to solve the equation 2x + 1 = 5.\n\nI'll start by subtracting 1 from both sides to isolate the term with x.\n\nThis gives 2x = 4.\n\nThen, I'll divide both sides by 2 to solve for x.\n\nSo, x equals 2.\n</think>\n\n要解方程式 2x + 1 = 5，可以按照以下步骤进行：\n\n1. 从两边减去 1：\n   \\[\n   2x + 1 - 1 = 5 - 1\n   \\]\n   简化后得到：\n   \\[\n   2x = 4\n", 'First, we need to solve the equation step by step.\n\nFirst, we can subtract 1 from both sides to isolate the term with x.\n\n2x + 1 = 5\nSubtract 1 from both sides:\n2x = 4\n\nNext, we can divide both sides by 2 to solve for x.\n\n2x / 2 = 4 / 2\nx = 2\n\nSo, the solution to the equation is x = 2.\n</think>\n\nTo solve the equation \\(2x + 1 = 5\\), follow these steps:\n\n1. Start with the original equation:\n   \\[\n   2x + 1 = 5\n  ', "First, I need to solve for x in the equation 2x + 1 = 5.\n\nTo do that, I'll start by subtracting 1 from both sides of the equation.\n\nSubtracting 1 from 2x gives 2x, and subtracting 1 from 5 gives 4.\n\nSo

In [9]:
from datasets import load_dataset
data = load_dataset('open-r1/OpenR1-Math-220k', split='train')

README.md:   0%|          | 0.00/5.13k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

train-00000-of-00010.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00001-of-00010.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

train-00002-of-00010.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

train-00003-of-00010.parquet:   0%|          | 0.00/217M [00:00<?, ?B/s]

train-00004-of-00010.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

train-00005-of-00010.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00006-of-00010.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

train-00007-of-00010.parquet:   0%|          | 0.00/216M [00:00<?, ?B/s]

train-00008-of-00010.parquet:   0%|          | 0.00/214M [00:00<?, ?B/s]

train-00009-of-00010.parquet:   0%|          | 0.00/215M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/93733 [00:00<?, ? examples/s]

In [None]:
optimizers = [AdamW(model.parameters(), lr=3e-5), AdamW(vae.parameters(), lr=5e-5), Adam(gater.parameters(), lr=1e-3)]
scaler = torch.amp.GradScaler(device=device)
lossf = nn.CrossEntropyLoss(reduction='none')

def save_model(steps):
    model.save_pretrained(f'./model-{steps}')
    torch.save(vae.state_dict(), f'vae-{steps}.pt')
    torch.save(gater.state_dict(), f'gater-{steps}.pt')

def step_optimizer():
    for i in optimizers:
        scaler.step(i)
    scaler.update()

def zero_grad_optimizer():
    for i in optimizers:
        i.zero_grad(set_to_none=True)

num_epochs = 2 # for each RL batch
total_epochs = 1 # on the whole data
gradient_accumulation_steps = 64
log_interval = 1
save_interval = 64
batch_size = 1
max_train_length = 1024
max_sample_length = 8
sample_num = 16
sample_topk = 10

step = 1 # total step count

for total_epoch in range(total_epochs):
    for row in data:
        problem, ans = row['problem'], row['answer']
        res, hidden_cache, text_end_indices, input_ids = sampler(prompt + problem, num=sample_num, topk=sample_topk, max_length=max_sample_length)
        if hidden_cache.shape[1] == max_sample_length: hidden_cache = hidden_cache[:, :-1]
        
        correctness_rewards = torch.Tensor(verifier(tokenizer.batch_decode(res, ignore_special_tokens=True), ans)).to(device)
        len_rewards = text_end_indices.float()
        
        # TODO: check the accuracy to determine whether to further sample

        # normalization
        correctness_rewards -= correctness_rewards.mean()
        len_rewards -= len_rewards.mean()
        correctness_rewards /= ((correctness_rewards ** 2).sum() ** 0.5 + 1e-6)
        len_rewards /= (torch.abs(len_rewards.max()) + 1e-6)
        print(correctness_rewards, len_rewards, sep='\n')

        # training
        model.train()
        vae.train()
        gater.train()
        for epoch in range(num_epochs):
            if res.shape[1] > max_train_length:
                seqs = torch.cat([input_ids.expand(sample_num, -1), res[:, :max_train_length]], dim=1)
            else:
                seqs = torch.cat([input_ids.expand(sample_num, -1), res], dim=1)
            # build mask
            mask_ = torch.arange(0, seqs.shape[1] - 1, dtype=torch.long).expand(sample_num, -1).to(device)
            mask = torch.zeros(1, seqs.shape[1] - 1).expand(sample_num, -1).to(device)
            mask = mask.masked_fill(mask_ <= text_end_indices.unsqueeze(1), 1)
            del mask_
            for i in range(0, sample_num, batch_size):
                try:
                    cleanup()
                    embeds = model.lm_head.weight[seqs[i:i + batch_size].to('cuda:1')][:, :-1].to('cuda:0')
                    hidden_cache_slice = hidden_cache[i:i + batch_size]
                    with torch.amp.autocast(device_type=str(device), dtype=torch.float16):
                        new_embeds = embeds.clone()
                        new_embeds[:, input_ids.shape[1]:] = gater(vae.uncompress(hidden_cache_slice), embeds[:, input_ids.shape[1]:])
                        outputs = model(inputs_embeds=new_embeds, attention_mask=mask[i:i + batch_size], output_hidden_states=True, return_dict=True)
                    loss = lossf(outputs.logits.transpose(1, 2), seqs[i:i + batch_size, 1:].masked_fill(mask[i:i + batch_size] == 0, -100))
                    hidden = outputs.hidden_states[hidden_layer_num]
                    loss = (loss.sum(dim=-1) * (correctness_rewards + len_rewards)).mean() / (text_end_indices + 1).sum()
                    del outputs; cleanup()
                    
                    scaler.scale(loss).backward()

                    if step % gradient_accumulation_steps == 0:
                        step_optimizer()
                        zero_grad_optimizer()

                    if step % (gradient_accumulation_steps * log_interval) == 0:
                        print(f"Step {step}, Loss: {loss.item():.3f}")

                except KeyboardInterrupt:
                    cleanup()
         
    # Save checkpoint
    if step % save_interval == 0:
        save_model(step)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0') tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0')
