In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

model_id = "mistralai/Mistral-7B-Instruct-v0.3"
tokenizer = AutoTokenizer.from_pretrained(model_id)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    attn_implementation="flash_attention_2",
)

`low_cpu_mem_usage` was None, now default to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 3/3 [01:46<00:00, 35.66s/it]


In [3]:
from typing import Dict


def process_cot_example(
    example: Dict,
    tokenizer,
):
    thinking_trajectory = example["deepseek_thinking_trajectory"]
    question = example["question"]
    answer = example["deepseek_attempt"]

    thinking = "\n".join(thinking_trajectory).strip()
    answer = "Answer: " + answer if "Answer:" not in answer else answer

    assistant_text = (
        "I'll think through this step by step:\n" + thinking + answer.strip()
    )

    text = tokenizer.apply_chat_template(
        [
            {"role": "user", "content": question.strip()},
            {
                "role": "assistant",
                "content": assistant_text,
            },
        ],
        tokenize=False,
    )
    return dict(text=text)

In [4]:
from datasets import load_dataset

ds = load_dataset("simplescaling/s1K-1.1")["train"]

In [None]:
ds[0]

In [None]:
ds_text_example = process_cot_example(ds[0], tokenizer)
ds_text_example

In [7]:
ds_text = ds.map(
    lambda x: process_cot_example(x, tokenizer),
    batched=False,
    remove_columns=ds.column_names,
)

In [None]:
print(ds_text[0])

In [None]:
print(model.get_memory_footprint())

In [10]:
from peft import prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

In [11]:
from peft import LoraConfig, get_peft_model

peft_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules="all-linear",
    modules_to_save=["lm_head", "embed_token"],
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

In [12]:
from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="outputs",
    num_train_epochs=2,
    per_device_train_batch_size=8,
    learning_rate=2e-5,
    logging_steps=10,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    save_strategy="epoch",
    warmup_ratio=0.1,
    gradient_accumulation_steps=4,
)

In [13]:
train_ds = ds_text

In [None]:
train_ds

In [None]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
)

trainer.train()

In [16]:
trainer.save_model(output_dir="models/mistral-7b-reasoning-lora")

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

adapter_path = "models/mistral-7b-reasoning-lora"
model = PeftModel.from_pretrained(base_model, adapter_path)

In [4]:
model = model.merge_and_unload()
model.eval()



MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32768, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): Mist

In [11]:
def generate_response(prompt):
    # Format the input using the chat template
    inputs = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}], return_tensors="pt"
    ).to(model.device)

    # Generate with parameters suited for reasoning
    outputs = model.generate(
        input_ids=inputs,
        max_new_tokens=32768,
        temperature=0.5,
        top_p=0.9,
        do_sample=True,
    )

    # Decode the response, removing the prompt
    response = tokenizer.decode(outputs[0][inputs.shape[1] :], skip_special_tokens=True)
    return response

In [12]:
question = "Consider the following two person game. A number of pebbles are situated on the table. Two players make their moves alternately. A move consists of taking off the table  $x$  pebbles where  $x$  is the square of any positive integer. The player who is unable to make a move loses. Prove that there are infinitely many initial situations in which the second player can win no matter how his opponent plays."
response = generate_response(question)
print(response)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Let's think through this step by step:

1. Note that the smallest number of pebbles that the second player can take in a move is 1 (since 1^2 = 1).

2. If the initial number of pebbles on the table is 1, the first player will take 1 pebble, leaving 0. The second player can't make a move, so the first player wins.

3. If the initial number of pebbles is 4, the first player can take 4 pebbles (4^2 = 16 > 4), leaving 0. Again, the second player can't make a move, so the first player wins.

4. Now, consider any number of pebbles that is not a perfect square. For example, let's take 9. The smallest perfect square greater than 9 is 16 (4^2). The first player can take 9 pebbles, leaving 0. The second player can't make a move, so the first player wins.

5. However, for any perfect square number of pebbles, the second player can win. For example, if there are 4 pebbles (1^2 + 3^2 = 4), the first player can take 4 pebbles, leaving 0. But the second player can take 1 pebble (1^2), leaving 0. The 