-
-
Notifications
You must be signed in to change notification settings - Fork 144
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mistral caller, openai verison 2.8, llama function caller, tests for …
…flow
- Loading branch information
Showing
7 changed files
with
815 additions
and
12 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,217 @@ | ||
# !pip install accelerate | ||
# !pip install torch | ||
# !pip install transformers | ||
# !pip install bitsandbytes | ||
|
||
import torch | ||
from transformers import ( | ||
AutoTokenizer, | ||
AutoModelForCausalLM, | ||
BitsAndBytesConfig, | ||
TextStreamer, | ||
) | ||
from typing import Callable, Dict, List | ||
|
||
|
||
class LlamaFunctionCaller: | ||
""" | ||
A class to manage and execute Llama functions. | ||
Attributes: | ||
----------- | ||
model: transformers.AutoModelForCausalLM | ||
The loaded Llama model. | ||
tokenizer: transformers.AutoTokenizer | ||
The tokenizer for the Llama model. | ||
functions: Dict[str, Callable] | ||
A dictionary of functions available for execution. | ||
Methods: | ||
-------- | ||
__init__(self, model_id: str, cache_dir: str, runtime: str) | ||
Initializes the LlamaFunctionCaller with the specified model. | ||
add_func(self, name: str, function: Callable, description: str, arguments: List[Dict]) | ||
Adds a new function to the LlamaFunctionCaller. | ||
call_function(self, name: str, **kwargs) | ||
Calls the specified function with given arguments. | ||
stream(self, user_prompt: str) | ||
Streams a user prompt to the model and prints the response. | ||
Example: | ||
# Example usage | ||
model_id = "Your-Model-ID" | ||
cache_dir = "Your-Cache-Directory" | ||
runtime = "cuda" # or 'cpu' | ||
llama_caller = LlamaFunctionCaller(model_id, cache_dir, runtime) | ||
# Add a custom function | ||
def get_weather(location: str, format: str) -> str: | ||
# This is a placeholder for the actual implementation | ||
return f"Weather at {location} in {format} format." | ||
llama_caller.add_func( | ||
name="get_weather", | ||
function=get_weather, | ||
description="Get the weather at a location", | ||
arguments=[ | ||
{ | ||
"name": "location", | ||
"type": "string", | ||
"description": "Location for the weather", | ||
}, | ||
{ | ||
"name": "format", | ||
"type": "string", | ||
"description": "Format of the weather data", | ||
}, | ||
], | ||
) | ||
# Call the function | ||
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius") | ||
print(result) | ||
# Stream a user prompt | ||
llama_caller("Tell me about the tallest mountain in the world.") | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model_id: str = "Trelis/Llama-2-7b-chat-hf-function-calling-v2", | ||
cache_dir: str = "llama_cache", | ||
runtime: str = "auto", | ||
max_tokens: int = 500, | ||
streaming: bool = False, | ||
*args, | ||
**kwargs, | ||
): | ||
self.model_id = model_id | ||
self.cache_dir = cache_dir | ||
self.runtime = runtime | ||
self.max_tokens = max_tokens | ||
self.streaming = streaming | ||
|
||
# Load the model and tokenizer | ||
self.model = self._load_model() | ||
self.tokenizer = AutoTokenizer.from_pretrained( | ||
model_id, cache_dir=cache_dir, use_fast=True | ||
) | ||
self.functions = {} | ||
|
||
def _load_model(self): | ||
# Configuration for loading the model | ||
bnb_config = BitsAndBytesConfig( | ||
load_in_4bit=True, | ||
bnb_4bit_use_double_quant=True, | ||
bnb_4bit_quant_type="nf4", | ||
bnb_4bit_compute_dtype=torch.bfloat16, | ||
) | ||
return AutoModelForCausalLM.from_pretrained( | ||
self.model_id, | ||
quantization_config=bnb_config, | ||
device_map=self.runtime, | ||
trust_remote_code=True, | ||
cache_dir=self.cache_dir, | ||
) | ||
|
||
def add_func( | ||
self, name: str, function: Callable, description: str, arguments: List[Dict] | ||
): | ||
""" | ||
Adds a new function to the LlamaFunctionCaller. | ||
Args: | ||
name (str): The name of the function. | ||
function (Callable): The function to execute. | ||
description (str): Description of the function. | ||
arguments (List[Dict]): List of argument specifications. | ||
""" | ||
self.functions[name] = { | ||
"function": function, | ||
"description": description, | ||
"arguments": arguments, | ||
} | ||
|
||
def call_function(self, name: str, **kwargs): | ||
""" | ||
Calls the specified function with given arguments. | ||
Args: | ||
name (str): The name of the function to call. | ||
**kwargs: Keyword arguments for the function call. | ||
Returns: | ||
The result of the function call. | ||
""" | ||
if name not in self.functions: | ||
raise ValueError(f"Function {name} not found.") | ||
|
||
func_info = self.functions[name] | ||
return func_info["function"](**kwargs) | ||
|
||
def __call__(self, task: str, **kwargs): | ||
""" | ||
Streams a user prompt to the model and prints the response. | ||
Args: | ||
task (str): The user prompt to stream. | ||
""" | ||
# Format the prompt | ||
prompt = f"{task}\n\n" | ||
|
||
# Encode and send to the model | ||
inputs = self.tokenizer([prompt], return_tensors="pt").to(self.runtime) | ||
|
||
streamer = TextStreamer(self.tokenizer) | ||
|
||
if self.streaming: | ||
out = self.model.generate( | ||
**inputs, streamer=streamer, max_new_tokens=self.max_tokens, **kwargs | ||
) | ||
|
||
return out | ||
else: | ||
out = self.model.generate(**inputs, max_length=self.max_tokens, **kwargs) | ||
# return self.tokenizer.decode(out[0], skip_special_tokens=True) | ||
return out | ||
|
||
|
||
llama_caller = LlamaFunctionCaller() | ||
|
||
|
||
# Add a custom function | ||
def get_weather(location: str, format: str) -> str: | ||
# This is a placeholder for the actual implementation | ||
return f"Weather at {location} in {format} format." | ||
|
||
|
||
llama_caller.add_func( | ||
name="get_weather", | ||
function=get_weather, | ||
description="Get the weather at a location", | ||
arguments=[ | ||
{ | ||
"name": "location", | ||
"type": "string", | ||
"description": "Location for the weather", | ||
}, | ||
{ | ||
"name": "format", | ||
"type": "string", | ||
"description": "Format of the weather data", | ||
}, | ||
], | ||
) | ||
|
||
# Call the function | ||
result = llama_caller.call_function("get_weather", location="Paris", format="Celsius") | ||
print(result) | ||
|
||
# Stream a user prompt | ||
llama_caller("Tell me about the tallest mountain in the world.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""""" |
Oops, something went wrong.