In [1]:
import json
from typing import List, Dict, Union, Optional
import random
from datasets import Dataset

In [2]:
from google.colab import files
uploaded = files.upload()

Saving tool invocation.json to tool invocation.json


In [3]:
# Load the dataset
with open('tool invocation.json', 'r') as f:
    data= json.load(f)

## PREPROCESSING

In [4]:
import re
def extract_and_normalize_schemas(system_text: str) -> List[Dict]:
    """
    arguments:system_text: raw text from the "system" field containing json schemas
    returns: list of normalized function schemas (empty list if none found)
    """
    schemas= []
    #extract all json objects using regex
    json_objects= re.findall(r'\{\s*".*?"\s*:\s*\{.*?\}\s*\}(?=\s*\{|\s*$|\,)', system_text, re.DOTALL)  #matches json objects till the first '}' and valid boundaries (verified by the lookahead function)

    for obj in json_objects:
        try:
            obj= obj.strip()  #removes possible whitespaces after '}' or before '{'
            obj= re.sub(r'\}\s*\{', '},{', obj)  #add missing commas
            obj = re.sub(r"(?<=\:|\,)\s*'", '"', obj)
            obj= re.sub(r'(?<!\\)"(?!\s*[:}\],])', r'\"', obj)  #escape quotes (only those which are not already escaped or followed by ':' or '}' or ']' or ',')

            #parse json (try as array if multiple objects)
            try:
                parse_data= json.loads(obj)
            except json.JSONDecodeError:
                if re.search(r'^\s*\{.*\}\s*\{', obj):
                    parse_data= json.loads(f'[{obj}]')
                else:
                    raise

            #normalize to list format
            functions= parse_data if isinstance(parse_data, list) else [parse_data]

            #standardize each schema
            for func in functions:
                if not isinstance(func, dict):
                    continue
                normalized= {
                    'name': str(func.get('name', '')),
                    'description': str(func.get('description', '')),
                    'parameters': {
                        'type': 'object',    #ensures all schemas declare parameters as key-value pairs
                        'properties': func.get('parameters', {}).get('properties', {}),
                        'required': sorted(func.get('parameters', {}).get('required', []))
                    }
                }
                schemas.append(normalized)

        except (json.JSONDecodeError, AttributeError):
            continue

    return schemas

In [5]:
def extract_user_query(chat_text: str) -> Optional[str]:
    """
    extract the user query from chat text.
    arguments: chat_text:raw text from "chat" field
    returns: extracted user query (none if not found)
    """
    for line in chat_text.split('\n'):
        if line.startswith("USER:"):
            return line[len("USER:"):].strip()
    return None


In [6]:
def format_prompt(functions: List[Dict], user_query: str) -> str:
    """
    arguments: functions: list of normalized function schemas, user_query: extracted user question/command
    returns: formatted prompt string
    """
    system_part= json.dumps(functions) if functions else "No functions available"    #converts list to json string (if list not empty)
    return f"<system>{system_part}</system>\n<user>{user_query}</user>"

In [7]:
def prepare_target(assistant_text: str) -> Union[Dict, str]:
    #match function call pattern
    func_call= re.search(
        r'<functioncall>\s*({.*?})\s*<\|endoftext\|>',
        assistant_text,
        re.DOTALL
    )

    if not func_call:
        #return clean text response if no function call found
        return assistant_text.split("ASSISTANT:")[-1].split("<|endoftext|>")[0].strip()

    try:
        #extract and deep clean JSON string
        json_str= func_call.group(1)

        #emove all control characters except \t, \n, \r
        json_str = ''.join(char for char in json_str if char == '\t' or char == '\n' or char == '\r' or ord(char) >= 32)

        #Fix common JSON issues
        json_str = (json_str
                   .replace("'", '"')  #Convert single to double quotes
                   .replace('\\n', ' ')  #Replace newlines with spaces
                   .strip())

        #Parse with multiple fallback strategies
        try:
            #first try standard parsing
            func_call = json.loads(json_str)
        except json.JSONDecodeError as e:
            #try fixing common malformations
            try:
                #handle unquoted property names
                json_str = re.sub(r'([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:', r'\1"\2":', json_str)
                func_call = json.loads(json_str)
            except:
                #final fallback: manual extraction
                name_match = re.search(r'"name"\s*:\s*"([^"]+)"', json_str) or \
                            re.search(r'name\s*:\s*"([^"]+)"', json_str)

                args_match = re.search(r'"arguments"\s*:\s*(\{.*?\}|"[^"]*")', json_str, re.DOTALL) or \
                            re.search(r'arguments\s*:\s*(\{.*?\}|"[^"]*")', json_str, re.DOTALL)

                if not name_match:
                    raise ValueError("Could not extract function name")

                func_call = {"name": name_match.group(1).strip('"\'')}

                if args_match:
                    args_str = args_match.group(1)
                    if args_str.startswith('{'):
                        try:
                            func_call["arguments"] = json.loads(args_str)
                        except:
                            #try cleaning the arguments JSON
                            args_str = args_str.replace("'", '"')
                            args_str = re.sub(r'([{,]\s*)([a-zA-Z_][a-zA-Z0-9_]*)\s*:', r'\1"\2":', args_str)
                            try:
                                func_call["arguments"] = json.loads(args_str)
                            except:
                                func_call["arguments"] = {}
                    else:
                        func_call["arguments"] = {}

        #validate and normalize the structure
        if not isinstance(func_call, dict):
            raise ValueError("Function call is not a dictionary")

        if "name" not in func_call:
            raise ValueError("Missing function name")

        #ensure arguments is always a dict
        arguments = func_call.get("arguments", {})
        if isinstance(arguments, str):
            try:
                arguments = json.loads(arguments.replace("'", '"'))
            except:
                arguments = {}

        return {
            "name": str(func_call["name"]).strip(),
            "arguments": arguments if isinstance(arguments, dict) else {}
        }

    except Exception as e:
        print(f"Invalid function call (recovered): {str(e)}")
        # Return clean text response as fallback
        return assistant_text.split("ASSISTANT:")[-1].split("<|endoftext|>")[0].strip()

In [8]:
def process_dataset(original_data: List[Dict]) -> List[Dict]:
    processed = []

    for entry in original_data:
        try:
            functions = extract_and_normalize_schemas(entry["system"])
            user_query = extract_user_query(entry["chat"])
            if not user_query:
                continue

            prompt = format_prompt(functions, user_query)
            target = prepare_target(entry["chat"])

            #create consistent structure
            is_function_call = isinstance(target, dict)
            processed.append({
                "input": prompt,
                "output": target,
                "is_function_call": is_function_call  # At root level
            })

        except Exception as e:
            print(f"Skipping entry: {e}")

    return processed


In [9]:
def balance_data(dataset: List[Dict]) -> List[Dict]:
    tool_calls = [ex for ex in dataset if ex.get("is_function_call", False)]
    plain_replies = [ex for ex in dataset if not ex.get("is_function_call", False)]

    print(f"Tool calls: {len(tool_calls)}, Plain replies: {len(plain_replies)}")

    if not tool_calls or not plain_replies:
        print("Warning: Could not balance - returning all data")
        return dataset

    min_len = min(len(tool_calls), len(plain_replies))
    return tool_calls[:min_len] + plain_replies[:min_len]

In [10]:
additional_examples = [
    {
        "system": "SYSTEM: You are a helpful assistant.\n[]",
        "chat": "USER: Hi, how are you?\nASSISTANT: I'm doing well, thanks! <|endoftext|>"
    },
    {
        "system": "SYSTEM: You are a helpful assistant.\n[]",
        "chat": "USER: What's 2+2?\nASSISTANT: The answer is 4. <|endoftext|>"
    }
]
data.extend(additional_examples)

In [11]:
def create_dataset(balanced_data: List[Dict]) -> Dataset:
    return Dataset.from_list([
        {
            "input": x["input"],
            "output": json.dumps(x["output"]) if isinstance(x["output"], dict) else x["output"],
            "is_function_call": x["is_function_call"]
        }
        for x in balanced_data
    ])


In [12]:
processed = process_dataset(data)
balanced = balance_data(processed)
train_data = create_dataset(balanced)

Tool calls: 63212, Plain replies: 49750


In [13]:
print(f"Raw data count: {len(data)}")
print(f"Processed data count: {len(processed)}")
print(f"Balanced data count: {len(balanced)}")
print(f"Final train_dataset size: {len(train_data)}")

Raw data count: 112962
Processed data count: 112962
Balanced data count: 99500
Final train_dataset size: 99500
