In [1]:
import torch
from transformers import pipeline

model_id = "RekaAI/reka-flash-3"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 5/5 [00:33<00:00,  6.77s/it]
Device set to use cuda:0


In [None]:
# Define the function schema

tools_get_weather = {
    "type": "function",
    "function": {
        "name": "get_weather",
        "description": "Get the current weather in a given location.",
        "parameters": {
            "type": "object",
            "properties": {
                "location": {
                    "type": "string",
                    "description": "Enter the 'city name' to get the weather. e.g. 'London'",
                },
                "unit": {
                    "type": "string",
                    "description": "Enter the unit of temperature. e.g. 'metric', 'imperial', 'standard'",
                    "default": "metric",
                },
            },
            "required": ["location"]
        },
    },
}

tools_get_location = {
    "type": "function",
    "function": {
        "name": "get_location",
        "description": "Returns the current location based on the user's device information.",
        "parameters": {
            "type": "object",
            "properties": {}
        },
    },
}

tools = [tools_get_weather, tools_get_location]
print(tools)


In [None]:
import jinja2

tools_template = jinja2.Template(
"""
{%- for tool in tools %}
    {%- set arguments = {} %}
    {%- if tool.function.parameters.properties %}
        {%- for key, value in tool.function.parameters.properties.items() %}
            {%- set required = tool.function.parameters.required | default([]) %}
            {%- set is_required = key in required %}
            {%- set _ = arguments.update({key: {"type": value.type, "description": value.description, "required": is_required}}) %}
        {%- endfor %}
    {%- endif %}

    {{- '{"name": "%s", "arguments": %s, "description": "%s"}' % (tool.function.name, arguments, tool.function.description) }}

    {%- if not loop.last %}
        {{- ', ' }}
    {%- endif %}
{%- endfor %}
"""
)

system_template = jinja2.Template(
"""
Agentic model with function call capability.
Do not explicitly state that you call tools or functions to a user.
If the response can be generated from your internal knowledge which is self-evident or does not change over time, do so.
The available tools are: {{ tools }}
If you decide to perform a function call, respond in the format below:
```toolcall
[
    {'name': <function-name>, 'arguments': <args-dict>}
]
```
""".strip()
)

tools_string = tools_template.render(tools=tools)
print(tools_string)
print()

system_prompt = system_template.render(tools=tools_string)
print(system_prompt)

In [None]:
# messages = [
#     {
#         "role": "system",
#         "content": [{"type": "text", "text": system_prompt}]
#     },
#     {
#         "role": "user",
#         # "content": [{"type": "text", "text": "Where am I?"}] # single (get_location)
#         # "content": [{"type": "text", "text": "What's the weather like in Seoul right now?"}] # single (get_weather)
#         # "content": [{"type": "text", "text": "What's the weather like in Seoul and London right now?"}] # parallel (get_weather)
#         # "content": [{"type": "text", "text": "What's the weather like in my current location?"}] # nested (get_location -> get_weather)
#         # "content": [{"type": "text", "text": "I want to know my current location and the current weather in Seattle, New York and London."}] # mixed parallel (x4 calls)
#         "content": [{"type": "text", "text": "y = 3\n60 / (x + y) = 12\n이 식에서 x는 뭐야?"}] # irrelevance (expecting text response)
#         # "content": [{"type": "text", "text": "안녕?"}] # multi-lingual multi-turn
#     # },
#     # {
#     #     "role": "assistant",
#     #     "content": [{"type": "text", "text": "안녕하세요, 무엇을 도와드릴까요?"}]
#     # },
#     # {
#     #     "role": "user",
#     #     "content": [{"type": "text", "text": "날씨 어때?"}]
#     }
# ] 



messages = [
  {
    "role": "system",
    "content": system_prompt
  },
    {
        "role": "user",
        "content": "What's the weather like in Seoul and London right now?"
    }
]


outputs = pipe(
    messages,
    max_new_tokens=1000,
)
print(outputs[0]["generated_text"][-1]['content'])