In [385]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [386]:
import json
from inspect import signature
from pydantic import BaseModel, create_model
from typing import List, Any, Callable, Union, Mapping

In [387]:
class Function(BaseModel):
    """
    A wrapper around a callable to store its metadata and provide a way to run it with validation.
    """
    def __init__(self, func: Callable[..., Any]):
        super().__init__()
        self._func = func
        self._name = func.__name__
        self._doc = func.__doc__
        self._signature = signature(func)
        self._return_type = self._signature.return_annotation
        self._dynamic_model = get_args_model(self._signature)  # Use the dynamic model class here.

    def get_args_model(signature: inspect.Signature):
        model = create_model(
            'DynamicArgs',
            **{
                param_name: (param.annotation, ...)
                if param.annotation != param.empty
                else (Any, ...)
                for param_name, param in signature.parameters.items()
            }
        )
        return model

    @property
    def name(self):
        return self._name

    @property
    def doc(self):
        return self._doc

    @property
    def return_type(self):
        return self._return_type

    def run(self, args: Mapping[str, Any]) -> Any:
        """
        Run the function with the provided arguments after validation.
        """
        # Validate the args using the dynamic model.
        validated_args = self._dynamic_model.model_validate(args)
        
        # Call the function with unpacked arguments from the validated BaseModel.
        result = self._func(**validated_args.model_dump())
        
        # Check if the result matches the expected return type.
        assert self._return_type == type(result), f"Expected return type {self._return_type}, but got {type(result)}"
        return result

    def __repr__(self):
        doc_preview = (self._doc[:30] + '...') if self._doc and len(self._doc) > 30 else self._doc
        return f"Function(name={self._name}, signature={self._signature}, doc={doc_preview})"

class Tool(Function):
    def __init__(self, func: Callable[..., Any]):
        super().__init__(func)
    
    @property
    def schema(self):
        schema = {
            "name": self.name,
            "doc": self.doc,
            "parameters": {
                "properties": {
                    param_name: {"type": param.annotation.__name__ if param.annotation != param.empty else "unknown"}
                    for param_name, param in self._signature.parameters.items()
                }
            }
        }
    
        return schema

    def __repr__(self):
        doc_preview = (self._doc[:30] + '...') if self._doc and len(self._doc) > 30 else self._doc
        return f"Tool(name={self._name}, signature={self._signature}, doc={doc_preview})"


class ToolKit(BaseModel):
    def __init__(self, tools: Union[Callable[..., Any], List[Callable[..., Any]]]):
        super().__init__()
        self._tools = {}
        self.tools = tools  # Use the setter to initialize tools

    @property
    def tools(self) -> List[Tool]:
        return list(self._tools.values())

    @tools.setter
    def tools(self, tools: Union[Callable[..., Any], List[Callable[..., Any]]]):
        if not isinstance(tools, list):
            tools = [tools]

        for tool in tools:
            tool_instance = Tool(tool)
            name = tool_instance.name

            # Handle conflicts by appending a number if a name already exists
            if name in self._tools:
                counter = 1
                new_name = f"{name}_{counter}"
                while new_name in self._tools:
                    counter += 1
                    new_name = f"{name}_{counter}"
                name = new_name

            # Store the tool instance with the resolved name
            self._tools[name] = tool_instance

    def tools_schemas(self) -> List[dict]:
        # Return a list of schemas for each Tool in the toolkit
        return [tool.schema for tool in self._tools.values()]

    def get_tool_by_name(self, name: str) -> Union[Tool, None]:
        """
        Retrieve a Tool by its name.

        Parameters:
        - name (str): The name of the tool.

        Returns:
        - Tool: The Tool instance if found.
        - None: If no tool with the given name exists.
        """
        return self._tools.get(name)

    def __repr__(self):
        return 'ToolKit'

def test(a: int, b: List = [1,2,3]) -> list:
    return a * b

def test2(c: str, d: int) -> str:
    return c*d


In [388]:
def sum2(a: float, b: float) -> float:
    return a + b

def sub(a: float, b: float) -> float:
    return a - b

def list_mul(a: int, b: List = [1,2,3]) -> list:
    return a * b

In [389]:
sum2_func = Function(sum2)
sub_func = Function(sub)

In [390]:
sub_func

Function(name=sub, signature=(a: float, b: float) -> float, doc=None)

In [391]:
sum2_func.run({'a': 2, 'b': 3})

5.0

In [392]:
sub_func.run({'a': 2, 'b': 3})

-1.0

In [393]:
sum2_tool = Tool(func=sum2)

In [394]:
sum2_tool.schema

{'name': 'sum2',
 'doc': None,
 'parameters': {'properties': {'a': {'type': 'float'},
   'b': {'type': 'float'}}}}

In [395]:
sum2_tool.run({'a': 2, 'b': 3})

5.0

In [396]:
assert sum2_tool.run({'a': 2, 'b': 3}) == sum2_func.run({'a': 2, 'b': 3})

In [397]:
sum2_tool.run({'a': 2, 'b': 3})

5.0

In [398]:
tk = ToolKit(tools=[sum2, sub, list_mul])

In [399]:
tk.tools

[Tool(name=sum2, signature=(a: float, b: float) -> float, doc=None),
 Tool(name=sub, signature=(a: float, b: float) -> float, doc=None),
 Tool(name=list_mul, signature=(a: int, b: List = [1, 2, 3]) -> list, doc=None)]

In [400]:
assert tk.tools[0].run({'a': 2, 'b': 3}) == sum2_tool.run({'a': 2, 'b': 3}) == sum2_func.run({'a': 2, 'b': 3})

In [401]:
assert tk.get_tool_by_name('sum2').run({'a': 2, 'b': 3}) == sum2_tool.run({'a': 2, 'b': 3}) == sum2_func.run({'a': 2, 'b': 3})