In [3]:
%load_ext autoreload
%autoreload 2

In [10]:
from model import Model
from prompts import system_prompt
from vector_store import VectorStore

In [11]:
vector_store = VectorStore()
model = Model(model_name="llama-3.3-70b-versatile")
client = model.client
model_name = model.model_name

In [16]:
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "query_sql",
            "description": "Generates a SQL query to retrieve transaction summaries or aggregations for a specific client.",
            "parameters": {
                "type": "object",
                "properties": {
                    "client_id": {
                        "type": "string",
                        "description": "The client ID for which to retrieve transactions."
                    },
                    "start_date": {
                        "type": "string",
                        "format": "date",
                        "description": "The start date for the transaction filter in YYYY-MM-DD format."
                    },
                    "end_date": {
                        "type": "string",
                        "format": "date",
                        "description": "The end date for the transaction filter in YYYY-MM-DD format."
                    },
                    "aggregation": {
                        "type": "string",
                        "enum": ["sum", "count", "avg", "max", "min"],
                        "description": "The type of aggregation to perform on the transaction amounts."
                    },
                    "direction": {
                        "type": "string",
                        "enum": ["spend", "income", "both"],
                        "description": "Filter transactions based on amount direction. 'spend' for negative, 'income' for positive, 'both' for all."
                    },
                    "category": {
                        "type": "string",
                        "description": "Filter transactions by a specific category."
                    },
                    "merchants": {
                        "type": "array",
                        "items": {
                        "type": "string"
                        },
                        "description": "List of merchants to filter the transactions by (case-insensitive)."
                    },
                    "descriptions": {
                        "type": "array",
                        "items": {
                        "type": "string"
                        },
                        "description": "List of keywords to search for in the transaction descriptions."
                    },
                    "group_by": {
                        "type": "array",
                        "items": {
                        "type": "string",
                        "enum": ["bank_id", "acc_id", "txn_id", "txn_date", "desc", "amt", "cat", "merchant"]
                        },
                        "description": "Columns to group the results by."
                    },
                    "limit": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Limit the number of rows returned."
                    }
                    },
                    "required": ["client_id", "aggregation"]
                }
        }
    }
]

In [17]:
def call_model(context, merchants, descriptions):
    query = context["messages"]
    messages = [{"role": "system", "content": system_prompt(merchants, descriptions)}] + query

    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.6,
        tools=TOOLS
    )

    reply = response.choices[0].message
    return reply


In [18]:
query = "Can you show me how much I spent on Uber each day last week?"
merchants, descriptions = vector_store.get_unique_merchants_and_descriptions(query, 880, 30)

In [19]:
context = {
    "messages": [
        {"role": "user", "content": query}
    ]
}
res = call_model(context, merchants, descriptions)

In [32]:
tool_calls = res.tool_calls

In [33]:
tool_calls[0]

ChatCompletionMessageToolCall(id='call_3ana', function=Function(arguments='{"client_id": "user123", "aggregation": "sum", "direction": "spend", "merchants": ["uber"], "start_date": "2025-05-05", "end_date": "2025-05-11", "group_by": ["txn_date"]}', name='query_sql'), type='function')

In [28]:
type(tool_calls[0].function.arguments)

str

In [30]:
import json
args = json.loads(tool_calls[0].function.arguments)


In [31]:
args

{'client_id': 'user123',
 'aggregation': 'sum',
 'direction': 'spend',
 'merchants': ['uber'],
 'start_date': '2025-05-05',
 'end_date': '2025-05-11',
 'group_by': ['txn_date']}

In [None]:
from tools import query_sql
tools_available = {
    "query_sql": query_sql
}

def tool_node(context):
    last_msg = context["messages"][-1]
    if not hasattr(last_msg, "tool_calls"):
        return context

    tool_calls = last_msg.tool_calls
    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        tool_args = json.loads(tool_call.function.arguments)

        tool_func = tools_available[tool_name]
        if tool_func:
            try:
                result = tool_func(**tool_args)
                context["messages"].append({"role": "function", "name": tool_name, "content": str(result)})
            except Exception as e:
                context["messages"].append({"role": "function", "name": tool_name, "content": f"Tool error: {e}"})


    return context

In [36]:
res

ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_3ana', function=Function(arguments='{"client_id": "user123", "aggregation": "sum", "direction": "spend", "merchants": ["uber"], "start_date": "2025-05-05", "end_date": "2025-05-11", "group_by": ["txn_date"]}', name='query_sql'), type='function')])

In [37]:
def should_continue(context):
    last_msg = context["messages"][-1]
    if last_msg.role == "function":
        return "call_model"
    if last_msg.tool_calls:
        return "tools"
    if last_msg.content:
        return "end"
    return "end"

In [39]:
res.role

'assistant'

In [None]:
def chat(query):
    context = {
        "messages": [],
        "state": "call_model",
    }
    context["messages"].append({"role": "user", "content": query})

    while context["state"] != "end":
        if context["state"] == "call_model":
            context = call_model(context, merchants, descriptions)
        elif context["state"] == "tools":
            context = tool_node(context)
        else:
            break
        context["state"] = should_continue(context)
    
    last_msg = context["messages"][-1].content

    messages = [
        {"role": "user", "content": query},
        {"role": last_msg.role, "content": last_msg.content.replace(system_prompt(merchants, descriptions), "")}
    ]

    return messages

---

## Testing Agent class

In [1]:
%load_ext autoreload
%autoreload 2

In [14]:
from agent import Agent
from model import Model
from vector_store import VectorStore

agent = Agent(Model(), VectorStore())


In [23]:
res = agent.chat("What is my last month spending on Travel?", 880, "2023-09-19")

Last Message:  ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_ch8z', function=Function(arguments='{"client_id": "your_client_id", "aggregation": "sum", "direction": "spend", "category": "Travel", "start_date": "2023-08-01", "end_date": "2023-08-31", "limit": 100}', name='query_sql'), type='function')])
[ChatCompletionMessageToolCall(id='call_ch8z', function=Function(arguments='{"client_id": "your_client_id", "aggregation": "sum", "direction": "spend", "category": "Travel", "start_date": "2023-08-01", "end_date": "2023-08-31", "limit": 100}', name='query_sql'), type='function')]
Last Message:  {'role': 'function', 'name': 'query_sql', 'content': "{'rows': [{'SUM(amt)': -645.292}]}"}
Last Message:  ChatCompletionMessage(content='Your last month spending on Travel was approximately $645.29.', refusal=None, role='assistant', annotations=None, audio=None, function_call=N

In [24]:
res

'Your last month spending on Travel was approximately $645.29.'