In [7]:
from litellm import completion
from typing import List, Dict, Any, Optional
import jinja2
from abc import ABC, abstractmethod

with open("templates/system_message.jinja2", "r") as file:
    system_message = jinja2.Template(file.read()).render(tools="")



In [None]:
# Objective: Create a conversation object that can be used to send messages to the model and get responses and tools
class Conversation:
    def __init__(self, model: str = "gpt-4o-mini"):
        self.model = model
        self.messages = []

    def add_message(self, role, content):
        self.messages.append({"role": role, "content": content})

    def get_response(self, tool_choice: str = "auto", tools: Optional[List] = None):
        if tools is None:
            tool_choice = None
        return completion(
            model=self.model,
            messages=self.messages,
            tools=tools,
            tool_choice=tool_choice,
        )


class Tool(ABC):
    def __init__(self):
        pass

    # method to run the tool.
    @abstractmethod
    def run(self, conversation: Conversation) -> Dict[str, Any]:
        pass

    # method to get config for the tool.
    @abstractmethod
    def get_config(self) -> Dict[str, Any]:
        pass


class Reasoning(Tool):
    def __init__(self):
        pass

    def run(self, conversation: Conversation) -> Dict[str, Any]:
        with open("templates/reasoning.jinja2", "r") as file:
            reasoning_template = jinja2.Template(file.read()).render()
        conversation.add_message("system", reasoning_template)
        response = conversation.get_response()
        conversation.messages.pop()  # remove the reasoning message and note that no other message has been added to the conversation as add_message method from Conversation class has not been called.
        return response

    def get_config(self) -> Dict[str, Any]:
        return {
            "type": "function",
            "function": {
                "name": "reasoning",
                "description": "This tool provides reasoning for the given context.",
            },
        }


class Respond(Tool):
    def __init__(self):
        pass

    def run(self, conversation: Conversation) -> Dict[str, Any]:
        with open("templates/respond.jinja2", "r") as file:
            reasoning_template = jinja2.Template(file.read()).render()
        conversation.add_message("system", reasoning_template)
        response = conversation.get_response()
        conversation.messages.pop()
        return response

    def get_config(self) -> Dict[str, Any]:
        return {
            "type": "function",
            "function": {
                "name": "respond",
                "description": "This tool provides responses for the given user message.",
            },
        }

# An agent is a tool that can be used to interact with the model and get responses and use tools as required.
class Agent(Tool):
    def __init__(self):
        pass



In [None]:
conversation = Conversation()

conversation.add_message("system", system_message)
conversation.add_message("user", "Hello, what are you?")
response = conversation.get_response()
print(response)

ModelResponse(id='chatcmpl-AyPJ9E7PMLtAgUfvJ0CidSDyLfTSS', created=1738960327, model='gpt-4o-mini-2024-07-18', object='chat.completion', system_fingerprint='fp_72ed7ab54c', choices=[Choices(finish_reason='stop', index=0, message=Message(content='I am an autonomous market research agent, designed to assist you with information and insights related to market trends, consumer behavior, and various industries. How can I help you today?', role='assistant', tool_calls=None, function_call=None, provider_specific_fields={'refusal': None}, refusal=None))], usage=Usage(completion_tokens=36, prompt_tokens=61, total_tokens=97, completion_tokens_details=CompletionTokensDetailsWrapper(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0, text_tokens=None), prompt_tokens_details=PromptTokensDetailsWrapper(audio_tokens=0, cached_tokens=0, text_tokens=None, image_tokens=None)), service_tier='default')
