This notebook shows how to format a dataset with thinking traces for a model like Qwen3

The dataset is 1.5k samples from [Salesforce/APIGen-MT-5k](https://huggingface.co/datasets/Salesforce/APIGen-MT-5k) with reasoning traces using gpt-oss-120b with high reasoning effort.


In [1]:
from datasets import load_dataset

ds = load_dataset("nbroad/apigen-with-thinking-1.5k")['train']
ds

Dataset({
    features: ['tools', 'messages'],
    num_rows: 1500
})

In [2]:
import ipywidgets as widgets
from IPython.display import display, HTML
import json

def format_message(msg):
    """Format a message based on its role"""
    role = msg.get('role', 'unknown')
    content = msg.get('content', '')
    reasoning_content = msg.get('reasoning_content', '')
    tool_calls = msg.get('tool_calls', None)
    
    # Color scheme for different roles
    colors = {
        'user': {'bg': '#007bff', 'text': 'white', 'align': 'right', 'label': 'Human'},
        'assistant': {'bg': '#e9ecef', 'text': 'black', 'align': 'left', 'label': 'Assistant'},
        'tool': {'bg': '#28a745', 'text': 'white', 'align': 'left', 'label': 'Tool Response'},
        'system': {'bg': '#6c757d', 'text': 'white', 'align': 'left', 'label': 'System'},
        'thinking': {'bg': '#ffc107', 'text': 'black', 'align': 'left', 'label': 'Thinking'}
    }
    
    style = colors.get(role, {'bg': '#6c757d', 'text': 'white', 'align': 'left', 'label': role})
    
    # Format the content
    formatted_parts = []
    
    # Add thinking/reasoning content if present
    if reasoning_content:
        thinking_style = colors['thinking']
        formatted_parts.append(f"""
            <div style="
                background-color: {thinking_style['bg']};
                color: {thinking_style['text']};
                padding: 8px 12px;
                border-radius: 12px;
                margin-bottom: 8px;
                font-size: 0.9em;
                font-style: italic;
            ">
                <strong>{thinking_style['label']}:</strong><br>
                {reasoning_content.replace('\n', '<br>')}
            </div>
        """)
    
    # Add tool calls if present
    if tool_calls:
        # tool_calls might be a string (JSON) or already a list
        if isinstance(tool_calls, str):
            try:
                tool_calls = json.loads(tool_calls)
            except:
                pass
        
        if isinstance(tool_calls, list):
            for tool_call in tool_calls:
                if isinstance(tool_call, dict):
                    func_name = tool_call.get('function', {}).get('name', 'unknown')
                    func_args = tool_call.get('function', {}).get('arguments', '{}')
                    try:
                        # func_args might be a string (JSON) or already a dict
                        if isinstance(func_args, str):
                            func_args_parsed = json.loads(func_args)
                        else:
                            func_args_parsed = func_args
                        func_args_formatted = json.dumps(func_args_parsed, indent=2)
                    except:
                        func_args_formatted = str(func_args)
                    
                    formatted_parts.append(f"""
                        <div style="
                            background-color: #ffc107;
                            color: black;
                            padding: 8px 12px;
                            border-radius: 12px;
                            margin-bottom: 8px;
                            font-size: 0.9em;
                        ">
                            <strong>Function Call: {func_name}</strong><br>
                            <pre style='margin: 5px 0; font-size: 0.85em;'>{func_args_formatted}</pre>
                        </div>
                    """)
    
    # Add main content
    if content:
        formatted_content = content.replace('\n', '<br>')
        formatted_parts.append(formatted_content)
    
    if not formatted_parts:
        formatted_parts = ['<em>(empty message)</em>']
    
    bubble_style = f"""
        background-color: {style['bg']};
        color: {style['text']};
        padding: 10px 15px;
        border-radius: 18px;
        margin: 8px 0;
        max-width: 70%;
        word-wrap: break-word;
        display: inline-block;
        text-align: left;
        box-shadow: 0 1px 2px rgba(0,0,0,0.1);
    """
    
    label_style = f"""
        font-size: 0.75em;
        font-weight: bold;
        margin-bottom: 4px;
        opacity: 0.8;
    """
    
    if style['align'] == 'right':
        container_style = "text-align: right; margin: 8px 0;"
    else:
        container_style = "text-align: left; margin: 8px 0;"
    
    return f"""
    <div style="{container_style}">
        <div style="{label_style}">{style['label']}</div>
        <div style="{bubble_style}">
            {''.join(formatted_parts)}
        </div>
    </div>
    """

def display_conversation(index):
    """Display a conversation at the given index"""
    sample = ds[index]
    messages = json.loads(sample['messages'])
    
    # Parse tools - handle both string and already-parsed formats
    tools = []
    if sample.get('tools'):
        tools_raw = sample['tools']
        if isinstance(tools_raw, str):
            try:
                tools = json.loads(tools_raw)
            except:
                tools = []
        elif isinstance(tools_raw, list):
            tools = tools_raw
    
    html_content = """
    <div style="
        font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
        padding: 20px;
        background-color: #f8f9fa;
        border-radius: 10px;
        max-height: 600px;
        overflow-y: auto;
    ">
    """
    
    # Display tools info if available
    if tools:
        tool_names = []
        for tool in tools:
            # Handle case where tool might be a string (shouldn't happen, but be safe)
            if isinstance(tool, str):
                try:
                    tool = json.loads(tool)
                except:
                    tool_names.append(f'<li>Invalid tool format</li>')
                    continue
            
            if isinstance(tool, dict):
                # Format: {'type': 'function', 'function': {'name': '...', 'description': '...', ...}}
                func_info = tool.get('function', {})
                if isinstance(func_info, dict):
                    name = func_info.get('name', 'unknown')
                    description = func_info.get('description', '')
                    if description:
                        tool_names.append(f'<li><strong>{name}</strong>: {description}</li>')
                    else:
                        tool_names.append(f'<li>{name}</li>')
                else:
                    tool_names.append(f'<li>Invalid tool format</li>')
            else:
                tool_names.append(f'<li>Unknown tool format</li>')
        
        tools_html = f"""
        <div style="
            background-color: #fff3cd;
            border: 1px solid #ffc107;
            border-radius: 8px;
            padding: 10px;
            margin-bottom: 15px;
            font-size: 0.9em;
        ">
            <strong>Available Tools ({len(tools)}):</strong>
            <ul style="margin: 5px 0; padding-left: 20px;">
                {''.join(tool_names)}
            </ul>
        </div>
        """
        html_content += tools_html
    
    for msg in messages:
        html_content += format_message(msg)
    
    html_content += "</div>"
    
    return html_content

def create_ui():
    """Create the main UI using a dedicated HTML widget"""
    
    # 1. The Slider
    index_slider = widgets.IntSlider(
        value=0,
        min=0,
        max=len(ds) - 1,
        step=1,
        description='Index:',
        layout=widgets.Layout(width='500px'),
        continuous_update=False # Keeps it from lagging while dragging
    )
    
    # 2. The Content Display (Use widgets.HTML instead of widgets.Output)
    # We initialize it with the first conversation
    conversation_display = widgets.HTML(
        value=display_conversation(0),
        layout=widgets.Layout(width='100%')
    )
    
    # 3. The Update Logic
    def update_conversation(change):
        # Simply overwrite the .value property. 
        # No clearing needed, no flickering, no stacking.
        conversation_display.value = display_conversation(change['new'])
    
    index_slider.observe(update_conversation, names='value')
    
    # 4. The Layout
    ui = widgets.VBox([
        widgets.HTML("<h2>Conversation Viewer</h2>"),
        widgets.HBox([index_slider, widgets.Label(f"Total: {len(ds)}")]),
        conversation_display
    ])
    
    return ui

# Run it
display(create_ui())

VBox(children=(HTML(value='<h2>Conversation Viewer</h2>'), HBox(children=(IntSlider(value=0, continuous_updateâ€¦

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-14B")

In [4]:
x = json.loads(ds[0]["messages"])



In [5]:
import json


def tokenize(sample):
    messages = json.loads(sample["messages"])

    # This is the full conversation
    full_ids = tokenizer.apply_chat_template(
        messages, tokenize=True, tools=json.loads(sample["tools"])
    )

    # This contains everything except the last assistant message
    prefix_ids = tokenizer.apply_chat_template(
        messages[:-1], tokenize=True, tools=json.loads(sample["tools"]), add_generation_prompt=True
    )

    # Masking all tokens from loss except for the last assistant message
    labels = [-100] * len(prefix_ids) + full_ids[len(prefix_ids):]

    assert len(full_ids) == len(labels)


    return {
        "input_ids": full_ids,
        "attention_mask": [1] * len(full_ids),
        "labels": labels
    }


tokenized_ds = ds.map(tokenize, batched=False, num_proc=16)

Map (num_proc=16):   0%|          | 0/1500 [00:00<?, ? examples/s]

In [6]:
def visualize_masking(input_ids, labels):
    """
    Visualize token masking using ANSI colors.
    
    Args:
        input_ids: List of token IDs
        labels: List of labels (-100 for masked, token_id for unmasked)
        tokenizer: Optional tokenizer to convert IDs to tokens
        tokens: Optional list of token strings (if tokenizer not provided)
    """
    # ANSI color codes
    RED = '\033[31m'      # Masked tokens
    GREEN = '\033[32m'    # Unmasked tokens
    RESET = '\033[0m'     # Reset color
    BOLD = '\033[1m'      # Bold text

    
    print(f"\n{BOLD}Token Masking Visualization{RESET}")
    print(f"{RED}Red = Masked (-100){RESET}, {GREEN}Green = Unmasked (tokens to train on){RESET}")
    print("-" * 80)
    

    line_tokens = []
    line_length = 0

    tokens = [tokenizer.decode(x) for x in input_ids]
    
    for token, label in zip(tokens, labels):
        # Choose color based on masking
        if label == -100:
            colored_token = f"{RED}{token}{RESET}"
        else:
            colored_token = f"{GREEN}{token}{RESET}"
        

        token_info = f"{colored_token}"

        line_tokens.append(token_info)
        line_length += len(token)
    

    if line_tokens:
        print(''.join(line_tokens))

In [7]:
x = tokenized_ds.shuffle()[0]

visualize_masking(x["input_ids"], x["labels"])


[1mToken Masking Visualization[0m
[31mRed = Masked (-100)[0m, [32mGreen = Unmasked (tokens to train on)[0m
--------------------------------------------------------------------------------
[31m<|im_start|>[0m[31msystem[0m[31m
[0m[31m#[0m[31m Retail[0m[31m agent[0m[31m policy[0m[31m
[0m[31mAs[0m[31m a[0m[31m retail[0m[31m agent[0m[31m,[0m[31m you[0m[31m can[0m[31m help[0m[31m users[0m[31m cancel[0m[31m or[0m[31m modify[0m[31m pending[0m[31m orders[0m[31m,[0m[31m return[0m[31m or[0m[31m exchange[0m[31m delivered[0m[31m orders[0m[31m,[0m[31m modify[0m[31m their[0m[31m default[0m[31m user[0m[31m address[0m[31m,[0m[31m or[0m[31m provide[0m[31m information[0m[31m about[0m[31m their[0m[31m own[0m[31m profile[0m[31m,[0m[31m orders[0m[31m,[0m[31m and[0m[31m related[0m[31m products[0m[31m.
[0m[31m-[0m[31m At[0m[31m the[0m[31m beginning[0m[31m of[0m[31m the[0m[31m conversation[0m

In [9]:
cols2keep = ["input_ids", "labels", "attention_mask"]
cols = tokenized_ds.column_names
final_ds = tokenized_ds.remove_columns(set(cols) - set(cols2keep))

split_final_ds = final_ds.train_test_split(test_size=0.1)


train_filename = "apigen-with-thinking-qwen3-5k-train.parquet"
test_filename = "apigen-with-thinking-qwen3-5k-test.parquet"

split_final_ds["train"].to_parquet(train_filename)
split_final_ds["test"].to_parquet(test_filename)

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

8792283

In [10]:
from together import Together
import os
from dotenv import load_dotenv

loaded = load_dotenv("../.env", override=True)
if not loaded:
    raise ValueError("Failed to load .env file")

together_client = Together(api_key=os.environ["TOGETHER_API_KEY"])

In [11]:
from together.utils import check_file

check_file(train_filename), check_file(test_filename)

({'is_check_passed': True,
  'message': 'Checks passed',
  'found': True,
  'file_size': 2925857,
  'utf8': None,
  'line_type': None,
  'text_field': None,
  'key_value': None,
  'has_min_samples': None,
  'num_samples': 1350,
  'load_json': None,
  'load_csv': None,
  'filetype': 'parquet'},
 {'is_check_passed': True,
  'message': 'Checks passed',
  'found': True,
  'file_size': 505340,
  'utf8': None,
  'line_type': None,
  'text_field': None,
  'key_value': None,
  'has_min_samples': None,
  'num_samples': 150,
  'load_json': None,
  'load_csv': None,
  'filetype': 'parquet'})

In [13]:
train_upload_details = together_client.files.upload(train_filename)
test_upload_details = together_client.files.upload(test_filename)

print(train_upload_details)
print(test_upload_details)

Uploading file apigen-with-thinking-qwen3-5k-train.parquet: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2.93M/2.93M [00:00<00:00, 5.00MB/s]
Uploading file apigen-with-thinking-qwen3-5k-test.parquet: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 505k/505k [00:00<00:00, 2.38MB/s]


id='file-f23a174f-a6e1-4709-9e95-9234b6b2d721' object='file' created_at=1770110347 type=None purpose=<FilePurpose.FineTune: 'fine-tune'> filename='apigen-with-thinking-qwen3-5k-train.parquet' bytes=2925857 line_count=0 processed=True FileType='parquet' project_id='proj_CL6JHwQ7zDQ7q5A3AeQuS' organization_id='org_CL6JHwNPVGzWNHJUKXcFi'
id='file-7c0aba7b-a33a-4757-b9a3-534a672fb2c2' object='file' created_at=1770110350 type=None purpose=<FilePurpose.FineTune: 'fine-tune'> filename='apigen-with-thinking-qwen3-5k-test.parquet' bytes=505340 line_count=0 processed=True FileType='parquet' project_id='proj_CL6JHwQ7zDQ7q5A3AeQuS' organization_id='org_CL6JHwNPVGzWNHJUKXcFi'


In [14]:
ft_resp = together_client.fine_tuning.create(
    training_file=train_upload_details.id,
    validation_file=test_upload_details.id,
    model="Qwen/Qwen3-1.7B",
    n_epochs=1,
    n_evals=1,
    n_checkpoints=1,
    lora=True, 
    lora_r=16,
    lora_alpha=32,
    warmup_ratio=0.1,
    learning_rate=5e-5,
    suffix="apigen-5k-with-thinking",
    train_on_inputs="auto",
    wandb_api_key=os.environ.get("WANDB_API_KEY"),
    wandb_project_name="apigen-ft"
)