# Tool Calling

In [1]:
import torch
import os
import uuid
import datetime
import pprint
import json

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, TransformersEngine, CodeAgent, BitsAndBytesConfig

from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFacePipeline, HuggingFaceEmbeddings, ChatHuggingFace
from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
from langchain.tools import tool

from langchain.schema import AIMessage, HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import RunnableConfig
from langchain_core.tools import render_text_description

from typing import Annotated, Any, Dict, Optional, TypedDict, Union, List
from lightning import Fabric
from peft import LoraConfig, get_peft_model, PeftModelForCausalLM, PeftModel

from IPython.display import display, Markdown, Image, SVG
from os import walk



### Set mixed precision

In [2]:
torch.set_float32_matmul_precision("medium")
fabric = Fabric(accelerator="cuda", devices=1, precision="bf16-mixed")
device = fabric.device
fabric.launch()

Using bfloat16 Automatic Mixed Precision (AMP)


## RAG

### Load Embedding Data

In [3]:
embed_model_name = "sentence-transformers/all-mpnet-base-v2"

embeddings = HuggingFaceEmbeddings(model_name=embed_model_name)

### Load vector store

In [4]:
vector_store = FAISS.load_local(
    "./faiss_spell_index",
    embeddings=embeddings,
    allow_dangerous_deserialization=True
)
retriever = vector_store.as_retriever()

## Load the model

In [5]:
model_name = "Salesforce/Llama-xLAM-2-8b-fc-r"
# model_name = "NousResearch/Hermes-3-Llama-3.1-8B"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, load_in_4bit=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## Function

xLAM does not support native function calling, I referred to [this](https://python.langchain.com/docs/how_to/tools_prompting/) guide to parse the model's text output into JSON format.

ref: [How does function calling work under the hood?](https://www.reddit.com/r/LangChain/comments/1d8y7mq/how_does_function_calling_work_under_the_hood/)

In [6]:
@tool
def spell_retrieve(query: str) -> str:
    """Retrieve information about dungeons and dragons spell.
    
    Args:
        query (str): The spell name to search for.

    Returns:
        str: The spell information.
    """
    retrieved_docs = vector_store.similarity_search(query, k=3)

    contents = "\n\n".join(
        (f"{doc.page_content}")
        for doc in retrieved_docs
    )
    
    return contents

@tool
def get_weather(city: str) -> str:
    """
    Get the current weather for a city.
    Args:
        city (str): The name of the city.
        
    Returns:
        str: The current weather in the city.
    """
    return f"The weather in {city} is sunny."

@tool
def add(x: int, y: int) -> int:
    "Add two numbers."
    return x + y

@tool
def multiply(x: float, y: float) -> float:
    """Multiply two numbers together."""
    return x * y

@tool
def user(name: str) -> str:
    """
    User infomation retreiver

    Args:
        name (str): The name of user.

    Returns:
        str: The user information.
    """
    return f'Hi, {name}.'

tools = [
    get_weather,
    add,
    multiply,
    spell_retrieve,
    user
]

for t in tools:
    print("---------------------")
    print(t.name)
    print(t.description)
    print(t.args)

---------------------
get_weather
Get the current weather for a city.
Args:
    city (str): The name of the city.

Returns:
    str: The current weather in the city.
{'city': {'title': 'City', 'type': 'string'}}
---------------------
add
Add two numbers.
{'x': {'title': 'X', 'type': 'integer'}, 'y': {'title': 'Y', 'type': 'integer'}}
---------------------
multiply
Multiply two numbers together.
{'x': {'title': 'X', 'type': 'number'}, 'y': {'title': 'Y', 'type': 'number'}}
---------------------
spell_retrieve
Retrieve information about dungeons and dragons spell.

    Args:
        query (str): The spell name to search for.

    Returns:
        str: The spell information.
{'query': {'title': 'Query', 'type': 'string'}}
---------------------
user
User infomation retreiver

Args:
    name (str): The name of user.

Returns:
    str: The user information.
{'name': {'title': 'Name', 'type': 'string'}}


In [7]:
rendered_tools = render_text_description(tools)
print(rendered_tools)

get_weather(city: str) -> str - Get the current weather for a city.
Args:
    city (str): The name of the city.

Returns:
    str: The current weather in the city.
add(x: int, y: int) -> int - Add two numbers.
multiply(x: float, y: float) -> float - Multiply two numbers together.
spell_retrieve(query: str) -> str - Retrieve information about dungeons and dragons spell.

    Args:
        query (str): The spell name to search for.

    Returns:
        str: The spell information.
user(name: str) -> str - User infomation retreiver

Args:
    name (str): The name of user.

Returns:
    str: The user information.


In [8]:
pipe = pipeline(
    task="text-generation",
    model=model,
    tokenizer=tokenizer,
    return_full_text=False,
    max_new_tokens=2048,
    top_k=10,
    device_map="auto"
)

llm = HuggingFacePipeline(pipeline=pipe)

Device set to use cuda:0


In [9]:
class ToolCallRequest(TypedDict):
    """A typed dict that shows the inputs into the invoke_tool function."""

    name: str
    arguments: Dict[str, Any]


def invoke_tool(
    tool_call_request: Union[ToolCallRequest, List[ToolCallRequest]], config: Optional[RunnableConfig] = None
):
    """A function that we can use the perform a tool invocation.

    Args:
        tool_call_request: a dict that contains the keys name and arguments.
            The name must match the name of a tool that exists.
            The arguments are the arguments to that tool.
        config: This is configuration information that LangChain uses that contains
            things like callbacks, metadata, etc.See LCEL documentation about RunnableConfig.

    Returns:
        output from the requested tool
    """

    print("Tool call request:", tool_call_request)
    
    # Sometimes the model outputs a list of tool call requests, 
    # so I loop each tool call and append to list
    
    if isinstance(tool_call_request, list):
        output = list()

        for tool_call in tool_call_request:
            tool_name_to_tool = {tool.name: tool for tool in tools}
            name = tool_call["name"]
            requested_tool = tool_name_to_tool[name]
            output.append(requested_tool.invoke(tool_call["arguments"], config=config))
        return output
    
    tool_name_to_tool = {tool.name: tool for tool in tools}
    name = tool_call_request["name"]
    requested_tool = tool_name_to_tool[name]
    return requested_tool.invoke(tool_call_request["arguments"], config=config)

In [17]:
system_prompt = SystemMessage(f"""\
You are an assistant that has access to the following set of tools. 
Here are the names and descriptions for each tool:

{rendered_tools}

Given the user input, return the name and input of the tool to use. 
Return your response as a JSON blob with 'name' and 'arguments' keys.

The `arguments` should be a dictionary, with keys corresponding 
to the argument names and the values corresponding to the requested values.
""")

chat = ChatHuggingFace(llm=llm, tokenizer=tokenizer)

chain = chat | JsonOutputParser() | invoke_tool

query = "what is a Absorb Elements spell?"

messages = [system_prompt, HumanMessage(query)]

response = chain.invoke(messages)

if isinstance(response, list):
    for idx, text in enumerate(response):
        print(f'{idx}\t {text}')
else:
    print("tool calling output:", response)

Tool call request: [{'name': 'spell_retrieve', 'arguments': {'query': 'Absorb Elements'}}]
0	 # Absorb Elements
## Spell Name
Absorb Elements  
From Xanathar's Guide to Everything, page 150; and Elemental Evil Player's Companion, page 15.
## Description
*1st-level abjuration*
* **Casting Time:** 1 reaction, which you take when you take acid, cold, fire, lightning, or thunder damage
* **Range:** Self
* **Components:** S
* **Duration:** 1 round
- **Casting Time:** 1 reaction, which you take when you take acid, cold, fire, lightning, or thunder damage
**Casting Time:**
- **Range:** Self
**Range:**

Disintegrate
## Learned By
* **Classes:** Artificer, Wizard
* **Subclasses:** Cleric (*Peace Domain*), Cleric (*Protection Domain*), Fighter (*Eldritch Knight*), Paladin (*Oath of Redemption*), Rogue (*Arcane Trickster*)
* **Backgrounds:** Izzet Engineer
- **Classes:** Artificer, Wizard
**Classes:**
Artificer
Wizard
- **Subclasses:** Cleric (*Peace Domain*), Cleric (*Protection Domain*), Fighte