##### Copyright 2025 Laurent Brusa

In [1]:

#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Workshop: Play with a huggingface model (Part 1)

<a target="_blank" href="https://colab.research.google.com/drive/1AFSZrjC5aMhtxnbYWh3RYOSjdxd1iPjZ?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


In [7]:
!rm -rf /content/playing-with-a-huggingface-model
!git clone https://github.com/multitudes/playing-with-a-huggingface-model.git
# Create the target directory if it doesn't exist
!mkdir -p /content/exercise_input/
# copy the files with
!cp playing-with-a-huggingface-model/exercise_input/function_calling_tests.json /content/exercise_input/
!cp playing-with-a-huggingface-model/exercise_input/functions_definition.json /content/exercise_input/
!cp playing-with-a-huggingface-model/merges.txt /content

Cloning into 'playing-with-a-huggingface-model'...
remote: Enumerating objects: 19, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (16/16), done.[K
remote: Total 19 (delta 3), reused 14 (delta 2), pack-reused 0 (from 0)[K
Receiving objects: 100% (19/19), 711.95 KiB | 2.26 MiB/s, done.
Resolving deltas: 100% (3/3), done.


# Small_LLM Class Documentation

Utility class wrapping a lightweight Hugging Face causal language model for fast, 
low-memory experimentation and inference.

This class provides a simplified interface for loading and using causal language models
from Hugging Face, with automatic device selection, optimized memory usage, and utility
methods for tokenization and inference.

## Parameters

- **model_name** : `str`, default=`"Qwen/Qwen3-0.6B"`
  - Identifier of the model on the 🤗 Hub. Should be a valid causal language model that can be loaded with AutoModelForCausalLM.

- **device** : `str | None`, default=`None`
  - Computation device. If None, automatically selects device with priority: MPS (macOS) > CUDA (GPU) > CPU.

- **dtype** : `torch.dtype | None`, default=`None`
  - Numerical precision for model weights. If None, defaults to float16 for GPU/MPS devices to conserve memory, and float32 for CPU for compatibility.

- **trust_remote_code** : `bool`, default=`True`
  - Whether to allow custom code from the model repository. Required for some models that include custom modeling code.

## Attributes

- **_model_name** : `str`
  - The name/identifier of the loaded model.

- **_device** : `str`
  - The device where the model is loaded.

- **_dtype** : `torch.dtype`
  - The data type used for model weights.

- **_tokenizer** : `PreTrainedTokenizer`
  - The tokenizer associated with the model.

- **_model** : `PreTrainedModel`
  - The loaded causal language model in evaluation mode.

## Methods

- **get_logits_from_input_ids**(`input_ids: list[int]`) → `list[float]`
  - Get raw logits for the next token given a sequence of input token IDs.

- **get_path_to_vocabulary_json**() → `str`
  - Download and return the path to the model's vocabulary JSON file.

## Notes

- The model is automatically set to evaluation mode and gradients are disabled for all parameters to optimize for inference.
- If the tokenizer lacks a pad token, the EOS token is used as the pad token.
- For CUDA devices, uses device_map="auto" for efficient multi-GPU handling.

## Examples

```python
# Uses default Qwen model
llm = Small_LLM()

# Use a specific model with CPU device
llm = Small_LLM("microsoft/DialoGPT-medium", device="cpu")

# Get logits for next token prediction
logits = llm.get_logits_from_input_ids([1, 2, 3, 4])
```

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, PreTrainedModel, logging
from huggingface_hub import hf_hub_download

logging.set_verbosity_info()


class Small_LLM:
    def __init__(
        self,
        model_name: str = "Qwen/Qwen3-0.6B",
        *,
        device: str | None = None,
        dtype: torch.dtype | None = None,
        trust_remote_code: bool = True,
    ) -> None:
        self._model_name = model_name

        # Auto-select device with priority: mps > cuda > cpu
        if device is None:
            if torch.backends.mps.is_available():
                device = "mps"
            elif torch.cuda.is_available():
                device = "cuda"
            else:
                device = "cpu"
        self._device = device

        if dtype is None:
            dtype = torch.float16 if self._device in ["cuda", "mps"] else torch.float32
        self._dtype = dtype

        # --- load tokenizer & model -------------------------------------------------
        self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=trust_remote_code
        )
        if self._tokenizer.pad_token_id is None:
            # ensure we have a pad token to keep batch helpers happy
            self._tokenizer.pad_token_id = self._tokenizer.eos_token_id

        self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=self._dtype,
            device_map="auto" if self._device == "cuda" else None,
            trust_remote_code=trust_remote_code,
        )
        self._model.to(self._device).eval()

        # switch to inference-only mode
        for p in self._model.parameters():
            p.requires_grad = False

    # -------------------------------------------------------------------------
    # Public helpers
    # -------------------------------------------------------------------------
    def get_logits_from_input_ids(self, input_ids: list[int]) -> list[float]:
        """
        Given a list of input token ids, return the raw logits (no softmax) 
        for the next token as a list of floats.
        """
        input_tensor = torch.tensor([input_ids], device=self._device, dtype=torch.long)
        with torch.no_grad():
            out = self._model(input_ids=input_tensor)
        # Get logits for the last token in the sequence for the batch (batch size 1)
        logits = out.logits[0, -1].tolist()
        return [float(x) for x in logits]

    def get_path_to_vocabulary_json(self) -> str:
        # Download and get paths to specific files
        vocab_file_name = self._tokenizer.vocab_files_names.get('vocab_file', "vocab.json")
        vocab_path = hf_hub_download(
            repo_id=self._model_name,
            filename=vocab_file_name
        )
        return vocab_path

In [None]:
# src/bpe_tokenizer.py
import re
import json

MAX_TOKENS = 150
SPECIAL_TOKENS = ["<|im_start|>", "<|im_end|>", "<think>", "</think>"]
END_TOKEN_ID1 = 3417
END_TOKEN_ID2 = 30975
MERGES_PATH = "merges.txt"
SPECIAL_TOKENS = {
    "<|im_start|>": 151644,
    "<|im_end|>": 151645,
    "<think>": 151667,
    "</think>": 151668,
}
# These token IDs to signal end of json generation '}}' and '\"}}"


def initialize_tokenizer(vocab_path):
    """
    Initialize the BPE tokenizer by loading the vocabulary and merge ranks
    from the respective files. The merge_ranks is a dictionary mapping
    token pairs to their rank (lower rank means higher priority for merging).
    The special tokens are added to the vocabulary if not already present.
    The merge.txt file is expected and eventually needs to be downloaded
    from the same source as the vocab.json file. It contains the rules for
    merging tokens during the BPE tokenization process.
    Returns: vocab (dict): Mapping of tokens to their IDs.
        merge_ranks (dict): Mapping of token pairs to their merge ranks.
    Raises: RuntimeError: If there is an error loading the vocabulary
        or merges file.
    """
    try:
        with open(vocab_path, "r") as f:
            vocab = json.load(f)
    except Exception as e:
        raise RuntimeError(f"Error loading vocabulary: {e}")
    for tok, tid in SPECIAL_TOKENS.items():
        if tok not in vocab:
            vocab[tok] = tid
    try:
        with open(MERGES_PATH, "r") as f:
            merges = [line.strip().split()
                      for line in f if not line.startswith("#")]
        merge_ranks = {tuple(merge): i for i, merge in enumerate(merges)}
    except Exception as e:
        raise RuntimeError(
            f"Error loading merges.txt file needed for the tokenizer: {e}")
    return vocab, merge_ranks


def get_pairs(tokens):
    """
    Return set of adjacent token pairs.
    """
    return {(tokens[i], tokens[i+1]) for i in range(len(tokens)-1)}


def preprocess_for_bpe(text):
    """
    Preprocess text for BPE tokenization by replacing spaces,
    newlines, and tabs for consistent tokenization.
    """
    text = text.replace(" ", "Ġ")
    text = text.replace("\n", "Ċ")
    text = text.replace("\t", "ĉ")
    return text


def bpe_tokenize(text, vocab, merge_ranks):
    """
    A simple BPE tokenizer implementation.
    This function tokenizes the input text using Byte Pair Encoding (BPE)
    based on the provided vocabulary and merge ranks.
    There are some special tokens that are needed for the prompt structure
    that should be treated as single tokens and not split further like:
    151644: <|im_start|>
    151645: <|im_end|>
    """
    pattern = "(" + "|".join(
        re.escape(tok) for tok in SPECIAL_TOKENS.keys()) + ")"
    parts = re.split(pattern, text)
    tokens = []
    for part in parts:
        if part in SPECIAL_TOKENS:
            tokens.append(part)
        else:
            tokens.extend(list(preprocess_for_bpe(part)))
    while True:
        pairs = get_pairs(tokens)
        # Find the best pair to merge
        min_rank = float('inf')
        best_pair = None
        for pair in pairs:
            if pair in merge_ranks and merge_ranks[pair] < min_rank:
                min_rank = merge_ranks[pair]
                best_pair = pair
        if best_pair is None:
            break
        # Merge all occurrences of the best pair
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == best_pair:
                new_tokens.append(tokens[i] + tokens[i+1])
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        tokens = new_tokens
    return [vocab[token] for token in tokens if token in vocab]


def custom_decode(ids, id_to_token):
    """
    Custom decode function to convert token IDs back to text.
    There are some special token IDS that we want to skip in the response:
    151667: <think>
    151668: </think>
    See SPECIAL_TOKENS dictionary above for reference.
    Those tokens are not included in the vocab dictionary we use for decoding.
    """
    skip_ids = set(SPECIAL_TOKENS.values())
    tokens = [id_to_token.get(i, "<unk>") for i in ids if i not in skip_ids]
    text = "".join(tokens)
    text = text.replace("Ġ", " ").replace("Ċ", "\n").replace("ĉ", "\t")
    print("\n\nllm output:", text, end="")
    return text


def create_prompt(user_input: str, tools: str) -> str:
    """
    Create the prompt for the LLM based on user input and available tools.
    """
    system_msg = "You are a helpful assistant that uses tools. "
    system_msg += "Based on the user's request, you must call the "
    system_msg += "appropriate tool with the correct arguments. "
    system_msg += "You have access to the following tools:\n"
    system_msg += f"{tools}"
    system_msg += """
---
Here are some examples:

User: Multiply 45 by 11
Assistant: {"fn_name": "fn_multiply_numbers", "args": {"a": 45.0, "b": 11.0}}

User: can you reverse the word 'banana'?
Assistant: {"fn_name": "fn_reverse_string", "args": {"s": "banana"}}

User: Substitute the digits in the string 'Hello 34 I'm 233 years old' with 'NUMBERS'
Assistant: {"fn_name": "fn_substitute_string_with_regex", "args": {"source_string": "Hello 34 I'm 233 years old", "regex": "\\\\d+", "replacement": "NUMBERS"}}
---

Now, answer the following request. Only provide the JSON for the tool call.
"""
    return (
        f"<|im_start|>system\n{system_msg}<|im_end|>\n"
        f"<|im_start|>user\n{user_input}/no_think<|im_end|>\n"
        f"<|im_start|>assistant\n"
    )


def get_answer_ids(llm, input_ids):
    """
    The llmm takes a list of input token ids and generates
    a list of logits for the next token at each step.
    The next token is chosen as the one with the highest logit,
    and appended to the input_ids for the next generation.
    At the same time I am interested in collecting the
    generated token ids to decode later in answer_ids.
    args:
        llm: instance of Small_LLM class
        input_ids: list of input token ids (integers)
    returns: list of generated token ids (integers)
    The generation stops when either the maximum number of tokens
    is reached or when the end token is generated.
    """
    answer_ids = []
    for _ in range(MAX_TOKENS):
        print(".", end="", flush=True)
        logits = llm.get_logits_from_input_ids(input_ids)
        next_token_id = max(enumerate(logits), key=lambda x: x[1])[0]
        input_ids.append(next_token_id)
        answer_ids.append(next_token_id)
        if (next_token_id == END_TOKEN_ID1 or next_token_id == END_TOKEN_ID2):
            break
    return answer_ids


In [10]:
# src/schemas.py
from pydantic import BaseModel
from typing import List, Dict, Any


class FunctionDefinition(BaseModel):
    """Schema for defining a function signature.

    This is what we get from input/functions_definition.json.
    This will be passed to the model in the prompt to let it know
    what functions are available to call.

    Example:
        {
            "fn_name": "fn_add_numbers",
            "args_names": ["a", "b"],
            "args_types": {
                "a": "float",
                "b": "float"
            },
            "return_type": "float"
        }

    Attributes:
        fn_name (str): Name of the function.
        args_names (List[str]): Ordered list of argument names.
        args_types (Dict[str, str]): Mapping of argument names to their types.
        return_type (str): The return type of the function.
    """
    fn_name: str
    args_names: List[str]
    args_types: Dict[str, str]
    return_type: str


class SelectedFunction(BaseModel):
    """Schema for the function selected by the model to call.

    This is what we expect the model to output after being prompted
    with a list of available functions and a user prompt.

    Attributes:
        prompt (str): The original natural-language request.
        fn_name (str | None): The name of the function to call.
        args (Dict[str, Any]): All required arguments with the correct types.
    """
    prompt: str
    fn_name: str | None
    args: Dict[str, Any]


class ToolParameter(BaseModel):
    """Represents the parameters for a tool function.

    Example:
        {
            "type": "object",
            "properties": {
                "city": {
                    "type": "string",
                    "description": "The city to get the weather for"
                }
            },
            "required": ["city"]
        }

    Attributes:
        type (str): The type of the parameters object (default: "object").
        properties (Dict[str, Dict[str, str]]): Properties of the parameters.
        required (List[str]): List of required parameter names.
    """
    type: str = "object"
    properties: Dict[str, Dict[str, str]]
    required: List[str] = []


class ToolFunction(BaseModel):
    """Represents a function within a tool.

    Example:
        {
            "name": "get_weather",
            "description": "Get the weather in a given city",
            "parameters": { ... }
        }

    Attributes:
        name (str): The function name.
        description (str): The function description.
        parameters (ToolParameter): The function parameters.
    """
    name: str
    description: str
    parameters: ToolParameter


class Tool(BaseModel):
    """Represents a tool schema for LLM tool selection.

    Example:
        {
            "type": "function",
            "function": { ... }
        }

    Attributes:
        type (str): The type of the tool (default: "function").
        function (ToolFunction): The function associated with the tool.
    """
    type: str = "function"
    function: ToolFunction


In [11]:
# src/utils.py
import json
import re
import os

from typing import List

OUTPUT_FILE = "output/function_calling_name.json"
TOOLS_DEFINITION_FILE = "exercise_input/functions_definition.json"
THINK_TAG = "</think>"


def get_functions() -> List[FunctionDefinition]:
    """
    The functions to load are actually the ones defined in the
    PATH_TOOLS_DEFINITION file. Those are converted to
    FunctionDefinition objects and returned as a list.
    Raises: RuntimeError: If there is an error loading the file
        or parsing the JSON because the program cannot continue without
        a valid functions list.
    """
    try:
        with open(TOOLS_DEFINITION_FILE) as f:
            functions_raw = json.load(f)
        functions = [FunctionDefinition(**fn) for fn in functions_raw]
        return functions
    except (FileNotFoundError, json.JSONDecodeError) as e:
        raise RuntimeError(f"Error loading functions for tools: {e}")


def get_input_prompts(file: str) -> List[str]:
    """
    Load the prompts from a JSON file.
    Each prompt should be under the "prompt" key.
    Args:
        file (str): Path to the JSON file containing the prompts.
    Returns: List[str]: List of prompt strings.
    Raises: RuntimeError: If there is an error loading the file
        or parsing the JSON.
    """
    try:
        with open(file) as f:
            prompts_raw = json.load(f)
        return [pr["prompt"] for pr in prompts_raw]
    except FileNotFoundError:
        raise RuntimeError(f"Error: {file} not found.")
    except json.JSONDecodeError:
        raise RuntimeError(f"Error: JSON decode failed for {file}.")


def get_tool_list() -> str:
    """
    Convert a list of FunctionDefinition objects to a JSON string
    If the json conversion fails, a RuntimeError is raised because the
    program cannot continue without a valid tools list.
    """
    tools = []
    functions = get_functions()
    for fn in functions:
        # Build properties for each argument
        properties = {}
        for name in fn.args_names:
            arg_type = fn.args_types[name]
            if arg_type == "float":
                type_str = "number"
            elif arg_type == "str":
                type_str = "string"
            elif arg_type == "int":
                type_str = "integer"
            else:
                type_str = "any"
            properties[name] = {"type": type_str}
        # Create a readable description from the function name
        readable_name = fn.fn_name
        if readable_name.startswith("fn_"):
            readable_name = readable_name[3:]
        description = readable_name.replace('_', ' ') + " function"
        tool = Tool(
            function=ToolFunction(
                name=fn.fn_name,
                description=description,
                parameters=ToolParameter(
                    properties=properties,
                    required=fn.args_names
                )
            )
        )
        tools.append(tool.model_dump()) # Use model_dump() instead of dict()
    try:
        return json.dumps(tools, indent=2)
    except Exception as e:
        raise RuntimeError(f"Error converting tools to JSON: {e}")


def enforce_arg_types(fn_name, args, functions_def):
    """
    Sometimes the LLM returns a float as 1 instead of 1.0 for example.
    Given a function name in the input requirements
    and its arguments as strings, convert the argument
    values to the correct types based on the function definitions.
    If the function name is not found in the definitions,
    or if an argument cannot be converted, it is left as is.
    Args:
        fn_name (str): The name of the function.
        args (dict): The arguments from the LLM output as strings.
        functions_def (List[dict]): List of function definitions
        (tools) as dicts.
    Returns:
        dict: The arguments with values converted to the correct types.
    """
    # Find the function definition
    fn_def = next((f for f in functions_def if f["fn_name"] == fn_name), None)
    if not fn_def:
        return args
    for arg_name, arg_type in fn_def["args_types"].items():
        if arg_name in args:
            try:
                if arg_type == "float":
                    # print(f"Converting arg {arg_name} to float")
                    args[arg_name] = float(args[arg_name])
                    # print(f"Converted arg {arg_name}: {args[arg_name]}")
                elif arg_type == "int":
                    args[arg_name] = int(args[arg_name])
                elif arg_type == "str":
                    args[arg_name] = str(args[arg_name])
                # Add more types as needed
            except (ValueError, TypeError):
                pass  # Leave as is if conversion fails
    return args


def extract_json_from_response(
        prompt: str,
        response: str
) -> SelectedFunction | None:
    """
    Extracts and parses a JSON object from the model's full output string.
    If the model output does not contain valid JSON, returns an empty
    SelectedFunction object with fn_name as an empty string.

    Args:
        prompt (str): The original natural-language request.
        response (str): The full output string from the model.
    Returns:
        SelectedFunction: The parsed SelectedFunction object.
    """
    # First get rid of the think block if it exists
    if THINK_TAG in response:
        response = response.split(THINK_TAG, 1)[1].strip()
    pattern = r'\{\s*"fn_name":.*?\}\s*\}'
    match = re.search(pattern, response, re.DOTALL)
    if not match:
        print("No JSON object found in the response.")
        return SelectedFunction(prompt=prompt, fn_name="", args={})
    json_str = match.group(0)
    try:
        data = json.loads(json_str)
        fn_name = data.get("fn_name")
        args = data.get("args", {})
        functions_def = get_functions()
        # Convert to dicts for enforce_arg_types
        functions_def_dicts = [fn.model_dump() for fn in functions_def]
        args = enforce_arg_types(fn_name, args, functions_def_dicts)
        print(f"\nargs post check: {args}")
        return SelectedFunction(prompt=prompt, fn_name=fn_name, args=args)
    except Exception as e:
        print(f"Error parsing JSON from response: {e}")
        return SelectedFunction(prompt=prompt, fn_name="", args={})


def write_output_to_file(output_to_write_to_file):
    os.makedirs("output", exist_ok=True)
    with open(OUTPUT_FILE, "w") as f:
        json.dump([o.model_dump() for o in output_to_write_to_file], f) # Use model_dump() instead of dict()

In [None]:
# src/__main__.py


INPUT_FILE = "exercise_input/function_calling_tests.json"



"""
Main entry point for function-calling LLM pipeline.

Args:
    input_file (str, optional): Path to the prompts JSON file. Defaults to
        "exercise_input/function_calling_tests.json".
"""
try:
    llm = Small_LLM()
    vocab_path = llm.get_path_to_vocabulary_json()
    vocab, merge_ranks = initialize_tokenizer(vocab_path)
    outputs = []
    tools = get_tool_list()
    # Reverse the vocab dict for ID to token lookup
    id_to_token = {v: k for k, v in vocab.items()}

    for user_prompt in get_input_prompts(INPUT_FILE):
        print(f"\n\nProcessing prompt: {user_prompt}")
        prompt = create_prompt(user_prompt, tools)
        # input_ids = llm._encode(final_prompt).tolist()[0]
        input_ids = bpe_tokenize(
            prompt, vocab=vocab, merge_ranks=merge_ranks)
        answer_ids = get_answer_ids(llm, input_ids)
        # llm_output = llm._decode(answer_ids)
        llm_output = custom_decode(answer_ids, id_to_token)
        result = extract_json_from_response(
            user_prompt, llm_output)
        outputs.append(result)

    write_output_to_file(outputs)
except RuntimeError as e:
    print(f"Fatal error: {e}")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
loading file vocab.json from cache at /root/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca/vocab.json
loading file merges.txt from cache at /root/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca/merges.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/c1899de289a04d12100db370d81485cdf75e47ca/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file



Processing prompt: What is the square root of 16?
..........................

llm output: 



{"fn_name": "fn_get_square_root", "args": {"a": 16.0}}
Extracted function call: fn_get_square_root with args {'a': 16.0}


Processing prompt: Reverse the string 'hello'
......................

llm output: 



{"fn_name": "fn_reverse_string", "args": {"s": "hello"}}
Extracted function call: fn_reverse_string with args {'s': 'hello'}


Processing prompt: Reverse the string 'world'
......................

llm output: 



{"fn_name": "fn_reverse_string", "args": {"s": "world"}}
Extracted function call: fn_reverse_string with args {'s': 'world'}


Processing prompt: Substitute the digits in the string 'Hello 34 I'm 233 years old' with 'NUMBERS'
..................................................

llm output: 



{"fn_name": "fn_substitute_string_with_regex", "args": {"source_string": "Hello 34 I'm 233 years old", "regex": "\\d+", "replacement": "NUMBERS"}}
Extracted function call: fn_substitute_st