infer.py

In [2]:
import torch
from tqdm import tqdm
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from openai import OpenAI
from transformers.generation.streamers import TextStreamer
import numpy as np

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


  from .autonotebook import tqdm as notebook_tqdm


🦥 Unsloth Zoo will now patch everything to make training faster!


In [3]:
max_seq_length = 4000  
load_in_4bit = True

def get_label(expr, base):
    lhs, rhs = expr.split("+")
    lhs_base10 = int(lhs, base)
    rhs_base10 = int(rhs, base)
    sum_base10 = lhs_base10 + rhs_base10
    return np.base_repr(sum_base10, base)

def load_model(model_path, chat_template, r=16, lora_alpha=32, peft_path=None):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = model_path,
        max_seq_length = max_seq_length,
        dtype = None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
        load_in_4bit = load_in_4bit,
        # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
        device_map="auto"
    )

    # model = FastLanguageModel.get_peft_model(model,
    #     r = r,                                                      # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    #     target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
    #                     "gate_proj", "up_proj", "down_proj",],
    #     lora_alpha = lora_alpha,
    #     lora_dropout = 0,                                           # Supports any, but = 0 is optimized
    #     bias = "none",                                              # Supports any, but = "none" is optimized
    #     use_gradient_checkpointing = True,
    #     random_state = 3407
    # )
    
    tokenizer = get_chat_template(tokenizer, chat_template = chat_template)
    return model, tokenizer

def inf_oai(model, prompts):
    client = OpenAI()
    
    all_responses = []

    for prompt in tqdm(prompts):
        completion = client.chat.completions.create(
            model=model,
            messages=[
            {"role": "user", "content": prompt}
            ]
        )
        all_responses.append(completion.choices[0].message.content)
        
    return all_responses


query.py

In [4]:
import os

def parse_bool(flag):
    if isinstance(flag, bool):
        return flag
    assert flag in {"True", "False"}
    return flag == "True"

def load_data(data_file, size):
    x = [line.strip() for line in open(data_file)][:size]
    print(len(x))
    return x


def answer(expr, base):
    lhs, rhs = expr.split("+")
    lt, lo = lhs  # tens, ones
    rt, ro = rhs
    ones_sum = get_label(f"{lo}+{ro}", base)
    carry_over = len(ones_sum) > 1
    tens_sum_wo_carry = get_label(f"{lt}+{rt}", base)
    if carry_over:
        assert ones_sum[0] == "1"
        tens_sum_w_carry = get_label(f"{tens_sum_wo_carry}+1", base)
    else:
        tens_sum_w_carry = tens_sum_wo_carry
    assert get_label(expr, base) == tens_sum_w_carry + ones_sum[-1:]

    ret = f"We add the ones digits first. In base-{base}, {lo}+{ro}={ones_sum}. So the ones digit of the final sum is {ones_sum[-1:]}. "
    if carry_over:
        ret += f"We need to carry over the 1 to the tens place. "
    else:
        ret += f"We do not need to carry any digits over. "
    ret += f"Then we add the tens digits. In base-{base}, {lt}+{rt}={tens_sum_wo_carry}. "
    if carry_over:
        ret += f"Since we carried over the 1, {tens_sum_wo_carry}+1={tens_sum_w_carry}. "
    if len(tens_sum_w_carry) == 1:
        ret += f"So the tens digit of the final sum is {tens_sum_w_carry}. "
    else:
        ret += f"So the hundreds and tens digits of the final sum are {tens_sum_w_carry}. "
    ret += f"Putting the digits of the final sum together, we get \\boxed{{{tens_sum_w_carry}{ones_sum[-1:]}}}."
    return ret


def templatize(expr, base, cot=True, n_shots=0, icl_cot=True):
    if n_shots > 0:
        with open(f"ft_data/data_ft_{base}_2.txt",'r') as f:
            demos = f.read()
        shots = demos.split("\n")[:n_shots]
        assert len(shots) == n_shots
        if icl_cot:
            context = "\n".join(f"{templatize(shot, base)} {answer(shot, base)}" for shot in shots)
        else:
            context = "\n".join(f"{templatize(shot, base)} \\boxed{{{get_label(expr, base)}}}" for shot in shots)
        return context + "\n" + templatize(expr, base)
    digits = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
    if cot:
        return f"You are a mathematician. Assuming that all numbers are in base-{base} where the digits are \"{digits[:base]}\", what is {expr}? Let's think step by step, and end the response with the result in \"\\boxed{{result}}\"."
    else:
        return f"You are a mathematician. Assuming that all numbers are in base-{base} where the digits are \"{digits[:base]}\", what is {expr}? End the response with the result in \"\\boxed{{result}}\"."


def escape(str):
    assert "\t" not in str and "\\\\n" not in str and "\\\\r" not in str
    return str.replace("\\n", "\\\\n").replace("\n", "\\n").replace("\\r", "\\\\r").replace("\r", "\\r")

In [5]:
n_digits = 3
base = 10
cot = True
n_shots = 0
size = 250
model_name = "unsloth/Meta-Llama-3.1-8B-Instruct"
# model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit"
chat_template = "llama-3"
output_file = "output.txt"
icl_cot = True
device = "cuda"

In [6]:
data_file = f'arithmetic/data/0shot{"_3digits" if n_digits == 3 else ("_4digits" if n_digits == 4 else "")}/base{base}.txt'
data = load_data(data_file, size)

assert not os.path.exists(output_file)

print("templatizing...")
prompts = [templatize(expr, base, cot=cot, n_shots=n_shots, icl_cot=icl_cot) for expr in data]
print("\tdone!")

print("loading model...")
model, tokenizer = load_model(model_name, chat_template)
print("\tdone!")
# model = model.to(device)

250
templatizing...
	done!
loading model...
==((====))==  Unsloth 2025.6.9: Fast Llama patching. Transformers: 4.53.0.
   \\   /|    NVIDIA RTX 6000 Ada Generation. Num GPUs = 8. Max memory: 47.408 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
	done!


In [7]:
prompts[0]

'You are a mathematician. Assuming that all numbers are in base-10 where the digits are "0123456789", what is 760+587? Let\'s think step by step, and end the response with the result in "\\boxed{result}".'

In [None]:

print("inferencing...")
FastLanguageModel.for_inference(model)

batch_size = 1
# batch_size = 100

responses = []

# chat_prompts = [{"role": "user", "content": prompts[0]}]
chat_prompts = [{"role": "user", "content": 'You are a mathematician. Assuming that all numbers are in base-10 where the digits are "0123456789", what is 760+587?'}]
# print(chat_prompts)
inputs = tokenizer.apply_chat_template(
	chat_prompts,
	tokenize=True,
	add_generation_prompt=True,
	return_tensors="pt",
).to(device)

text_streamer = TextStreamer(tokenizer)
model.generate(
	input_ids=inputs,
	max_new_tokens= 50,
	streamer=text_streamer,
	use_cache=True,
	# temperature=0.1,
	# temperature=0.0,
	# min_p=0.1,
	# pad_token_id=tokenizer.pad_token_id
)

# for i in tqdm(range(0, len(prompts), batch_size)):
# 	batch_prompts = prompts[i:i + batch_size]
# 	print("\tgetting prompts")
# 	chat_prompts = [[{"role": "user", "content": p}] for p in batch_prompts]
	
# 	print("\tapplying chat template")
# 	inputs = tokenizer.apply_chat_template(
# 		chat_prompts,
# 		tokenize=True,
# 		add_generation_prompt=True,
# 		return_tensors="pt",
# 		padding=True,
# 		truncation=True,
# 		max_length=max_seq_length
# 	).to(device)
# 	# print(f"DEVICE: {model.device}")
# 	print("\tgenerating")
# 	outputs = model.generate(
# 		input_ids=inputs,
# 		max_new_tokens=max_seq_length - inputs.shape[1],
# 		# streamer=text_streamer,
# 		use_cache=False,
# 		temperature=0.1,
# 		# temperature=0.0,
# 		min_p=0.1,
# 		pad_token_id=tokenizer.pad_token_id
# 	)
	
# 	print("\tdecoding")
# 	batch_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# 	responses.extend(batch_responses)
	
# 	if torch.cuda.is_available():
# 		torch.cuda.empty_cache()
    


# print("\tdone!")

# print("writing output...")
# with open(output_file, "w") as log:
# 	for expr, response in zip(data, responses, strict=True):
# 		log.write(f"{expr}\t{escape(response)}\n")
# print("\tdone!")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


inferencing...
<|begin_of_text|><|start_header_id|>user<|end_header_id|>

You are a mathematician. Assuming that all numbers are in base-10 where the digits are "0123456789", what is 760+587?<|eot_id|><|start_header_id|>assistant<|end_header_id|>



/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [32,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [33,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [34,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [35,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [36,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu:1553: indexSelectLargeIndex: block: [369,0,0], thread: [37,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
/pytorch/aten/src/ATen/native/cuda/Indexing.cu

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
