diff --git a/articles/gpt-oss/fine-tune-transfomers.ipynb b/articles/gpt-oss/fine-tune-transfomers.ipynb new file mode 100644 index 0000000000..7131dfe9dd --- /dev/null +++ b/articles/gpt-oss/fine-tune-transfomers.ipynb @@ -0,0 +1,674 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b40b4db2", + "metadata": {}, + "source": [ + "# Fine-tuning a Multilingual Reasoner with Hugging Face\n", + "\n", + "Authored by: [Edward Beeching](https://huggingface.co/edbeeching), [Quentin Gallouédec](https://huggingface.co/qgallouedec), and [Lewis Tunstall](https://huggingface.co/lewtun)\n", + "\n", + "Large reasoning models like [OpenAI o3](https://openai.com/index/introducing-o3-and-o4-mini/) generate a chain-of-thought to improve the accuracy and quality of their responses. However, most of these models reason in English, even when a question is asked in another language.\n", + "\n", + "In this notebook, we show how OpenAI's open-weight reasoning model [OpenAI gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) can be fine-tuned to reason effectively in multiple languages. We'll do this by adding a new _\"reasoning language\"_ option to the model's system prompt, and applying [supervised fine-tuning](https://huggingface.co/learn/llm-course/chapter11/1) with Hugging Face's [TRL library](https://github.com/huggingface/trl) on a multilingual reasoning dataset.\n", + "\n", + "We'll cover the following steps:\n", + "\n", + "1. **Setup:** Install the required libraries.\n", + "2. **Prepare the dataset:** Download and format the dataset for fine-tuning.\n", + "3. **Prepare the model:** Loading the base model and configure it for fine-tuning [LoRA](https://huggingface.co/learn/llm-course/chapter11/4), a memory efficient technique.\n", + "4. **Fine-tuning:** Train the model with our multilingual reasoning data.\n", + "5. **Inference:** Generate reasoning responses in different languages using the fine-tuned model.\n", + "\n", + "The end result is a multilingual reasoning model that can generate a chain-of-thought in English, Spanish, French, Italian, or German. You can even _mix languages_—for example, ask a question in Spanish, request reasoning in German, and receive the final response in Spanish:\n", + "\n", + "```txt\n", + "User:\n", + " ¿Cuál es el capital de Australia?\n", + "Assistant reasoning:\n", + " Okay, der Benutzer fragt nach der Hauptstadt Australiens. Ich erinnere mich, dass Canberra die Hauptstadt ist. Ich\n", + " sollte das bestätigen. Lass mich sehen, ob es irgendwelche potenziellen Verwirrungen gibt. Der Benutzer könnte auch\n", + " an der größten Stadt interessiert sein. Die größte Stadt ist Sydney, aber die Hauptstadt ist Canberra. Ich sollte\n", + " das klarstellen. Vielleicht auch erwähnen, dass Canberra eine geplante Stadt ist und nicht die größte. Der Benutzer\n", + " könnte auch nach der Geografie fragen. Vielleicht erwähne ich, dass Canberra im südwestlichen Teil der Australian\n", + " Capital Territory liegt. Ich sollte die Antwort präzise und freundlich halten. Vielleicht auch erwähnen, dass\n", + " Canberra oft mit Sydney verwechselt wird. Ich sollte sicherstellen, dass die Antwort klar und korrekt ist.\n", + "Assistant response:\n", + " La capital de Australia es **Canberra**. Aunque es la ciudad más pequeña de las principales capitales del país, fue\n", + " elegida en 1908 como la sede del gobierno federal para equilibrar la influencia entre las ciudades de Sydney y\n", + " Melbourne. Canberra está ubicada en el Territorio de la Capital Australiana (ACT), en el este de Australia.\n", + "```\n", + "\n", + "We hope this tutorial will enable AI developers working with under-represented languages to improve the interpretability of [`openai/gpt-oss-20b`](https://huggingface.co/openai/gpt-oss-20b) in their native languages.\n", + "\n", + "> **Note:** This notebook is designed to be run on a single H100 GPU with 80GB of memory. If you have access to a smaller GPU, you can reduce the batch size and sequence length in the hyperparameters below." + ] + }, + { + "cell_type": "markdown", + "id": "ba4d0215", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "To get started, let’s install all the necessary libraries. First install PyTorch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5e95b98a", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install torch --index-url https://download.pytorch.org/whl/cu128" + ] + }, + { + "cell_type": "markdown", + "id": "46c6f749", + "metadata": {}, + "source": [ + "Next, install the remaining dependencies:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b9bad857", + "metadata": {}, + "outputs": [], + "source": [ + "%pip install \"trl>=0.20.0\" \"peft>=0.17.0\" \"transformers>=4.55.0\" trackio" + ] + }, + { + "cell_type": "markdown", + "id": "dc7497b7", + "metadata": {}, + "source": [ + "Finally, log into your Hugging Face account as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c891517e", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "id": "7b235c41", + "metadata": {}, + "source": [ + "Now that we've installed the required libraries, let's take a look at the dataset that we will use for fine-tuning." + ] + }, + { + "cell_type": "markdown", + "id": "d9534203", + "metadata": {}, + "source": [ + "## Prepare the dataset\n", + "\n", + "We will be using [Multilingual-Thinking](https://huggingface.co/datasets/HuggingFaceH4/Multilingual-Thinking), which is a reasoning dataset where the chain-of-thought has been translated into several languages such as French, Spanish, and German. By fine-tuning `openai/gpt-oss-20b` on this dataset, it will learn to generate reasoning steps in these languages, and thus its reasoning process can be interpreted by users who speak those languages." + ] + }, + { + "cell_type": "markdown", + "id": "078f9b34", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "id": "33f77056", + "metadata": {}, + "source": [ + "\n", + "\n", + "Let's download this dataset from the Hugging Face Hub:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8307b239", + "metadata": {}, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset(\"HuggingFaceH4/Multilingual-Thinking\", split=\"train\")\n", + "dataset" + ] + }, + { + "cell_type": "markdown", + "id": "38c50d00", + "metadata": {}, + "source": [ + "This is a small dataset of 1,000 examples, but this is usually more than sufficient for models like `openai/gpt-oss-20b` which have undergone extensive post-training. Let's take a look at one of the training examples:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "51c44ed6", + "metadata": {}, + "outputs": [], + "source": [ + "dataset[0]" + ] + }, + { + "cell_type": "markdown", + "id": "9ba75fde", + "metadata": {}, + "source": [ + "The `gpt-oss` models were trained on the Harmony response format for defining conversation structures, generating reasoning output and structuring function calls. The format is designed to mimic the OpenAI Responses API, and the table below summarizes the different message types used in the dataset:" + ] + }, + { + "cell_type": "markdown", + "id": "d8b01cef", + "metadata": {}, + "source": [ + "|||\n", + "| :---- | :--|\n", + "| `developer` | The developer message is used to provide custom instructions for the model (what we usually call the `system` role) |\n", + "| `user` | The user message is used to provide the input to the model |\n", + "| `assistant` | Output by the model which can either be a tool call or a message output. The output might also be associated with a particular “channel” identifying what the intent of the message is. |\n", + "| `analysis` | These are messages that are being used by the model for its chain-of thought |\n", + "| `final` | Messages tagged in the final channel are messages intended to be shown to the end-user and represent the responses from the model. |\n", + "| `messages` | The list of messages that combine the content of the above to produce a full conversation. This is the input to the model. |" + ] + }, + { + "cell_type": "markdown", + "id": "69d496f0", + "metadata": {}, + "source": [ + "If you're familiar with [OpenAI's messages format](https://platform.openai.com/docs/api-reference/messages/object), you will recognise this as being quite similar, but with an important difference:\n", + "\n", + "> The `assistant` turn contains two special fields: a `thinking` one which contains the model's reasoning process, and a `content` one which contains the final response to the user.\n", + "\n", + "In order to fine-tune the model, we need to convert these messages into a format that the model can understand. In practice this is done by formatting each message with the model's [_chat template_](https://huggingface.co/docs/transformers/chat_templating) and then tokenizing the resulting text. The TRL library does this automatically, but let's walk through it step by step to understand how it works.\n", + "\n", + "To do so, let's first load the tokenizer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "72d11c36", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoTokenizer\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(\"openai/gpt-oss-20b\")" + ] + }, + { + "cell_type": "markdown", + "id": "7ee27751", + "metadata": {}, + "source": [ + "Then we can use the tokenizer's `apply_chat_template()` method to format the messages:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "86445ef7", + "metadata": {}, + "outputs": [], + "source": [ + "messages = dataset[0][\"messages\"]\n", + "conversation = tokenizer.apply_chat_template(messages, tokenize=False)\n", + "print(conversation)" + ] + }, + { + "cell_type": "markdown", + "id": "d74423b1", + "metadata": {}, + "source": [ + "This chat template is quite sophisticated, so let's take a closer look at it! First, we can see there are special tokens `<|start|>` and `<|end|>` that indicate the start and end of each message. There is also a `<|return|>` token that marks the end of the conversation. These tokens help the model understand the structure of the conversation.\n", + "\n", + "We can also see there are _two_ types of system message: \n", + "\n", + "* A default `system` one that is used for all messages. In the example above, this refers to the text _\"You are ChatGPT, a large language model trained by OpenAI...\"_\n", + "* A special `developer` one that contains custom instructions (defined by the `system` role in our `messages` object). This allows us to provide additional context to the model about how it should behave for a given conversation. In the example above, this refers to the text _\"You are an AI chatbot with a lively and energetic personality.\"_\n", + "\n", + "Finally, we can see that the assistant response is contained in a series of _channels_:\n", + "\n", + "* The `analysis` channel is used for the model's reasoning process, where it can think step by step about the user's question. In the example above, this refers to the French text _\"D'accord, l'utilisateur demande les tendances Twitter...\"_\n", + "* The `final` channel is used for the model's final response to the user. In the example above, this refers to the text _\"Hey there! While I can't check Twitter...\"_" + ] + }, + { + "cell_type": "markdown", + "id": "ddc392a9", + "metadata": {}, + "source": [ + "Now that we understand how the dataset will be prepared, let's move on to preparing the model for training." + ] + }, + { + "cell_type": "markdown", + "id": "6f1186ef", + "metadata": {}, + "source": [ + "## Prepare the model\n", + "\n", + "To prepare the model for training, let's first download the weights from the [Hugging Face Hub](https://huggingface.co). We will use the `AutoModelForCausalLM` class from 🤗 Transformers to load the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04dc1f4a", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from transformers import AutoModelForCausalLM, Mxfp4Config\n", + "\n", + "quantization_config = Mxfp4Config(dequantize=True)\n", + "model_kwargs = dict(\n", + " attn_implementation=\"eager\",\n", + " torch_dtype=torch.bfloat16,\n", + " quantization_config=quantization_config,\n", + " use_cache=False,\n", + " device_map=\"auto\",\n", + ")\n", + "\n", + "model = AutoModelForCausalLM.from_pretrained(\"openai/gpt-oss-20b\", **model_kwargs)" + ] + }, + { + "cell_type": "markdown", + "id": "0f349ee8", + "metadata": {}, + "source": [ + "This will load the model with the necessary configurations for training. The `attn_implementation` is set to `eager` for better performance, and `use_cache` is set to `False` since we will fine-tune the model with gradient checkpointing.\n", + "\n", + "If you're familiar with 🤗 Transformers, you might notice that we are using the `Mxfp4Config` for quantization. This is a specific configuration for the OpenAI models that allows us to use mixed precision training with a special 4-bit floating point format called [MXFP4](https://en.wikipedia.org/wiki/Block_floating_point) that is optimized for AI workloads.\n", + "\n", + "Before we train the model, let's generate a sample response to see how the model behaves with the default settings. To do so, we need to tokenize a sample prompt and then use the model to generate a response:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff27bb35", + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\"role\": \"user\", \"content\": \"¿Cuál es el capital de Australia?\"},\n", + "]\n", + "\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\",\n", + ").to(model.device)\n", + "\n", + "output_ids = model.generate(input_ids, max_new_tokens=512)\n", + "response = tokenizer.batch_decode(output_ids)[0]\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "d9bb2ba1", + "metadata": {}, + "source": [ + "In this example, we can see that the model first reasons about the question in English, and then provides a final response in Spanish. This is the default behavior of the model, but let's see if we can change it with a bit of fine-tuning. \n", + "\n", + "To do so, we will use a technique called [LoRA](https://huggingface.co/learn/llm-course/chapter11/4) (Low-Rank Adaptation) to fine-tune the model. This technique allows us to tune a few specific layers of the model, which is particularly useful for large models like `openai/gpt-oss-20b`.\n", + "\n", + "First we need wrap the model as a `PeftModel` and define the LoRA configuration. We will use the `LoraConfig` class from the [PEFT library](https://github.com/huggingface/peft) to do this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5a14f256", + "metadata": {}, + "outputs": [], + "source": [ + "from peft import LoraConfig, get_peft_model\n", + "\n", + "peft_config = LoraConfig(\n", + " r=8,\n", + " lora_alpha=16,\n", + " target_modules=\"all-linear\",\n", + " target_parameters=[\n", + " \"7.mlp.experts.gate_up_proj\",\n", + " \"7.mlp.experts.down_proj\",\n", + " \"15.mlp.experts.gate_up_proj\",\n", + " \"15.mlp.experts.down_proj\",\n", + " \"23.mlp.experts.gate_up_proj\",\n", + " \"23.mlp.experts.down_proj\",\n", + " ],\n", + ")\n", + "peft_model = get_peft_model(model, peft_config)\n", + "peft_model.print_trainable_parameters()" + ] + }, + { + "cell_type": "markdown", + "id": "59cb24bb", + "metadata": {}, + "source": [ + "Here we've used some basic hyperparameters for LoRA, but you can experiment with different values to see how they affect the model's performance. For instance, if you increase `r` you will enable more trainable parameters, which may produce a better model at the expense of requiring more VRAM and time to train.\n", + "\n", + "**Note:** The `openai/gpt-oss-20b` model is a [Mixture-of-Experts (MoE)](https://huggingface.co/blog/moe) architecture. In addition to targeting the attention layers (`target_modules=\"all-linear\"`), it’s also important to include the projection layers within the expert modules. PEFT facilitates this via the `target_parameters` argument, which allows you to specify expert-specific layers such as `mlp.experts.down_proj` and `mlp.experts.gate_up_proj`. In this example, we target a subset of these projection layers, but you are encouraged to experiment with different configurations." + ] + }, + { + "cell_type": "markdown", + "id": "a1328c16", + "metadata": {}, + "source": [ + "Now that we have the model and dataset ready, we can define the hyperparameters for training." + ] + }, + { + "cell_type": "markdown", + "id": "40347f0d", + "metadata": {}, + "source": [ + "## Fine-tuning" + ] + }, + { + "cell_type": "markdown", + "id": "6a8fb02e", + "metadata": {}, + "source": [ + "TRL provides a convenient way to define hyperparameters for training using the `SFTConfig` class. We will set the learning rate, batch size, number of epochs, and other parameters as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ff75e13a", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTConfig\n", + "\n", + "training_args = SFTConfig(\n", + " learning_rate=2e-4,\n", + " gradient_checkpointing=True,\n", + " num_train_epochs=1,\n", + " logging_steps=1,\n", + " per_device_train_batch_size=4,\n", + " gradient_accumulation_steps=4,\n", + " max_length=2048,\n", + " warmup_ratio=0.03,\n", + " lr_scheduler_type=\"cosine_with_min_lr\",\n", + " lr_scheduler_kwargs={\"min_lr_rate\": 0.1},\n", + " output_dir=\"gpt-oss-20b-multilingual-reasoner\",\n", + " report_to=\"trackio\",\n", + " push_to_hub=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "1337d6cc", + "metadata": {}, + "source": [ + "Note that the `per_device_train_batch_size` is set to 4, and the `gradient_accumulation_steps` is set to 4. This means that we will effectively have a batch size of 4 x 4 = 16 across 1 GPU. You may need to adjust these values based on your hardware setup. We also use [Trackio](https://huggingface.co/blog/trackio) to log the training progress and metrics, but you can use any other logging library of your choice." + ] + }, + { + "cell_type": "markdown", + "id": "cde26ee3", + "metadata": {}, + "source": [ + "We now have all the pieces needed to train the model. We will use the `SFTTrainer` class from TRL to handle the training process. The trainer will take care of formatting the dataset, applying the chat template, and training the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe048a6b", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import SFTTrainer\n", + "\n", + "trainer = SFTTrainer(\n", + " model=peft_model,\n", + " args=training_args,\n", + " train_dataset=dataset,\n", + " processing_class=tokenizer,\n", + ")\n", + "trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "6af29870", + "metadata": {}, + "source": [ + "On a H100 GPU, this takes about 18 minutes to train, but may take longer depending on your hardware." + ] + }, + { + "cell_type": "markdown", + "id": "3296e42f", + "metadata": {}, + "source": [ + "## Save the model and push to the Hugging Face Hub" + ] + }, + { + "cell_type": "markdown", + "id": "b57ebaa4", + "metadata": {}, + "source": [ + "Finally, you can push the fine-tuned model to your Hub repository to share with the community:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5de8ca50", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(training_args.output_dir)\n", + "trainer.push_to_hub(dataset_name=\"HuggingFaceH4/Multilingual-Thinking\")" + ] + }, + { + "cell_type": "markdown", + "id": "3caa03d3", + "metadata": {}, + "source": [ + "**Note**: To avoid out-of-memory (OOM) errors, we recommend restarting the kernel at this point. The trained model is still occupying GPU memory, but it's no longer needed." + ] + }, + { + "cell_type": "markdown", + "id": "dcf67472", + "metadata": {}, + "source": [ + "## Inference\n", + "\n", + "Once the model is uploaded to Hub, we can use it for inference. To do so we first initialize the original base model and its tokenizer. Next, we need to merge the fine-tuned weights with the base model for fast inference:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "515c4b64", + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import AutoModelForCausalLM, AutoTokenizer\n", + "from peft import PeftModel\n", + "\n", + "# Load the tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(\"openai/gpt-oss-20b\")\n", + "\n", + "# Load the original model first\n", + "model_kwargs = dict(attn_implementation=\"eager\", torch_dtype=\"auto\", use_cache=True, device_map=\"auto\")\n", + "base_model = AutoModelForCausalLM.from_pretrained(\"openai/gpt-oss-20b\", **model_kwargs).cuda()\n", + "\n", + "# Merge fine-tuned weights with the base model\n", + "peft_model_id = \"gpt-oss-20b-multilingual-reasoner\"\n", + "model = PeftModel.from_pretrained(base_model, peft_model_id)\n", + "model = model.merge_and_unload()" + ] + }, + { + "cell_type": "markdown", + "id": "ad0bd360", + "metadata": {}, + "source": [ + "Now that the model is loaded, the final step is to generate some tokens from it! Here we use the model's `generate` method to produce output based on the input prompt. Let's first define the prompt:" + ] + }, + { + "cell_type": "markdown", + "id": "082afb43", + "metadata": {}, + "source": [ + "Now we can tokenize the prompt and generate the output. Finally, we can decode the output tokens to get the final response:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e8d1007", + "metadata": {}, + "outputs": [], + "source": [ + "REASONING_LANGUAGE = \"German\"\n", + "SYSTEM_PROMPT = f\"reasoning language: {REASONING_LANGUAGE}\"\n", + "USER_PROMPT = \"¿Cuál es el capital de Australia?\" # Spanish for \"What is the capital of Australia?\"\n", + "\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": USER_PROMPT},\n", + "]\n", + "\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\",\n", + ").to(model.device)\n", + "\n", + "gen_kwargs = {\"max_new_tokens\": 512, \"do_sample\": True, \"temperature\": 0.6, \"top_p\": None, \"top_k\": None}\n", + "\n", + "output_ids = model.generate(input_ids, **gen_kwargs)\n", + "response = tokenizer.batch_decode(output_ids)[0]\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "28452ff8", + "metadata": {}, + "source": [ + "Let's also try with languages that the model has not been explicitly fine-tuned on, such as Chinese and Hindi:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50ae0810", + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "REASONING_LANGUAGE = \"Chinese\" # or Hindi, or any other language...\n", + "SYSTEM_PROMPT = f\"reasoning language: {REASONING_LANGUAGE}\"\n", + "USER_PROMPT = \"What is the national symbol of Canada?\"\n", + "\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", + " {\"role\": \"user\", \"content\": USER_PROMPT},\n", + "]\n", + "\n", + "input_ids = tokenizer.apply_chat_template(\n", + " messages,\n", + " add_generation_prompt=True,\n", + " return_tensors=\"pt\",\n", + ").to(model.device)\n", + "\n", + "output_ids = model.generate(input_ids, **gen_kwargs)\n", + "response = tokenizer.batch_decode(output_ids)[0]\n", + "print(response)" + ] + }, + { + "cell_type": "markdown", + "id": "930b7b03", + "metadata": {}, + "source": [ + "Great, it works - we've now fine-tuned `openai/gpt-oss-20b` to reason in multiple languages!" + ] + }, + { + "cell_type": "markdown", + "id": "fd3ff811", + "metadata": {}, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "id": "f404075f", + "metadata": {}, + "source": [ + "Congratulations! You have successfully fine-tuned a multilingual reasoning model using the TRL library and LoRA. The steps in this notebook can be adapted to fine-tune [`openai/gpt-oss-20b`](https://huggingface.co/openai/gpt-oss-20b) on many other [datasets](https://huggingface.co/datasets) on the Hugging Face Hub - we are excited to see what you'll build!" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "openai-tsm", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/articles/gpt-oss/handle-raw-cot.md b/articles/gpt-oss/handle-raw-cot.md new file mode 100644 index 0000000000..7334202ebc --- /dev/null +++ b/articles/gpt-oss/handle-raw-cot.md @@ -0,0 +1,123 @@ +# How to handle the raw chain of thought in gpt-oss + +The [gpt-oss models](https://openai.com/open-models) provide access to a raw chain of thought (CoT) meant for analysis and safety research by model implementors, but it’s also crucial for the performance of tool calling, as tool calls can be performed as part of the CoT. At the same time, the raw CoT might contain potentially harmful content or could reveal information to users that the person implementing the model might not intend (like rules specified in the instructions given to the model). You therefore should not show raw CoT to end users. + +## Harmony / chat template handling + +The model encodes its raw CoT as part of our [harmony response format](https://cookbook.openai.com/articles/openai-harmony). If you are authoring your own chat templates or are handling tokens directly, make sure to [check out harmony guide first](https://cookbook.openai.com/articles/openai-harmony). + +To summarize a couple of things: + +1. CoT will be issued to the `analysis` channel +2. After a message to the `final` channel in a subsequent sampling turn all `analysis` messages should be dropped. Function calls to the `commentary` channel can remain +3. If the last message by the assistant was a tool call of any type, the analysis messages until the previous `final` message should be preserved on subsequent sampling until a `final` message gets issued + +## Chat Completions API + +If you are implementing a Chat Completions API, there is no official spec for handling chain of thought in the published OpenAI specs, as our hosted models will not offer this feature for the time being. We ask you to follow [the following convention from OpenRouter instead](https://openrouter.ai/docs/use-cases/reasoning-tokens). Including: + +1. Raw CoT will be returned as part of the response unless `reasoning: { exclude: true }` is specified as part of the request. [See details here](https://openrouter.ai/docs/use-cases/reasoning-tokens#legacy-parameters) +2. The raw CoT is exposed as a `reasoning` property on the message in the output +3. For delta events the delta has a `reasoning` property +4. On subsequent turns you should be able to receive the previous reasoning (as `reasoning`) and handle it in accordance with the behavior specified in the chat template section above. + +When in doubt, please follow the convention / behavior of the OpenRouter implementation. + +## Responses API + +For the Responses API we augmented our Responses API spec to cover this case. Below are the changes to the spec as type definitions. At a high level we are: + +1. Introducing a new `content` property on `reasoning`. This allows a reasoning `summary` that could be displayed to the end user to be returned at the same time as the raw CoT (which should not be shown to the end user, but which might be helpful for interpretability research). +2. Introducing a new content type called `reasoning_text` +3. Introducing two new events `response.reasoning_text.delta` to stream the deltas of the raw CoT and `response.reasoning_text.done` to indicate a turn of CoT to be completed +4. On subsequent turns you should be able to receive the previous reasoning and handle it in accordance with the behavior specified in the chat template section above. + +**Item type changes** + +```typescript +type ReasoningItem = { + id: string; + type: "reasoning"; + summary: SummaryContent[]; + // new + content: ReasoningTextContent[]; +}; + +type ReasoningTextContent = { + type: "reasoning_text"; + text: string; +}; + +type ReasoningTextDeltaEvent = { + type: "response.reasoning_text.delta"; + sequence_number: number; + item_id: string; + output_index: number; + content_index: number; + delta: string; +}; + +type ReasoningTextDoneEvent = { + type: "response.reasoning_text.done"; + sequence_number: number; + item_id: string; + output_index: number; + content_index: number; + text: string; +}; +``` + +**Event changes** + +```typescript +... +{ + type: "response.content_part.added" + ... +} +{ + type: "response.reasoning_text.delta", + sequence_number: 14, + item_id: "rs_67f47a642e788191aec9b5c1a35ab3c3016f2c95937d6e91", + output_index: 0, + content_index: 0, + delta: "The " +} +... +{ + type: "response.reasoning_text.done", + sequence_number: 18, + item_id: "rs_67f47a642e788191aec9b5c1a35ab3c3016f2c95937d6e91", + output_index: 0, + content_index: 0, + text: "The user asked me to think" +} +``` + +**Example responses output** + +```typescript +"output": [ + { + "type": "reasoning", + "id": "rs_67f47a642e788191aec9b5c1a35ab3c3016f2c95937d6e91", + "summary": [ + { + "type": "summary_text", + "text": "**Calculating volume of gold for Pluto layer**\n\nStarting with the approximation..." + } + ], + "content": [ + { + "type": "reasoning_text", + "text": "The user asked me to think..." + } + ] + } +] + +``` + +## Displaying raw CoT to end-users + +If you are providing a chat interface to users, you should not show the raw CoT because it might contain potentially harmful content or other information that you might not intend to show to users (like, for example, instructions in the developer message). Instead, we recommend showing a summarized CoT, similar to our production implementations in the API or ChatGPT, where a summarizer model reviews and blocks harmful content from being shown. diff --git a/articles/gpt-oss/run-locally-ollama.md b/articles/gpt-oss/run-locally-ollama.md new file mode 100644 index 0000000000..65a5d5fe2b --- /dev/null +++ b/articles/gpt-oss/run-locally-ollama.md @@ -0,0 +1,163 @@ +# How to run gpt-oss locally with Ollama + +Want to get [**OpenAI gpt-oss**](https://openai.com/open-models) running on your own hardware? This guide will walk you through how to use [Ollama](https://ollama.ai) to set up **gpt-oss-20b** or **gpt-oss-120b** locally, to chat with it offline, use it through an API, and even connect it to the Agents SDK. + +Note that this guide is meant for consumer hardware, like running a model on a PC or Mac. For server applications with dedicated GPUs like NVIDIA’s H100s, [check out our vLLM guide](https://cookbook.openai.com/articles/gpt-oss/run-vllm). + +## Pick your model + +Ollama supports both model sizes of gpt-oss: + +- **`gpt-oss-20b`** + - The smaller model + - Best with **≥16GB VRAM** or **unified memory** + - Perfect for higher-end consumer GPUs or Apple Silicon Macs +- **`gpt-oss-120b`** + - Our larger full-sized model + - Best with **≥60GB VRAM** or **unified memory** + - Ideal for multi-GPU or beefy workstation setup + +**A couple of notes:** + +- These models ship **MXFP4 quantized** out the box and there is currently no other quantization +- You _can_ offload to CPU if you’re short on VRAM, but expect it to run slower. + +## Quick setup + +1. **Install Ollama** → [Get it here](https://ollama.com/download) +2. **Pull the model you want:** + +```shell +# For 20B +ollama pull gpt-oss:20b + +# For 120B +ollama pull gpt-oss:120b +``` + +## Chat with gpt-oss + +Ready to talk to the model? You can fire up a chat in the app or the terminal: + +```shell +ollama run gpt-oss:20b +``` + +Ollama applies a **chat template** out of the box that mimics the [OpenAI harmony format](https://example.com/harmony-docs). Type your message and start the conversation. + +## Use the API + +Ollama exposes a **Chat Completions-compatible API**, so you can use the OpenAI SDK without changing much. Here’s a Python example: + +```py +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:11434/v1", # Local Ollama API + api_key="ollama" # Dummy key +) + +response = client.chat.completions.create( + model="gpt-oss:20b", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Explain what MXFP4 quantization is."} + ] +) + +print(response.choices[0].message.content) +``` + +If you’ve used the OpenAI SDK before, this will feel instantly familiar. + +Alternatively, you can use the Ollama SDKs in [Python](https://github.com/ollama/ollama-python) or [JavaScript](https://github.com/ollama/ollama-js) directly. + +## Using tools (function calling) + +Ollama can: + +- Call functions +- Use a **built-in browser tool** (in the app) + +Example of invoking a function via Chat Completions: + +```py +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather in a given city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }, + }, + } +] + +response = client.chat.completions.create( + model="gpt-oss:20b", + messages=[{"role": "user", "content": "What's the weather in Berlin right now?"}], + tools=tools +) + +print(response.choices[0].message) +``` + +Since the models can perform tool calling as part of the chain-of-thought (CoT) it’s important for you to return the reasoning returned by the API back into a subsequent call to a tool call where you provide the answer until the model reaches a final answer. + +## Responses API workarounds + +Ollama doesn’t (yet) support the **Responses API** natively. + +If you do want to use the Responses API you can use [**Hugging Face’s `Responses.js` proxy**](https://github.com/huggingface/responses.js) to convert Chat Completions to Responses API. + +For basic use cases you can also [**run our example Python server with Ollama as the backend.**](https://github.com/openai/gpt-oss?tab=readme-ov-file#responses-api) This server is a basic example server and does not have the + +```shell +pip install gpt-oss +python -m gpt_oss.responses_api.serve \ + --inference_backend=ollama \ + --checkpoint gpt-oss:20b +``` + +## Agents SDK integration + +Want to use gpt-oss with OpenAI’s **Agents SDK**? + +Both Agents SDK enable you to override the OpenAI base client to point to Ollama using Chat Completions or your Responses.js proxy for your local models. Alternatively, you can use the built-in functionality to point the Agents SDK against third party models. + +- **Python:** Use [LiteLLM](https://openai.github.io/openai-agents-python/models/litellm/) to proxy to Ollama through LiteLLM +- **TypeScript:** Use [AI SDK](https://openai.github.io/openai-agents-js/extensions/ai-sdk/) with the [ollama adapter](https://ai-sdk.dev/providers/community-providers/ollama) + +Here’s a Python Agents SDK example using LiteLLM: + +```py +import asyncio +from agents import Agent, Runner, function_tool, set_tracing_disabled +from agents.extensions.models.litellm_model import LitellmModel + +set_tracing_disabled(True) + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(model: str, api_key: str): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=LitellmModel(model="ollama/gpt-oss:120b", api_key=api_key), + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + +if __name__ == "__main__": + asyncio.run(main()) +``` diff --git a/articles/gpt-oss/run-transformers.md b/articles/gpt-oss/run-transformers.md new file mode 100644 index 0000000000..b6b8016c16 --- /dev/null +++ b/articles/gpt-oss/run-transformers.md @@ -0,0 +1,271 @@ +# How to run gpt-oss with Hugging Face Transformers + +The Transformers library by Hugging Face provides a flexible way to load and run large language models locally or on a server. This guide will walk you through running [OpenAI gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) or [OpenAI gpt-oss-120b](https://huggingface.co/openai/gpt-oss-120b) using Transformers, either with a high-level pipeline or via low-level `generate` calls with raw token IDs. + +We'll cover the use of [OpenAI gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) or [OpenAI gpt-oss-120b](https://huggingface.co/openai/gpt-oss-120b) with the high-level pipeline abstraction, low-level \`generate\` calls, and serving models locally with \`transformers serve\`, with in a way compatible with the Responses API. + +In this guide we’ll run through various optimised ways to run the **gpt-oss models via Transformers.** + +Bonus: You can also fine-tune models via transformers, [check out our fine-tuning guide here](https://cookbook.openai.com/articles/gpt-oss/fine-tune-transformers). + +## Pick your model + +Both **gpt-oss** models are available on Hugging Face: + +- **`openai/gpt-oss-20b`** + - \~16GB VRAM requirement when using MXFP4 + - Great for single high-end consumer GPUs +- **`openai/gpt-oss-120b`** + - Requires ≥60GB VRAM or multi-GPU setup + - Ideal for H100-class hardware + +Both are **MXFP4 quantized** by default. Please, note that MXFP4 is supported in Hopper or later architectures. This includes data center GPUs such as H100 or GB200, as well as the latest RTX 50xx family of consumer cards. + +If you use `bfloat16` instead of MXFP4, memory consumption will be larger (\~48 GB for the 20b parameter model). + +## Quick setup + +1. **Install dependencies** + It’s recommended to create a fresh Python environment. Install transformers, accelerate, as well as the Triton kernels for MXFP4 compatibility: + +```bash +pip install -U transformers accelerate torch triton kernels pip install git+https://github.com/triton-lang/triton.git@main#subdirectory=python/triton_kernels +``` + +2. **(Optional) Enable multi-GPU** + If you’re running large models, use Accelerate or torchrun to handle device mapping automatically. + +## Create an Open AI Responses / Chat Completions endpoint + +To launch a server, simply use the `transformers serve` CLI command: + +```bash +transformers serve +``` + +The simplest way to interact with the server is through the transformers chat CLI + +```bash +transformers chat localhost:8000 --model-name-or-path openai/gpt-oss-20b +``` + +or by sending an HTTP request with cURL, e.g. + +```bash +curl -X POST http://localhost:8000/v1/responses -H "Content-Type: application/json" -d '{"messages": [{"role": "system", "content": "hello"}], "temperature": 0.9, "max_tokens": 1000, "stream": true, "model": "openai/gpt-oss-20b"}' +``` + +Additional use cases, like integrating `transformers serve` with Cursor and other tools, are detailed in [the documentation](https://huggingface.co/docs/transformers/main/serving). + +## Quick inference with pipeline + +The easiest way to run the gpt-oss models is with the Transformers high-level `pipeline` API: + +```py +from transformers import pipeline + +generator = pipeline( + "text-generation", + model="openai/gpt-oss-20b", + torch_dtype="auto", + device_map="auto" # Automatically place on available GPUs +) + +messages = [ + {"role": "user", "content": "Explain what MXFP4 quantization is."}, +] + +result = generator( + messages, + max_new_tokens=200, + temperature=1.0, +) + +print(result[0]["generated_text"]) +``` + +## Advanced inference with `.generate()` + +If you want more control, you can load the model and tokenizer manually and invoke the `.generate()` method: + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "openai/gpt-oss-20b" + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype="auto", + device_map="auto" +) + +messages = [ + {"role": "user", "content": "Explain what MXFP4 quantization is."}, +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, +).to(model.device) + +outputs = model.generate( + **inputs, + max_new_tokens=200, + temperature=0.7 +) + +print(tokenizer.decode(outputs[0])) +``` + +## Chat template and tool calling + +OpenAI gpt-oss models use the [harmony response format](https://cookbook.openai.com/article/harmony) for structuring messages, including reasoning and tool calls. + +To construct prompts you can use the built-in chat template of Transformers. Alternatively, you can install and use the [openai-harmony library](https://github.com/openai/harmony) for more control. + +To use the chat template: + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer + +model_name = "openai/gpt-oss-20b" + +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype="auto", +) + +messages = [ + {"role": "system", "content": "Always respond in riddles"}, + {"role": "user", "content": "What is the weather like in Madrid?"}, +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, +).to(model.device) + +generated = model.generate(**inputs, max_new_tokens=100) +print(tokenizer.decode(generated[0][inputs["input_ids"].shape[-1] :])) +``` + +To integrate the [`openai-harmony`](https://github.com/openai/harmony) library to prepare prompts and parse responses, first install it like this: + +```bash +pip install openai-harmony +``` + +Here’s an example of how to use the library to build your prompts and encode them to tokens: + +```py +import json +from openai_harmony import ( + HarmonyEncodingName, + load_harmony_encoding, + Conversation, + Message, + Role, + SystemContent, + DeveloperContent +) +from transformers import AutoModelForCausalLM, AutoTokenizer + +encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + +# Build conversation +convo = Conversation.from_messages([ + Message.from_role_and_content(Role.SYSTEM, SystemContent.new()), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Always respond in riddles") + ), + Message.from_role_and_content(Role.USER, "What is the weather like in SF?") +]) + +# Render prompt +prefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT) +stop_token_ids = encoding.stop_tokens_for_assistant_action() + +# Load model +model_name = "openai/gpt-oss-20b" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto") + +# Generate +outputs = model.generate( + input_ids=[prefill_ids], + max_new_tokens=128, + eos_token_id=stop_token_ids +) + +# Parse completion tokens +completion_ids = outputs[0][len(prefill_ids):] +entries = encoding.parse_messages_from_completion_tokens(completion_ids, Role.ASSISTANT) + +for message in entries: + print(json.dumps(message.to_dict(), indent=2)) +``` + +Note that the `Developer` role in Harmony maps to the `system` prompt in the chat template. + +## Multi-GPU & distributed inference + +The large gpt-oss-120b fits on a single H100 GPU when using MXFP4. If you want to run it on multiple GPUs, you can: + +- Use `tp_plan="auto"` for automatic placement and tensor parallelism +- Launch with `accelerate launch or torchrun` for distributed setups +- Leverage Expert Parallelism +- Use specialised Flash attention kernels for faster inference + +```py +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.distributed import DistributedConfig +import torch + +model_path = "openai/gpt-oss-120b" +tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="left") + +device_map = { + # Enable Expert Parallelism + "distributed_config": DistributedConfig(enable_expert_parallel=1), + # Enable Tensor Parallelism + "tp_plan": "auto", +} + +model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype="auto", + attn_implementation="kernels-community/vllm-flash-attn3", + **device_map, +) + +messages = [ + {"role": "user", "content": "Explain how expert parallelism works in large language models."} +] + +inputs = tokenizer.apply_chat_template( + messages, + add_generation_prompt=True, + return_tensors="pt", + return_dict=True, +).to(model.device) + +outputs = model.generate(**inputs, max_new_tokens=1000) + +# Decode and print +response = tokenizer.decode(outputs[0]) +print("Model response:", response.split("<|channel|>final<|message|>")[-1].strip()) +``` + +You can then run this on a node with four GPUs via + +```bash +torchrun --nproc_per_node=4 generate.py +``` diff --git a/articles/gpt-oss/run-vllm.md b/articles/gpt-oss/run-vllm.md new file mode 100644 index 0000000000..fc70d119f0 --- /dev/null +++ b/articles/gpt-oss/run-vllm.md @@ -0,0 +1,230 @@ +# How to run gpt-oss with vLLM + +[vLLM](https://docs.vllm.ai/en/latest/) is an open-source, high-throughput inference engine designed to efficiently serve large language models (LLMs) by optimizing memory usage and processing speed. This guide will walk you through how to use vLLM to set up **gpt-oss-20b** or **gpt-oss-120b** on a server to serve gpt-oss as an API for your applications, and even connect it to the Agents SDK. + +Note that this guide is meant for server applications with dedicated GPUs like NVIDIA’s H100s. For local inference on consumer GPUs, [check out our Ollama guide](https://cookbook.openai.com/articles/gpt-oss/run-vllm). + +## Pick your model + +vLLM supports both model sizes of gpt-oss: + +- [**`openai/gpt-oss-20b`**](https://huggingface.co/openai/gpt-oss-20b) + - The smaller model + - Only requires about **16GB of VRAM** +- [**`openai/gpt-oss-120b`**](https://huggingface.co/openai/gpt-oss-120b) + - Our larger full-sized model + - Best with **≥60GB VRAM** + - Can fit on a single H100 or multi-GPU setups + +Both models are **MXFP4 quantized** out of the box. + +## Quick Setup + +1. **Install vLLM** + vLLM recommends using [uv](https://docs.astral.sh/uv/) to manage your Python environment. This will help with picking the right implementation based on your environment. [Learn more in their quickstart](https://docs.vllm.ai/en/latest/getting_started/quickstart.html#installation). To create a new virtual environment and install vLLM run: + +```shell +uv venv --python 3.12 --seed +source .venv/bin/activate +uv pip install vllm --torch-backend=auto +``` + +2. **Start up a server and download the model** + vLLM provides a `serve` command that will automatically download the model from HuggingFace and spin up an OpenAI-compatible server on `localhost:8000`. Run the following command depending on your desired model size in a terminal session on your server. + +```shell +# For 20B +vllm serve openai/gpt-oss-20b + +# For 120B +vllm serve openai/gpt-oss-120b +``` + +## Use the API + +vLLM exposes a **Chat Completions-compatible API** and a **Responses-compatible API** so you can use the OpenAI SDK without changing much. Here’s a Python example: + +```py +from openai import OpenAI + +client = OpenAI( + base_url="http://localhost:8000/v1", + api_key="EMPTY" +) + +result = client.chat.completions.create( + model="openai/gpt-oss-20b", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Explain what MXFP4 quantization is."} + ] +) + +print(result.choices[0].message.content) + +response = client.responses.create( + model="openai/gpt-oss-120b", + instructions="You are a helfpul assistant.", + input="Explain what MXFP4 quantization is." +) + +print(response.output_text) +``` + +If you’ve used the OpenAI SDK before, this will feel instantly familiar and your existing code should work by changing the base URL. + +## Using tools (function calling) + +vLLM supports function calling and giving the model browsing capabilities. + +Function calling works through both the Responses and Chat Completions APIs. + +Example of invoking a function via Chat Completions: + +```py +tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current weather in a given city", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"] + }, + }, + } +] + +response = client.chat.completions.create( + model="openai/gpt-oss-120b", + messages=[{"role": "user", "content": "What's the weather in Berlin right now?"}], + tools=tools +) + +print(response.choices[0].message) +``` + +Since the models can perform tool calling as part of the chain-of-thought (CoT) it’s important for you to return the reasoning returned by the API back into a subsequent call to a tool call where you provide the answer until the model reaches a final answer. + +## Agents SDK Integration + +Want to use gpt-oss with OpenAI’s **Agents SDK**? + +Both Agents SDK enable you to override the OpenAI base client to point to vLLM for your self-hosted models. Alternatively, for the Python SDK you can also use the [LiteLLM integration](https://openai.github.io/openai-agents-python/models/litellm/) to proxy to vLLM. + +Here’s a Python Agents SDK example: + +``` +uv pip install openai-agents +``` + +```py +import asyncio +from openai import AsyncOpenAI +from agents import Agent, Runner, function_tool, OpenAIResponsesModel, set_tracing_disabled + +set_tracing_disabled(True) + +@function_tool +def get_weather(city: str): + print(f"[debug] getting weather for {city}") + return f"The weather in {city} is sunny." + + +async def main(model: str, api_key: str): + agent = Agent( + name="Assistant", + instructions="You only respond in haikus.", + model=OpenAIResponsesModel( + model="openai/gpt-oss-120b", + openai_client=AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="EMPTY", + ), + ) + tools=[get_weather], + ) + + result = await Runner.run(agent, "What's the weather in Tokyo?") + print(result.final_output) + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Using vLLM for direct sampling + +Aside from running vLLM using `vllm serve` as an API server, you can use the vLLM Python library to control inference directly. + +If you are using vLLM for sampling directly it’s important to ensure that your input prompts follow the [harmony response format](https://cookbook.openai.com/article/harmony) as the model will not function correctly otherwise. You can use the [`openai-harmony` SDK](https://github.com/openai/harmony) for this. + +``` +uv pip install openai-harmony +``` + +Afterwards you can use harmony to encode and parse the tokens generated by vLLM’s generate function. + +```py +import json +from openai_harmony import ( + HarmonyEncodingName, + load_harmony_encoding, + Conversation, + Message, + Role, + SystemContent, + DeveloperContent, +) + +from vllm import LLM, SamplingParams + +# --- 1) Render the prefill with Harmony --- +encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + +convo = Conversation.from_messages( + [ + Message.from_role_and_content(Role.SYSTEM, SystemContent.new()), + Message.from_role_and_content( + Role.DEVELOPER, + DeveloperContent.new().with_instructions("Always respond in riddles"), + ), + Message.from_role_and_content(Role.USER, "What is the weather like in SF?"), + ] +) + +prefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT) + +# Harmony stop tokens (pass to sampler so they won't be included in output) +stop_token_ids = encoding.stop_tokens_for_assistant_action() + +# --- 2) Run vLLM with prefill --- +llm = LLM( + model="openai/gpt-oss-120b", + trust_remote_code=True, +) + +sampling = SamplingParams( + max_tokens=128, + temperature=1, + stop_token_ids=stop_token_ids, +) + +outputs = llm.generate( + prompt_token_ids=[prefill_ids], # batch of size 1 + sampling_params=sampling, +) + +# vLLM gives you both text and token IDs +gen = outputs[0].outputs[0] +text = gen.text +output_tokens = gen.token_ids # <-- these are the completion token IDs (no prefill) + +# --- 3) Parse the completion token IDs back into structured Harmony messages --- +entries = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT) + +# 'entries' is a sequence of structured conversation entries (assistant messages, tool calls, etc.). +for message in entries: + print(f"{json.dumps(message.to_dict())}") +``` diff --git a/articles/openai-harmony.md b/articles/openai-harmony.md new file mode 100644 index 0000000000..2942e998ce --- /dev/null +++ b/articles/openai-harmony.md @@ -0,0 +1,573 @@ +# OpenAI harmony response format + +The [`gpt-oss` models](https://openai.com/open-models) were trained on the harmony response format for defining conversation structures, generating reasoning output and structuring function calls. If you are not using `gpt-oss` directly but through an API or a provider like Ollama, you will not have to be concerned about this as your inference solution will handle the formatting. If you are building your own inference solution, this guide will walk you through the prompt format. The format is designed to mimic the OpenAI Responses API, so if you have used that API before, this format should hopefully feel familiar to you. `gpt-oss` should not be used without using the harmony format, as it will not work correctly. + +## Concepts + +### Roles + +Every message that the model processes has a role associated with it. The model knows about three types of roles: + +| Role | Purpose | +| :---------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `system` | A system message is used to specify reasoning effort, meta information like knowledge cutoff and built-in tools | +| `developer` | The developer message is used to provide information about the instructions for the model (what is normally considered the “system prompt”) and available function tools | +| `user` | Typically representing the input to the model | +| `assistant` | Output by the model which can either be a tool call or a message output. The output might also be associated with a particular “channel” identifying what the intent of the message is. | +| `tool` | Messages representing the output of a tool call. The specific tool name will be used as the role inside a message. | + +These roles also represent the information hierarchy that the model applies in case there are any instruction conflicts: `system` \> `developer` \> `user` \> `assistant` \> `tool` + +#### Channels + +Assistant messages can be output in three different “channels”. These are being used to separate between user-facing responses and internal facing messages. + +| Channel | Purpose | +| :----------- | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `final` | Messages tagged in the final channel are messages intended to be shown to the end-user and represent the responses from the model. | +| `analysis` | These are messages that are being used by the model for its chain of thought (CoT). **Important:** Messages in the analysis channel do not adhere to the same safety standards as final messages do. Avoid showing these to end-users. | +| `commentary` | Any function tool call will typically be triggered on the `commentary` channel while built-in tools will normally be triggered on the `analysis` channel. However, occasionally built-in tools will still be output to `commentary`. Occasionally this channel might also be used by the model to generate a [preamble](#preambles) to calling multiple functions. | + +## Harmony renderer library + +We recommend using our harmony renderer through [PyPI](https://pypi.org/project/openai-harmony/) or [crates.io](https://crates.io/crates/openai-harmony) when possible as it will automatically handle rendering your messages in the right format and turning them into tokens for processing by the model. + +Below is an example of using the renderer to construct a system prompt and a short conversation. + +```py +from openai_harmony import ( + Author, + Conversation, + DeveloperContent, + HarmonyEncodingName, + Message, + Role, + SystemContent, + ToolDescription, + load_harmony_encoding, + ReasoningEffort +) + +encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) + +system_message = ( + SystemContent.new() + .with_model_identity( + "You are ChatGPT, a large language model trained by OpenAI." + ) + .with_reasoning_effort(ReasoningEffort.HIGH) + .with_conversation_start_date("2025-06-28") + .with_knowledge_cutoff("2024-06") + .with_required_channels(["analysis", "commentary", "final"]) +) + +developer_message = ( + DeveloperContent.new() + .with_instructions("Always respond in riddles") + .with_function_tools( + [ + ToolDescription.new( + "get_location", + "Gets the location of the user.", + ), + ToolDescription.new( + "get_current_weather", + "Gets the current weather in the provided location.", + parameters={ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "format": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + "default": "celsius", + }, + }, + "required": ["location"], + }, + ), + ] + ) + ) +) + +convo = Conversation.from_messages( + [ + Message.from_role_and_content(Role.SYSTEM, system_message), + Message.from_role_and_content(Role.DEVELOPER, system_message), + Message.from_role_and_content(Role.USER, "What is the weather in Tokyo?"), + Message.from_role_and_content( + Role.ASSISTANT, + 'User asks: "What is the weather in Tokyo?" We need to use get_weather tool.', + ).with_channel("analysis"), + Message.from_role_and_content(Role.ASSISTANT, '{"location": "Tokyo"}') + .with_channel("commentary") + .with_recipient("functions.get_weather") + .with_content_type("json"), + Message.from_author_and_content( + Author.new(Role.TOOL, "functions.lookup_weather"), + '{ "temperature": 20, "sunny": true }', + ).with_recipient("assistant").with_channel("commentary"), + ] +) + +tokens = encoding.render_conversation_for_completion(convo, Role.ASSISTANT) + +# After receiving a token response +# Do not pass in the stop token +parsed_response = encoding.parse_messages_from_completion_tokens(new_tokens, Role.ASSISTANT) +``` + +Additionally the openai_harmony library also includes a StreamableParser for parsing and decoding as the model is generating new tokens. This can be helpful for example to stream output and handle unicode characters during decoding. + +```py +from openai_harmony import ( + load_harmony_encoding, + StreamableParser +) + +encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) +stream = StreamableParser(encoding, role=Role.ASSISTANT) + +tokens = [ + 200005,35644,200008,1844,31064,25,392,4827,382,220,17,659,220,17,16842,12295,81645, + 13,51441,6052,13,200007,200006,173781,200005,17196,200008,17,659,220,17,314,220,19, + 13,200002 +] + +for token in tokens: + stream.process(token) + print("--------------------------------") + print("current_role", stream.current_role) + print("current_channel", stream.current_channel) + print("last_content_delta", stream.last_content_delta) + print("current_content_type", stream.current_content_type) + print("current_recipient", stream.current_recipient) + print("current_content", stream.current_content) +``` + +## Prompt format + +If you choose to build your own renderer, you’ll need to adhere to the following format. + +### Special Tokens + +The model uses a set of special tokens to identify the structure of your input. If you are using [tiktoken](https://github.com/openai/tiktoken) these tokens are encoded in the `o200k_harmony` encoding. All special tokens follow the format `<|type|>`. + +| Special token | Purpose | Token ID | +| :---------------------- | :------------------------------------------------------------------------------------------------------------------------------------------ | :------- | +| <|start|> | Indicates the beginning of a [message](#message-format). Followed by the “header” information of a message starting with the [role](#roles) | `200006` | +| <|end|> | Indicates the end of a [message](#message-format) | `200007` | +| <|message|> | Indicates the transition from the message “header” to the actual content | `200008` | +| <|channel|> | Indicates the transition to the [channel](#channels) information of the header | `200005` | +| <|constrain|> | Indicates the transition to the data type definition in a [tool call](#receiving-tool-calls) | `200003` | +| <|return|> | Indicates the model is done with sampling the response message. A valid “stop token” indicating that you should stop inference. | `200002` | +| <|call|> | Indicates the model wants to call a tool. A valid “stop token” indicating that you should stop inference. | `200012` | + +### Message format + +The harmony response format consists of “messages” with the model potentially generating multiple messages in one go. The general structure of a message is as follows: + +``` +<|start|>{header}<|message|>{content}<|end|> +``` + +The `{header}` contains a series of meta information including the [role](#roles). `<|end|>` represents the end of a fully completed message but the model might also use other stop tokens such as `<|call|>` for tool calling and `<|return|>` to indicate the model is done with the completion. + +### Chat conversation format + +Following the message format above the most basic chat format consists of a `user` message and the beginning of an `assistant` message. + +#### Example input + +``` +<|start|>user<|message|>What is 2 + 2?<|end|> +<|start|>assistant +``` + +The output will begin by specifying the `channel`. For example `analysis` to output the chain of thought. The model might output multiple messages (primarily chain of thought messages) for which it uses the `<|end|>` token to separate them. + +Once its done generating it will stop with either a `<|return|>` token indicating it’s done generating the final answer, or `<|call|>` indicating that a tool call needs to be performed. In either way this indicates that you should stop inference. + +#### Example output + +``` +<|channel|>analysis<|message|>User asks: "What is 2 + 2?" Simple arithmetic. Provide answer.<|end|> +<|start|>assistant<|channel|>final<|message|>2 + 2 = 4.<|return|> +``` + +The `final` channel will contain the answer to your user’s request. Check out the [reasoning section](#reasoning) for more details on the chain-of-thought. + +### System message format + +The system message is used to provide general information to the system. This is different to what might be considered the “system prompt” in other prompt formats. For that, check out the [developer message format](#developer-message-format). + +We use the system message to define: + +1. The **identity** of the model — This should always stay as `You are ChatGPT, a large language model trained by OpenAI.` If you want to change the identity of the model, use the instructions in the [developer message](#developer-message-format). +2. Meta **dates** — Specifically the `Knowledge cutoff:` and the `Current date:` +3. The **reasoning effort** — As specified on the levels `high`, `medium`, `low` +4. Available channels — For the best performance this should map to `analysis`, `commentary`, and `final`. +5. Built-in tools — The model has been trained on both a `python` and `browser` tool. Check out the [built-in tools section](#built-in-tools) for details. + +**If you are defining functions,** it should also contain a note that all function tool calls must go to the `commentary` channel. + +For the best performance stick to this format as closely as possible. + +#### Example system message + +The most basic system message you should use is the following: + +``` +<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Valid channels: analysis, commentary, final. Channel must be included for every message. +Calls to these tools must go to the commentary channel: 'functions'.<|end|> +``` + +### Developer message format + +The developer message represents what is commonly considered the “system prompt”. It contains the instructions that are provided to the model and optionally a list of [function tools](#function-calling) available for use or the output format you want the model to adhere to for [structured outputs](#structured-output). + +If you are not using function tool calling your developer message would just look like this: + +``` +<|start|>developer<|message|># Instructions + +{instructions}<|end|> +``` + +Where `{instructions}` is replaced with your “system prompt”. + +For defining function calling tools, [check out the dedicated section](#function-calling). +For defining an output format to be used in structured outputs, [check out this section of the guide](#structured-output). + +### Reasoning + +The gpt-oss models are reasoning models. By default, the model will do medium level reasoning. To control the reasoning you can specify in the [system message](#system-message-format) the reasoning level as `low`, `medium`, or `high`. The recommended format is: + +``` +Reasoning: high +``` + +The model will output its raw chain-of-thought (CoT) as assistant messages into the `analysis` channel while the final response will be output as `final`. + +For example for the question `What is 2 + 2?` the model output might look like this: + +``` +<|channel|>analysis<|message|>User asks: "What is 2 + 2?" Simple arithmetic. Provide answer.<|end|> +<|start|>assistant<|channel|>final<|message|>2 + 2 = 4.<|return|> +``` + +In this case the CoT is + +``` +User asks: “What is 2 + 2?” Simple arithmetic. Provide answer. +``` + +And the actual answer is: + +``` +2 + 2 = 4 +``` + +**Important:** +The model has not been trained to the same safety standards in the chain-of-thought as it has for final output. You should not show the chain-of-thought to your users, as they might contain harmful content. [Learn more in the model card](https://openai.com/index/gpt-oss-model-card/). + +#### Handling reasoning output in subsequent sampling + +In general, you should drop any previous CoT content on subsequent sampling if the responses by the assistant ended in a message to the `final` channel. Meaning if our first input was this: + +``` +<|start|>user<|message|>What is 2 + 2?<|end|> +<|start|>assistant +``` + +and resulted in the output: + +``` +<|channel|>analysis<|message|>User asks: "What is 2 + 2?" Simple arithmetic. Provide answer.<|end|> +<|start|>assistant<|channel|>final<|message|>2 + 2 = 4.<|return|> +``` + +For the model to work properly, the input for the next sampling should be + +``` +<|start|>user<|message|>What is 2 + 2?<|end|> +<|start|>assistant<|channel|>final<|message|>2 + 2 = 4.<|return|> +<|start|>user<|message|>What about 9 / 2?<|end|> +<|start|>assistant +``` + +The exception for this is tool/function calling. The model is able to call tools as part of its chain-of-thought and because of that, we should pass the previous chain-of-thought back in as input for subsequent sampling. Check out the [function calling section](#function-calling) for a complete example. + +### Function calling + +#### Defining available tools + +All functions that are available to the model should be defined in the [developer message](#developer-message-format) in a dedicated `Tools` section. + +To define the functions we use a TypeScript-like type syntax and wrap the functions into a dedicated `functions` namespace. It’s important to stick to this format closely to improve accuracy of function calling. You can check out the harmony renderer codebase for more information on how we are turning JSON schema definitions for the arguments into this format but some general formatting practices: + +- Define every function as a `type {function_name} = () => any` if it does not receive any arguments +- For functions that receive an argument name the argument `_` and inline the type definition +- Add comments for descriptions in the line above the field definition +- Always use `any` as the return type +- Keep an empty line after each function definition +- Wrap your functions into a namespace, generally `functions` is the namespace you should use to not conflict with [other tools](#built-in-tools) that the model might have been trained on. + +Here’s a complete input example including the definition of two functions: + +``` +<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Valid channels: analysis, commentary, final. Channel must be included for every message. +Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>developer<|message|># Instructions + +Use a friendly tone. + +# Tools + +## functions + +namespace functions { + +// Gets the location of the user. +type get_location = () => any; + +// Gets the current weather in the provided location. +type get_current_weather = (_: { +// The city and state, e.g. San Francisco, CA +location: string, +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +// Gets the current weather in the provided list of locations. +type get_multiple_weathers = (_: { +// List of city and state, e.g. ["San Francisco, CA", "New York, NY"] +locations: string[], +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +} // namespace functions<|end|><|start|>user<|message|>What is the weather like in SF?<|end|><|start|>assistant +``` + +#### Receiving tool calls + +If the model decides to call a tool it will define a `recipient` in the header of the message using the format `to={name}`. For example, if it decides to trigger the `get_current_weather` function from above it would specify `to=functions.get_current_weather` in the header and `commentary` as the channel as specified in the [system message](#system-message-format). **The recipient might be defined in the role or channel section of the header.** + +The model might also specify a `<|constrain|>` token to indicate the type of input for the tool call. In this case since it’s being passed in as JSON the `<|constrain|>` is set to `json`. + +``` +<|channel|>analysis<|message|>Need to use function get_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|> +``` + +#### Handling tool calls + +After the function call was handled we need to provide the output back to the model by specifying a new tool message with the output after the call message. + +A tool message has the following format: + +``` +<|start|>{toolname} to=assistant<|channel|>commentary<|message|>{output}<|end|> +``` + +So in our example above + +``` +<|start|>functions.get_weather to=assistant<|channel|>commentary<|message|>{"sunny": true, "temperature": 20}<|end|> +``` + +Once you have gathered the output for the tool calls you can run inference with the complete content: + +``` +<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Valid channels: analysis, commentary, final. Channel must be included for every message. +Calls to these tools must go to the commentary channel: 'functions'.<|end|><|start|>developer<|message|># Instructions + +Use a friendly tone. + +# Tools + +## functions + +namespace functions { + +// Gets the location of the user. +type get_location = () => any; + +// Gets the current weather in the provided location. +type get_current_weather = (_: { +// The city and state, e.g. San Francisco, CA +location: string, +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +// Gets the current weather in the provided list of locations. +type get_multiple_weathers = (_: { +// List of city and state, e.g. ["San Francisco, CA", "New York, NY"] +locations: string[], +format?: "celsius" | "fahrenheit", // default: celsius +}) => any; + +} // namespace functions<|end|><|start|>user<|message|>What is the weather like in SF?<|end|><|start|>assistant<|channel|>analysis<|message|>Need to use function get_weather.<|end|><|start|>assistant<|channel|>commentary to=functions.get_weather <|constrain|>json<|message|>{"location":"San Francisco"}<|call|> <|start|>functions.get_weather to=assistant<|channel|>commentary<|message|>{"sunny": true, "temperature": 20}<|end|><|start|>assistant +``` + +As you can see above we are passing not just the function out back into the model for further sampling but also the previous chain-of-thought (“Need to use function get_weather.”) to provide the model with the necessary information to continue its chain-of-thought or provide the final answer. + +#### Preambles + +At times the model might choose to generate a “preamble” to inform the user about the tools it is about to call. For example, when it plans to call multiple tools. If this is the case it will generate an assistant message on the `commentary` channel that, unlike the chain-of-thought, is intended to be shown to the end-user. + +``` +<|channel|>analysis<|message|>{long chain of thought}<|end|><|start|>assistant<|channel|>commentary<|message|>**Action plan**: +1. Generate an HTML file +2. Generate a JavaScript for the Node.js server +3. Start the server +--- +Will start executing the plan step by step<|end|><|start|>assistant<|channel|>commentary to=functions.generate_file<|constrain|>json<|message|>{"template": "basic_html", "path": "index.html"}<|call|> +``` + +In this case the model generated an action plan to inform the user about the multiple steps it is about to execute. + +### Structured output + +To control the output behavior of the model, you can define a response format at the end of the [developer message](#developer-message-format) with the following structure: + +``` +# Response Formats + +## {format name} + +// {description or context} +{schema}<|end|> +``` + +The format name functions similar to the name you can specify for your schema in the [Responses API](https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#how-to-use) and the schema is a JSON Schema. + +As an example, here’s a developer message that defines a schema for a shopping list: + +``` +<|start|>developer<|message|># Instructions + +You are a helpful shopping assistant + +# Response Formats + +## shopping_list + +{"properties":{"items":{"type":"array","description":"entries on the shopping list","items":{"type":"string"}}},"type":"object"}<|end|><|start|>user<|message|>I need to buy coffee, soda and eggs<|end|><|start|>assistant + +``` + +This prompt alone will, however, only influence the model’s behavior but doesn’t guarantee the full adherence to the schema. For this you still need to construct your own grammar and enforce the schema during sampling. + +### Built-in tools + +During the training of the `gpt-oss` models, they were trained with two common tools to browse for information and execute python code to improve its results. + +If you are trying to build this functionality, you should use the format below to improve reliability and accuracy. + +These tools should be defined in the [system message](#system-message-format) not in the developer message by adding a `# Tools` section. + +#### Browser tool + +To define the browser tool add it to the system prompt section: + +``` +<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Tools + +## browser + +// Tool for browsing. +// The `cursor` appears in brackets before each browsing display: `[{cursor}]`. +// Cite information from the tool using the following format: +// `【{cursor}†L{line_start}(-L{line_end})?】`, for example: `【6†L9-L11】` or `【8†L3】`. +// Do not quote more than 10 words directly from the tool output. +// sources=web (default: web) +namespace browser { + +// Searches for information related to `query` and displays `topn` results. +type search = (_: { +query: string, +topn?: number, // default: 10 +source?: string, +}) => any; + +// Opens the link `id` from the page indicated by `cursor` starting at line number `loc`, showing `num_lines` lines. +// Valid link ids are displayed with the formatting: `【{id}†.*】`. +// If `cursor` is not provided, the most recent page is implied. +// If `id` is a string, it is treated as a fully qualified URL associated with `source`. +// If `loc` is not provided, the viewport will be positioned at the beginning of the document or centered on the most relevant passage, if available. +// Use this function without `id` to scroll to a new location of an opened page. +type open = (_: { +id?: number | string, // default: -1 +cursor?: number, // default: -1 +loc?: number, // default: -1 +num_lines?: number, // default: -1 +view_source?: boolean, // default: false +source?: string, +}) => any; + +// Finds exact matches of `pattern` in the current page, or the page given by `cursor`. +type find = (_: { +pattern: string, +cursor?: number, // default: -1 +}) => any; + +} // namespace browser + +# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|> +``` + +If the model decides to call actions in the browser it will use the same format as for [function calls](#function-calling) with two notable exceptions: + +1. Requests will be made to the `analysis` channel +2. The recipient will be `browser.search`, `browser.open`, `browser.find` respectively + +#### Python tool + +``` +<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI. +Knowledge cutoff: 2024-06 +Current date: 2025-06-28 + +Reasoning: high + +# Tools + +## python + +Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files). + +When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 120.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is UNKNOWN. Depends on the cluster. + +# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|> +``` + +If the model decides to call actions in the browser it will use the same format as for [function calls](#function-calling) with two notable exceptions: + +3. Requests will be made to the `analysis` channel +4. The recipient will always be `python` diff --git a/authors.yaml b/authors.yaml index d84aa18f87..8f72503ba8 100644 --- a/authors.yaml +++ b/authors.yaml @@ -396,3 +396,23 @@ daisyshe-oai: name: "Daisy Sheng" website: "https://www.linkedin.com/in/daisysheng/" avatar: "https://avatars.githubusercontent.com/u/212609991?v=4" + +dkundel-openai: + name: "Dominik Kundel" + website: "https://www.linkedin.com/in/dominik-kundel/" + avatar: "https://avatars.githubusercontent.com/u/200841172?v=4" + +edbeeching: + name: "Edward Beeching" + website: "https://huggingface.co/edbeeching" + avatar: "https://cdn-avatars.huggingface.co/v1/production/uploads/1644220542819-noauth.jpeg" + +qgallouedec: + name: "Quentin Gallouédec" + website: "https://huggingface.co/qgallouedec" + avatar: "https://cdn-avatars.huggingface.co/v1/production/uploads/1677431596830-631ce4b244503b72277fc89f.jpeg" + +lewtun: + name: "Lewis Tunstall" + website: "https://huggingface.co/lewtun" + avatar: "https://cdn-avatars.huggingface.co/v1/production/uploads/1594651707950-noauth.jpeg" diff --git a/registry.yaml b/registry.yaml index 7e9cf0b1b9..d965bb9052 100644 --- a/registry.yaml +++ b/registry.yaml @@ -4,6 +4,63 @@ # should build pages for, and indicates metadata such as tags, creation date and # authors for each page. +- title: Fine-tuning with gpt-oss and Hugging Face Transformers + path: articles/gpt-oss/fine-tune-transfomers.ipynb + date: 2025-08-05 + authors: + - edbeeching + - qgallouedec + - lewtun + tags: + - open-models + - gpt-oss + +- title: How to handle the raw chain of thought in gpt-oss + path: articles/gpt-oss/handle-raw-cot.md + date: 2025-08-05 + authors: + - dkundel-openai + tags: + - open-models + - gpt-oss + +- title: How to run gpt-oss with Transformers + path: articles/gpt-oss/run-transformers.md + date: 2025-08-05 + authors: + - dkundel-openai + tags: + - open-models + - gpt-oss + +- title: How to run gpt-oss with vLLM + path: articles/gpt-oss/run-vllm.md + date: 2025-08-05 + authors: + - dkundel-openai + tags: + - open-models + - gpt-oss + +- title: How to run gpt-oss locally with Ollama + path: articles/gpt-oss/run-locally-ollama.md + date: 2025-08-05 + authors: + - dkundel-openai + tags: + - open-models + - gpt-oss + +- title: OpenAI Harmony Response Format + path: articles/openai-harmony.md + date: 2025-08-05 + authors: + - dkundel-openai + tags: + - open-models + - gpt-oss + - harmony + - title: Temporal Agents with Knowledge Graphs path: examples/partners/temporal_agents_with_knowledge_graphs/temporal_agents_with_knowledge_graphs.ipynb date: 2025-07-22 @@ -18,7 +75,6 @@ - temporal-agents - RAG - - title: Using Evals API on Image Inputs path: examples/evaluation/use-cases/EvalsAPI_Image_Inputs.ipynb date: 2025-07-15