In [5]:
from config import hf_cache_dir
import transformers
import torch
import os
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from jinja2 import Template
import pandas as pd
from utils_activations import rot13_alpha
from datasets import load_dataset
from grader import grade_answer

In [2]:
ds = load_dataset("meta-math/MetaMathQA", cache_dir=hf_cache_dir)

In [None]:
def get_model_output(prompt, model, tokenizer):
    content = f'Answer the following question, and format your answer as \'The answer is: <answer>\'. Here is the question: {prompt}'
    messages = [
        {"role": "user", "content": content},
    ]

    # Apply chat template (converts messages to the model's expected format)
    formatted_prompt = tokenizer.apply_chat_template(
        messages, 
        tokenize=False, 
        add_generation_prompt=True
    )
    inputs = tokenizer(formatted_prompt, return_tensors="pt")
    
    # Move inputs to the same device as model (if needed)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=2500,
            do_sample=False,  # Set to True for sampling
            temperature=0.75,  # Only used when do_sample=True
            top_p=0.9,        # Only used when do_sample=True
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            use_cache=False,
        )
    generated_tokens = outputs[0][inputs['input_ids'].shape[1]:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=False)
    
    print("Generated text:")
    print(generated_text)
    return generated_text

# SFT Model

In [21]:
model_path = '/workspace/data/axolotl-outputs/llama_deepseek_2epochs/merged'
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,  # Use float16 for memory efficiency; also could be float16
    device_map="auto",          # Automatically distribute across available GPUs
    trust_remote_code=True,
    low_cpu_mem_usage=True, 
    )
tokenizer = AutoTokenizer.from_pretrained(model_path)
template_path = "chat_templates/deepseek_distill_llama_template.jinja"
with open(template_path, "r") as file:
    jinja_template = file.read()
tokenizer.chat_template = jinja_template 

Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [None]:
# Select data
index = 21
prompt = ds['train'][index]['query']
solution = ds['train'][index]['response']

In [24]:
generated_text = get_model_output(prompt, model, tokenizer)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:


Bxnl, fb V unir guvf ceboyrz jurer V arrq gb svaq n 2k2 zngevk Z. Gur zngevk Z vf fhccbfrq gb genafsbez gjb qvssrerag irpgbef vagb bgure irpgbef. Gur svefg genafsbezngvba vf Z gvzrf gur irpgbe (3, 0) rdhnyf (6, 21), naq gur frpbaq genafsbezngvba vf Z gvzrf (-1, 5) rdhnyf (K, -17). Gurl nyfb gbyq zr gung gur nafjre vf gur zngevk [[2, 1], [7, -2]], naq V arrq gb svaq gur inyhr bs K. Uzz, bxnl, yrg zr guvax nobhg ubj gb nccebnpu guvf.

Svefg, V erzrzore gung n 2k2 zngevk pna or ercerfragrq nf:

Z = [[n, o],
     [p, q]]

Jura V zhygvcyl Z ol n irpgbe (k, l), vg tvirf zr (n*k + o*l, p*k + q*l). Fb, nccylvat guvf gb gur svefg rdhngvba, Z*(3, 0) = (6, 21). Yrg zr jevgr gung bhg:

Sbe gur svefg pbzcbarag: n*3 + o*0 = 6, juvpu fvzcyvsvrf gb 3n = 6. Fb, n = 6/3 = 2.

Sbe gur frpbaq pbzcbarag: p*3 + q*0 = 21, juvpu vf 3p = 21, fb p = 21/3 = 7.

Fb sebz gur svefg genafsbezngvba, V sbhaq gung n=2 naq p=7. Fb abj, gur zngevk Z ybbxf yvxr:

Z = [[2, o],
     [7, q]]

Abj, hfvat gur

In [25]:
rot13_alpha(generated_text)

"\n\nOkay, so I have this problem where I need to find a 2x2 matrix M. The matrix M is supposed to transform two different vectors into other vectors. The first transformation is M times the vector (3, 0) equals (6, 21), and the second transformation is M times (-1, 5) equals (X, -17). They also told me that the answer is the matrix [[2, 1], [7, -2]], and I need to find the value of X. Hmm, okay, let me think about how to approach this.\n\nFirst, I remember that a 2x2 matrix can be represented as:\n\nM = [[a, b],\n     [c, d]]\n\nWhen I multiply M by a vector (x, y), it gives me (a*x + b*y, c*x + d*y). So, applying this to the first equation, M*(3, 0) = (6, 21). Let me write that out:\n\nFor the first component: a*3 + b*0 = 6, which simplifies to 3a = 6. So, a = 6/3 = 2.\n\nFor the second component: c*3 + d*0 = 21, which is 3c = 21, so c = 21/3 = 7.\n\nSo from the first transformation, I found that a=2 and c=7. So now, the matrix M looks like:\n\nM = [[2, b],\n     [7, d]]\n\nNow, usin

In [None]:
model_answer = generated_text.split('</think>')[-1]
model_answer = model_answer.split('The answer is: ')[-1].split('<｜end▁of▁sentence｜>')[0]
print(model_answer)



To find the value of \( X \), we use the given matrix \( \mathbf{M} = \begin{pmatrix} 2 & 1 \\ 7 &


In [80]:
parsed_solution = solution.split('The answer is: ')[-1]
print(parsed_solution)

3


In [None]:
grade_answer(model_answer, parsed_solution)

False

# GRPO Model

In [12]:
qlora_dir = "/workspace/data/grpo-metamath-lora-model/checkpoint-210"

In [6]:
from transformers import BitsAndBytesConfig
from peft import PeftModel

In [8]:
import importlib.util
spec = importlib.util.spec_from_file_location(
    "run_grpo", "11_run_grpo.py"
)
run_grpo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(run_grpo)
load_your_model_and_tokenizer_with_lora_and_quantization = run_grpo.load_your_model_and_tokenizer_with_lora_and_quantization


In [36]:
def load_your_model_and_tokenizer_with_lora_and_quantization(
    device_map_auto: bool = False, 
    quantization_type: str = "4bit",
    small_model: bool = False,  # For debugging
    training: bool = True,  # Whether to prepare for training
    return_peft: bool = True,  # Whether to return PEFT model
):
    """Load model and tokenizer with LoRA and quantization setup"""

    if not small_model: 
        model_path = '/workspace/data/axolotl-outputs/llama_deepseek_2epochs/merged'
    else:
        model_path = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'

    args = {
        'pretrained_model_name_or_path': model_path,
        'torch_dtype': torch.bfloat16,
        'trust_remote_code': True,
        'low_cpu_mem_usage': True,
        'cache_dir': hf_cache_dir,
    }
    args['quantization_config'] = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,  # Nested quantization for additional savings
        bnb_4bit_quant_type="nf4",  # Normal Float 4-bit
        bnb_4bit_quant_storage=torch.bfloat16,
    )
    if device_map_auto:
        args['device_map'] = "auto"

    # Load base model with quantization
    print(f"Loading model with {quantization_type} quantization...")
    model = AutoModelForCausalLM.from_pretrained(**args)
    model.generation_config.temperature = 1.0
    model.generation_config.do_sample = True
    model.generation_config.top_p = 0.9
    
    # # Prepare model for k-bit training (required for quantized + LoRA)
    # model = prepare_model_for_kbit_training(
    #     model, use_gradient_checkpointing=False)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    template_path = "chat_templates/deepseek_distill_llama_template.jinja"
    with open(template_path, "r") as file:
        jinja_template = file.read()
    tokenizer.chat_template = jinja_template

    # if return_peft:
    #     lora_config = setup_lora_config_for_quantized()
    #     model = get_peft_model(model, lora_config)
    # else:
    #     pass
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id
    if training:
        model.train()
        model.print_trainable_parameters()
    
    return model, tokenizer

In [37]:
base_model, tokenizer = load_your_model_and_tokenizer_with_lora_and_quantization(training=False)

Loading model with 4bit quantization...


Loading checkpoint shards:   0%|          | 0/30 [00:00<?, ?it/s]

In [38]:
base_model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 8192)
    (layers): ModuleList(
      (0-79): 80 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
          (k_proj): Linear4bit(in_features=8192, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=8192, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=8192, out_features=8192, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=8192, out_features=28672, bias=False)
          (up_proj): Linear4bit(in_features=8192, out_features=28672, bias=False)
          (down_proj): Linear4bit(in_features=28672, out_features=8192, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((8192,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((819

In [39]:
# Load and merge LoRA
peft_model = PeftModel.from_pretrained(base_model, qlora_dir)
merged_model = peft_model.merge_and_unload()



In [40]:
get_model_output(prompt, peft_model)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:
jnl, fb V unir guvf ceboyrz jurer V arrq gb svaq gur inyhr bs K. Gur ceboyrz tvirf zr n zngevk Z gung'f 2k2, naq vg'f hfrq gb zhygvcyl gjb qvssrerag irpgbef. Gur svefg zhygvcyvpngvba vf Z gvzrf gur irpgbe (3, 0) juvpu rdhnyf (6, 21). Gur frpbaq zhygvcyvpngvba vf Z gvzzrf (-1, 5) juvpu rdhnyf (K, -17). Gurl nyfb gryy zr gung gur zngevk Z vf [[2, 1], [7, -2]]. Fb, V arrq gb svther bhg jung K vf.

Bxnl, yrg zr fgneg ol erpnyyvat ubj zngevk zhygvcyvpngvba jbexf. Vs Z vf n 2k2 zngevk, naq V zhygvcyl vg ol n 2k1 irpgbe, gur erfhyg vf nabgure 2k1 irpgbe. Rnpu ragel va gur erfhygvat irpgbe vf gur qbg cebqhpg bs gur pbeerfcbaqvat ebj bs Z naq gur pbyhza bs gur irpgbe.

Fb, Z vf tvira nf:
[ n  o ]
[ p  q ]

Naq jura V zhygvcyl Z ol (3, 0), V trg (6, 21). Yrg zr jevgr gung bhg:

Svefg pbzcbarag: n*3 + o*0 = 6
Frpbaq pbzcbarag: p*3 + q*0 = 21

Fb, sebz gur svefg rdhngvba: 3n = 6 => n = 2
Sebz gur frpbaq rdhngvba: 3p = 21 => p = 7

Fb, gung tvirf zr n=2 naq p=7. Abj, ybbxvat ng gur 

"jnl, fb V unir guvf ceboyrz jurer V arrq gb svaq gur inyhr bs K. Gur ceboyrz tvirf zr n zngevk Z gung'f 2k2, naq vg'f hfrq gb zhygvcyl gjb qvssrerag irpgbef. Gur svefg zhygvcyvpngvba vf Z gvzrf gur irpgbe (3, 0) juvpu rdhnyf (6, 21). Gur frpbaq zhygvcyvpngvba vf Z gvzzrf (-1, 5) juvpu rdhnyf (K, -17). Gurl nyfb gryy zr gung gur zngevk Z vf [[2, 1], [7, -2]]. Fb, V arrq gb svther bhg jung K vf.\n\nBxnl, yrg zr fgneg ol erpnyyvat ubj zngevk zhygvcyvpngvba jbexf. Vs Z vf n 2k2 zngevk, naq V zhygvcyl vg ol n 2k1 irpgbe, gur erfhyg vf nabgure 2k1 irpgbe. Rnpu ragel va gur erfhygvat irpgbe vf gur qbg cebqhpg bs gur pbeerfcbaqvat ebj bs Z naq gur pbyhza bs gur irpgbe.\n\nFb, Z vf tvira nf:\n[ n  o ]\n[ p  q ]\n\nNaq jura V zhygvcyl Z ol (3, 0), V trg (6, 21). Yrg zr jevgr gung bhg:\n\nSvefg pbzcbarag: n*3 + o*0 = 6\nFrpbaq pbzcbarag: p*3 + q*0 = 21\n\nFb, sebz gur svefg rdhngvba: 3n = 6 => n = 2\nSebz gur frpbaq rdhngvba: 3p = 21 => p = 7\n\nFb, gung tvirf zr n=2 naq p=7. Abj, ybbxvat ng gur

In [41]:
get_model_output(prompt, merged_model)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Generated text:
jnl, fb V unir guvf ceboyrz jurer V arrq gb svaq gur inyhr bs K. Gur ceboyrz tvirf zr n zngevk Z gung'f 2k2, naq vg'f hfrq gb zhygvcyl gjb qvssrerag irpgbef. Gur svefg zhygvcyvpngvba vf Z gvzrf gur irpgbe (3, 0) juvpu rdhnyf (6, 21). Gur frpbaq zhygvcyvpngvba vf Z gvzzrf (-1, 5) juvpu rdhnyf (K, -17). Gurl nyfb gryy zr gung gur zngevk Z vf [[2, 1], [7, -2]]. Fb, V arrq gb svther bhg jung K vf.

Bxnl, yrg zr fgneg ol erpnyyvat ubj zngevk zhygvcyvpngvba jbexf. Vs Z vf n 2k2 zngevk, naq V zhygvcyl vg ol n 2k1 irpgbe, gur erfhyg vf nabgure 2k1 irpgbe. Rnpu ragel va gur erfhygvat irpgbe vf gur qbg cebqhpg bs gur pbeerfcbaqvat ebj bs Z naq gur pbyhza bs gur irpgbe.

Fb, Z vf tvira nf:
[ n  o ]
[ p  q ]

Naq jura V zhygvcyl Z ol (3, 0), V trg (6, 21). Yrg zr jevgr gung bhg:

Svefg pbzcbarag: n*3 + o*0 = 6
Frpbaq pbzcbarag: p*3 + q*0 = 21

Fb, sebz gur svefg rdhngvba: 3n = 6 => n = 2
Sebz gur frpbaq rdhngvba: 3p = 21 => p = 7

Fb, gung tvirf zr n=2 naq p=7. Abj, ybbxvat ng gur 

"jnl, fb V unir guvf ceboyrz jurer V arrq gb svaq gur inyhr bs K. Gur ceboyrz tvirf zr n zngevk Z gung'f 2k2, naq vg'f hfrq gb zhygvcyl gjb qvssrerag irpgbef. Gur svefg zhygvcyvpngvba vf Z gvzrf gur irpgbe (3, 0) juvpu rdhnyf (6, 21). Gur frpbaq zhygvcyvpngvba vf Z gvzzrf (-1, 5) juvpu rdhnyf (K, -17). Gurl nyfb gryy zr gung gur zngevk Z vf [[2, 1], [7, -2]]. Fb, V arrq gb svther bhg jung K vf.\n\nBxnl, yrg zr fgneg ol erpnyyvat ubj zngevk zhygvcyvpngvba jbexf. Vs Z vf n 2k2 zngevk, naq V zhygvcyl vg ol n 2k1 irpgbe, gur erfhyg vf nabgure 2k1 irpgbe. Rnpu ragel va gur erfhygvat irpgbe vf gur qbg cebqhpg bs gur pbeerfcbaqvat ebj bs Z naq gur pbyhza bs gur irpgbe.\n\nFb, Z vf tvira nf:\n[ n  o ]\n[ p  q ]\n\nNaq jura V zhygvcyl Z ol (3, 0), V trg (6, 21). Yrg zr jevgr gung bhg:\n\nSvefg pbzcbarag: n*3 + o*0 = 6\nFrpbaq pbzcbarag: p*3 + q*0 = 21\n\nFb, sebz gur svefg rdhngvba: 3n = 6 => n = 2\nSebz gur frpbaq rdhngvba: 3p = 21 => p = 7\n\nFb, gung tvirf zr n=2 naq p=7. Abj, ybbxvat ng gur

### Save merged QLoRA model

In [42]:
qlora_parent_dir = os.path.dirname(qlora_dir)
print(qlora_parent_dir)

/workspace/data/grpo-metamath-lora-model


In [43]:
merged_model.save_pretrained(os.path.join(qlora_parent_dir, "merged"))

[2025-08-13 18:17:57,734] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


df: /root/.triton/autotune: No such file or directory
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/root/miniconda3/envs/py3.11/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `typeinfo for bool@CXXABI_1.3

[2025-08-13 18:17:58,445] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False
