In [1]:
from transformers import AutoModelForCausalLM, pipeline
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from constants import MODEL_NAME, HUGGINGFACE_TOKEN, CHECKPOINTS, INITIAL_SAVE_PATH
from model.setup import load_model, load_tokenizer
import os
import re
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset("EleutherAI/arithmetic", "arithmetic_1dc", split="validation")
split_point = int(0.9 * len(dataset))
dataset = dataset.shuffle(seed=42)
train_dataset = dataset.select(range(split_point))
eval_dataset = dataset.select(range(split_point, len(dataset)))

# Convert the dataset for SFTTrainer
def combine_text(example):
    # Combine context and completion
    example["text"] = example["context"] + example["completion"]
    return example

# Apply the transformation for the pure dataset
sft_train_dataset = train_dataset.map(combine_text)
sft_train_dataset = sft_train_dataset.remove_columns(["context", "completion"])

# Replace the labels for the toolformer dataset
def transform_question(input_text):
    """
    Transform arithmetic questions to use calculator tool format.
    
    Args:
        input_text: The input question text
        
    Returns:
        Transformed question with calculator tool tags
    """
    # First replace word operators
    operator_replacements = [
        (r'\bplus\b', '+'),     # Replace "plus" with "+"
        (r'\bminus\b', '-'),    # Replace "minus" with "-"
        (r'\btimes\b', '*')     # Replace "times" with "*"
    ]
    
    result = input_text
    for pattern, replacement in operator_replacements:
        result = re.sub(pattern, replacement, result)
    
    # Then handle the overall structure
    question_regex = r'Question: What is (.*?)\?\nAnswer:'
    question_replacement = r'<tool:calculator>\1</tool>'
    
    transformed = re.sub(question_regex, question_replacement, result) 
    transformed = transformed.replace('</tool>?', '</tool>')

    return transformed

sft_train_tool_dataset = train_dataset.map(
    lambda x: {"prompt": x["context"], "completion": transform_question(x["context"])},
    remove_columns=["context", "completion"],
)

print("Transformed dataset:")
for i in range(5):
    print(sft_train_tool_dataset[i])


Transformed dataset:
{'completion': '<tool:calculator>(6 - 6) + 3</tool>', 'prompt': 'Question: What is (6 - 6) + 3?\nAnswer:'}
{'completion': '<tool:calculator>(5 * 6) + 5</tool>', 'prompt': 'Question: What is (5 * 6) + 5?\nAnswer:'}
{'completion': '<tool:calculator>(2 - 2) * 3</tool>', 'prompt': 'Question: What is (2 - 2) * 3?\nAnswer:'}
{'completion': '<tool:calculator>(4 - 2) + 7</tool>', 'prompt': 'Question: What is (4 - 2) + 7?\nAnswer:'}
{'completion': '<tool:calculator>(7 * 5) * 2</tool>', 'prompt': 'Question: What is (7 * 5) * 2?\nAnswer:'}


In [3]:
# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_NAME,
#     token=HUGGINGFACE_TOKEN,
#     device_map="auto"
# )

model, metadata = load_model(os.path.join(CHECKPOINTS, "pretrained", INITIAL_SAVE_PATH))

tool_model = copy.deepcopy(model)
pure_model = copy.deepcopy(model)

tokenizer = load_tokenizer(os.path.join(CHECKPOINTS, "pretrained", INITIAL_SAVE_PATH))

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading model from local path: ./checkpoints\pretrained\qwen-initial


In [None]:
# PURE TRAINING
training_args = SFTConfig(output_dir="/tmp")

pure_trainer = SFTTrainer(
    pure_model,
    train_dataset=sft_train_tool_dataset,
    args=training_args,
)

pure_trainer.train()

pure_model.save_pretrained(os.path.join(CHECKPOINTS, "pure"))

In [None]:
# TOOLFORMER TRAINING
training_args = SFTConfig(output_dir="/tmp", completion_only_loss = True)

# Train the toolformer model
tool_trainer = SFTTrainer(
    tool_model,
    train_dataset=sft_train_tool_dataset,
    eval_dataset=sft_train_tool_dataset.select(range(100)),  # Small eval set
    args=training_args
)

tool_trainer.train()

tool_model.save_pretrained(os.path.join(CHECKPOINTS, "tool"))

[34m[1mwandb[0m: Currently logged in as: [33mdaniel-chuang[0m ([33mdaniel-chuang-cornell[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
500,0.0105


In [None]:
def evaluate_arithmetic(examples, pipe, use_tool, num_examples=100):
    correct = 0
    total = 0
    
    for i, example in enumerate(examples):
        if i >= num_examples:
            break
            
        # Extract the context (question)
        if use_tool:
            # Add a system message for the calculator tool
            # context = "Use the calculator tool to solve arithmetic. For example, given the problem 'What is 5+4', output <tool:calculator>5+4</tool>.\n" + example["context"]
            context = example["context"]
        else:
            context = example["context"]
        expected_answer = example["completion"].strip()
        
        # Generate answer
        generated = pipe(context, max_new_tokens=40)[0]["generated_text"]
        
        # Extract only the new text (the answer)
        generated_answer = generated[len(context):].strip()
        
        # Clean the answer (remove extra text after the number)
        generated_answer = re.findall(r'-?\d+', generated_answer)
        if generated_answer:
            generated_answer = generated_answer[0]
        else:
            print("--")
            print(f"Failed to extract answer for example {i}: \n{generated}")
            print("--")
            generated_answer = "N/A"
        
        # Compare with expected answer
        if generated_answer == expected_answer.strip():
            correct += 1
        
        total += 1
        
        if i < 5:  # Print first 5 examples
            print(f"Question: {context}")
            print(f"Expected: {expected_answer}")
            print(f"Generated: {generated_answer}")
            print(f"Correct: {generated_answer == expected_answer.strip()}")
            print("-" * 50)
    
    accuracy = correct / total
    print(f"\nAccuracy: {accuracy:.2%} ({correct}/{total})")
    return accuracy

In [None]:
pure_model, _ = load_model(os.path.join(CHECKPOINTS, "pure"))
tool_model, _ = load_model(os.path.join(CHECKPOINTS, "tool"))

In [None]:
# Create a text generation pipeline
pretrained_pipe = pipeline("text-generation", 
                model=model, 
                tokenizer=tokenizer,
                max_new_tokens=10,
                temperature=0.0,
                pad_token_id=tokenizer.eos_token_id)

pure_pipe = pipeline("text-generation", 
                model=pure_model, 
                tokenizer=tokenizer,
                max_new_tokens=10,
                temperature=0.0,
                pad_token_id=tokenizer.eos_token_id)

tool_pipe = pipeline("text-generation", 
                model=tool_model, 
                tokenizer=tokenizer,
                max_new_tokens=10,
                temperature=0.0,
                pad_token_id=tokenizer.eos_token_id)

# Run evaluation
# accuracy = evaluate_arithmetic(eval_dataset, pretrained_pipe, use_tool = False, num_examples=100)
# print(f"Pretrained Model Accuracy: {accuracy:.2%}")

# accuracy = evaluate_arithmetic(eval_dataset, pure_pipe, use_tool = False, num_examples=100)
# print(f"Pure Model Accuracy: {accuracy:.2%}")

accuracy = evaluate_arithmetic(eval_dataset, tool_pipe, use_tool = True, num_examples=100)
print(f"Tool Model Accuracy: {accuracy:.2%}")