In [23]:
# ! pip install datasets
# ! pip install transformers
# ! pip install accelerate

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from datasets import Dataset, DatasetDict
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load dataset
from experiments.model_training.prepare_dataset import OpenAICompletionDataPoint
dataset_path = "data/test_full-v1.jsonl"
dataset = []
with open(dataset_path, "r") as f:
    for line in f:
        datapoint = OpenAICompletionDataPoint.model_validate_json(line)
        dataset.append(datapoint)
print(len(dataset))

[32m2025-08-01 21:22:54.757[0m | [1mINFO    [0m | [36mtau2.utils.utils[0m:[36m<module>[0m:[36m27[0m - [1mUsing data directory from source: /lambda/nfs/victor-north-tx/tau2-bench-private/data[0m
[32m2025-08-01 21:22:55.799[0m | [1mINFO    [0m | [36mtau2.utils.llm_utils[0m:[36m<module>[0m:[36m65[0m - [1mLiteLLM: Cache is disabled[0m
[32m2025-08-01 21:22:55.893[0m | [34m[1mDEBUG   [0m | [36mtau2.registry[0m:[36m<module>[0m:[36m194[0m - [34m[1mRegistering default components...[0m
[32m2025-08-01 21:22:55.893[0m | [34m[1mDEBUG   [0m | [36mtau2.registry[0m:[36m<module>[0m:[36m236[0m - [34m[1mDefault components registered successfully. Registry info: {
  "domains": [
    "mock",
    "airline",
    "retail",
    "telecom",
    "telecom-workflow"
  ],
  "agents": [
    "llm_agent",
    "llm_agent_gt",
    "llm_agent_solo",
    "llm_agent_completion",
    "llm_agent_gt_completion",
    "llm_agent_solo_completion"
  ],
  "users": [
    "user_simul

1107


In [3]:
dp0: OpenAICompletionDataPoint = dataset[0]
print(type(dp0))
print(dp0.parallel_tool_calls)
print(len(dp0.messages))
print(len(dp0.tools))
for i, message in enumerate(dp0.messages):
    if message["role"] in ["system", "tool"]:
        continue
    print(f"{i} {message['role']}: {message.get('content', 'tool call...')}")

<class 'experiments.model_training.prepare_dataset.OpenAICompletionDataPoint'>
True
18
14
1 assistant: Hi! How can I help you today?
2 user: Hi, I need to change my upcoming flight from JFK. My cat is really sick and I need to get home sooner. Can you help me look into changing my flight to an earlier nonstop option?
3 assistant: To assist you with changing your flight, I need to know your user ID and the reservation ID for the flight you want to change. Could you please provide those? If you don't know your reservation ID, I can help you locate it.
4 user: I don’t have my reservation ID handy, but my user ID is daiki_lee_6144. Can you look up my reservation with that?
5 assistant: tool call...
7 assistant: You have three reservations under your user ID daiki_lee_6144: DF89BM, COVE6R, and IIHXDG.

Could you please specify which reservation you want to change? If you are not sure, you can provide more details about the flight such as the destination or date, and I can help identify the 

In [17]:
# Load model and tokenizer

# Model checkpoint
MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
# MODEL_NAME = "Qwen/Qwen2.5-0.5B"

def get_model_and_tokenizer(model_name, torch_dtype="auto", device_map="auto"):
    # Load tokenizer and model (4-bit optional)
    print(f"Loading tokenizer and model for {model_name}")
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.chat_template is None:
        raise ValueError(f"Tokenizer for model {model_name} does not have a chat template.")
    else:
        print("Tokenizer has a chat template.")
    print(f"Tokenizer eos_token_id: {tokenizer.eos_token_id}, pad_token_id: {tokenizer.pad_token_id}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch_dtype, # Automatically determines the best dtype (e.g., float16, bfloat16)
        device_map=device_map   # Automatically distributes the model across available devices (e.g., GPUs)
    )
    return model, tokenizer


model, tokenizer = get_model_and_tokenizer(MODEL_NAME, torch_dtype="auto", device_map="auto")

Loading tokenizer and model for Qwen/Qwen2.5-3B-Instruct
Tokenizer has a chat template.
Tokenizer eos_token_id: 151645, pad_token_id: 151643


Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.30it/s]


In [None]:
# Load gemnerate response function
# GPU_BACKEND = "cuda" if torch.cuda.is_available() else "cpu"
# GPU_BACKEND = "mps" if torch.cuda.is_available() else "cpu"

def generate_response_qwen(messages, tools, tokenizer, model, temperature=0.5, add_generation_prompt=True, enable_thinking=False, max_new_tokens=32768, verbose=False):
    """
    Generate completion response for a given prompt.
    Args:
        messages: List of messages in the chat format.
        tools: List of tools to use for the completion.
        tokenizer: Tokenizer to use for the completion.
        model: Model to use for the completion.
        max_new_tokens: Maximum number of new tokens to generate.
    """
    prompt = tokenizer.apply_chat_template(messages, tools=tools, tokenize=False, add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking)
    if verbose:
        print(f"Prompt:\n{prompt}")
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    # inputs.pop('token_type_ids', None)
    with torch.no_grad():
        outputs = model.generate(**inputs, 
                                  max_new_tokens=max_new_tokens,
                                  temperature=temperature
                                  )
    input_length = inputs.input_ids.shape[1]
    generated_token_ids = outputs[0][input_length+1:]
    # parsing thinking content
    try:
        # rindex finding 151668 (</think>)
        index = len(generated_token_ids) - generated_token_ids[::-1].index(151668)
    except ValueError:
        index = 0

    thinking_content = tokenizer.decode(generated_token_ids[:index], skip_special_tokens=True).strip("\n")
    content = tokenizer.decode(generated_token_ids[index:], skip_special_tokens=True).strip("\n")
    if verbose:
        print(f"Thinking content: {thinking_content}")
        print(f"Content: {content}")
    return content, thinking_content


test_messages = [
    {"role": "user", "content": "What is the capital of France?"}
]
test_tools = None
test_response, test_thinking_content = generate_response_qwen(test_messages, test_tools, tokenizer, model, add_generation_prompt=True, enable_thinking=True)
print("thinking content:", test_thinking_content)
print("content:", test_response)

# Parse tool calls
import json
import re

def parse_tool_calls_qwen(response: str) -> list[dict]:
    """
    Parse tool calls from a response.
    System instructions in Qwen2.5:
    ```
    For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
    <tool_call>
    {"name": <function-name>, "arguments": <args-json-object>}
    </tool_call>
    ```
    """
    pattern = r"<tool_call>(.*?)</tool_call>"
    matches = re.findall(pattern, response, re.DOTALL)
    tool_calls = []
    
    for match in matches:
        # Process each line of the tool call content
        lines = match.strip().split('\n')
        for line in lines:
            line = line.strip()
            if line:  # Skip empty lines
                try:
                    tool_call = json.loads(line)
                    tool_calls.append(tool_call)
                except json.JSONDecodeError:
                    continue  # Skip lines that aren't valid JSON
    
    return tool_calls

test_response = "I'm thinking about the capital of France. <tool_call>{\"name\": \"get_capital\", \"arguments\": {\"country\": \"France\"}}</tool_call>"
tool_calls = parse_tool_calls_qwen(test_response)
print(tool_calls)



    


Prompt:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
What is the capital of France?<|im_end|>
<|im_start|>assistant

thinking content: 
content:  capital of France is Paris.
[{'name': 'get_capital', 'arguments': {'country': 'France'}}]


In [None]:
def create_partial_trajectories(datapoint: OpenAICompletionDataPoint) -> list[OpenAICompletionDataPoint]:
    messages = datapoint.messages
    tools = datapoint.tools
    partial_trajectories = []
    for i, message in enumerate(messages):
        if message["role"] == "assistant":

            partial_trajectories.append(OpenAICompletionDataPoint(messages=messages[:i+1], tools=tools, parallel_tool_calls=datapoint.parallel_tool_calls))
    return partial_trajectories

partial_trajectories = create_partial_trajectories(dataset[0])

def replay(datapoint):
    partial_trajectories = create_partial_trajectories(datapoint)
    for i, partial_trajectory in enumerate(partial_trajectories[-2:]):
        original_response = partial_trajectory.messages[-1]
        convo_history = partial_trajectory.messages[:-1]
        replay_response, thinking_content = generate_response_qwen(convo_history, partial_trajectory.tools, tokenizer, model, add_generation_prompt=True, enable_thinking=True, verbose=False)
        print(f"Partial trajectory {i} (number of messages: {len(partial_trajectory.messages)}):")
        print(f"Original response: {original_response['content']}")
        print(f"Thinking content: {thinking_content}")
        print(f"Replay response: {replay_response}")
        tool_calls = parse_tool_calls_qwen(replay_response)
        print(f"Tool calls: {tool_calls}")
        return

replay(dataset[0])

Prompt:
<|im_start|>system
<instructions>
You are a customer service agent that helps the user according to the <policy> provided below.
During each turn you can either:
- Send a message to the user.
- Make a tool call.
IMPORTANT: You cannot do both at the same time!!
If you send text content while making a tool call, the agent will raise an error.
Text content will be sent to the user. Do not use this field for your own reasoning.

Try to be helpful and always follow the policy. Always make sure you generate valid JSON only.
</instructions>
<policy>
# Airline Agent Policy

The current time is 2024-05-15 15:00:00 EST.

As an airline agent, you can help users **book**, **modify**, or **cancel** flight reservations. You also handle **refunds and compensation**.

Before taking any actions that update the booking database (booking, modifying flights, editing baggage, changing cabin class, or updating passenger information), you must list the action details and obtain explicit user confirma

DatasetDict({
    train: Dataset({
        features: ['messages', 'tools', 'task_id', 'simulation_id', 'turn_idx'],
        num_rows: 74
    })
    test: Dataset({
        features: ['messages', 'tools', 'task_id', 'simulation_id', 'turn_idx'],
        num_rows: 40
    })
})


In [None]:
from trl import SFTConfig, SFTTrainer
from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback
from peft import LoraConfig, TaskType, get_peft_model

MODEL_NAME = "Qwen/Qwen2.5-7B"

USE_PEFT = True
if USE_PEFT:
    learning_rate = 1e-4 # Higher learning rate for PEFT?
else:
    learning_rate = 8e-5

sft_config = SFTConfig(
    assistant_only_loss=True,                # Only compute the loss on the assistant messages
    report_to="none",                        # disable logging to W&B
    logging_strategy="steps",
    learning_rate=learning_rate,                      # Learning rate for training. 
    num_train_epochs=20,                     #  Set the number of epochs to train the model.
    per_device_train_batch_size=2,           # Batch size for each device (e.g., GPU) during training. 
    gradient_accumulation_steps=8,           # Number of steps before performing a backward/update pass to accumulate gradients.
    gradient_checkpointing=True,             # Enable gradient checkpointing to reduce memory usage during training at the cost of slower training speed.
    logging_steps=2,                         # Frequency of logging training progress (log every 2 steps).
    eval_strategy="epoch",                   # evaluate at end of each epoch
    save_strategy="epoch",                   # save checkpoint at end of each epoch
    save_total_limit=1,                      # keep only the best/latest model
    load_best_model_at_end=True,             # load best model according to eval loss
    metric_for_best_model="eval_loss",       # use eval loss for best model selection
    greater_is_better=False,                 # lower eval_loss is better
    output_dir="./SFTcheckpoints"               # directory to save checkpoints
)


# Instantiate early stopping callback
early_stopping_callback = EarlyStoppingCallback(
    early_stopping_patience=2  # Stop if no improvement for 2 evals (epochs)
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True)

if USE_PEFT: # FIXME: Check what's the right config.
    # lora_config = LoraConfig(
    #     r=64,
    #     lora_alpha=16,
    #     target_modules=["c_attn", "q_proj", "v_proj"],  # adjust to Qwen architecture
    #     lora_dropout=0.05,
    #     bias="none",
    #     task_type=TaskType.CAUSAL_LM,
    # )
    lora_config = LoraConfig()
else:
    lora_config = None

sft_trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=sft_dataset["train"],
    processing_class=tokenizer,
    callbacks=[early_stopping_callback],
    peft_config=lora_config 
    
)

