In [None]:
# note: same thing for REPL
# note: we use this instead of magic because `black` will otherwise fail to format
#
# Enable autoreload to automatically reload modules when they change

from IPython import get_ipython

# do this so that formatter not messed up
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Import commonly used libraries
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# graphics
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

# type annotation
import jaxtyping
from jaxtyping import Float32, Int64, jaxtyped
from typeguard import typechecked as typechecker

# more itertools
import more_itertools as mi

# itertools
import itertools
import collections

# tensor manipulation
from einops import rearrange, reduce, repeat

# automatically apply jaxtyping
# %load_ext jaxtyping
# %jaxtyping.typechecker typeguard.typechecked

In [None]:
from typing import Any, Callable
import dataclasses
import rich


@dataclasses.dataclass(frozen=True)
class ToolDefinition:
    function_name: str
    description: str
    arguments: list[str]
    return_value: str


@dataclasses.dataclass(frozen=True)
class ToolUseRequest:
    function_name: str
    arguments: dict[str, Any]


@dataclasses.dataclass(frozen=True)
class ToolUseResponse:
    function_name: str
    return_value: Any


@dataclasses.dataclass(frozen=True)
class ToolDefinitionWithCallable[R, **P]:
    """Tool definition with an associated python callable that can be used to execute the tool."""

    tool_definition: ToolDefinition
    callable_fn: Callable[P, R]


@dataclasses.dataclass(frozen=True)
class ToolDefinitionList[R, **P]:

    tools: list[ToolDefinitionWithCallable[R, P]]


# TODO(bschoen): Tool use record? Separate from conversation history?
class FunctionCallManager:
    def __init__(self, tool_definition_list: ToolDefinitionList) -> None:

        self.tool_definition_list = tool_definition_list

        # Store a mapping of function names to their ToolDefinitionWithCallable objects
        self._tool_map: dict[str, ToolDefinitionWithCallable] = {
            tool_def.tool_definition.function_name: tool_def
            for tool_def in tool_definition_list.tools
        }

    def get_tool_definition_list(self) -> ToolDefinitionList:
        # available for easy access by outside when defining everything together
        # with it's corresponding callable
        return self.tool_definition_list

    # TODO(bschoen): Ignore type conversion for now (could try ast literal then fall back to
    #                string if can't parse)
    # TODO(bschoen): Ignoring errors for now
    def execute_tool(self, request: ToolUseRequest) -> ToolUseResponse:

        rich.print(f"Executing tool: {request}")

        # Lookup the ToolDefinitionWithCallable for the requested function
        if request.function_name not in self._tool_map:
            raise ValueError(f"Unknown function: {request.function_name}")

        tool_def_with_callable = self._tool_map[request.function_name]

        # Call the function with the provided arguments

        # note: we just assume values are in order instead of enforcing kwargs,
        #       since we want to be able to easily use partial even if we
        #       don't know argument names

        result = tool_def_with_callable.callable_fn(*request.arguments.values())

        # Return the result wrapped in a ToolUseResponse
        response = ToolUseResponse(
            function_name=request.function_name,
            return_value=str(result),
        )

        rich.print(f"Executed request. Tool response for {response.function_name}")

        return response

In [None]:
"""
    
    Protocol:

    {
  "user_message": "Here are the available tools, please use them to add 14 and 21.",
  "tool_definitions": [
    {
      "function": "get_wikipedia_page",
      "description": "Returns full content of a given wikipedia page",
      "arguments": ["title_of_wikipedia_page"],
      "return": "returns full content of a given wikipedia page as a json string"
    },
    {
      "function": "add",
      "description": "Add two numbers together",
      "arguments": ["first_number", "second_number"],
      "return": "returns the numeric value of adding together `first_number` and `second_number`"
    }
  ],
  "assistant_tool_use_request": {
    "function": "add",
    "arguments": {
      "first_number": 14,
      "second_number": 21
    }
  },
  "user_tool_use_response": {
    "function": "add",
    "return": 35
  },
  "assistant_response": "The result of adding 14 and 21 is 35."
}

 """

In [None]:
# note: could ask it to write tests first
"""
Example aider usage
> aider --read inspect_explorer/conversation_manager.py --file tests/test_conversation_manager
.py --yes --message 'Write unit tests for conversation manager that can be run with pytest'

"""