In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from model import Model
from prompts import system_prompt
from vector_store import VectorStore

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
vector_store = VectorStore()
model = Model(model_name="llama-3.3-70b-versatile")
client = model.client
model_name = model.model_name

In [4]:
def visualize_data(data: dict, chart_type: str, x: str, y: str, title: str = "", in_notebook=False):
    import pandas as pd
    import plotly.express as px

    df = pd.DataFrame(data["rows"])
    if df.empty or x not in df.columns or y not in df.columns:
        print("Insufficient data or invalid fields.")
        return

    df = df.sort_values(by=x)
    fig = None

    if chart_type == "bar":
        fig = px.bar(df, x=x, y=y, title=title or f"{y} by {x}")
    elif chart_type == "line":
        fig = px.line(df, x=x, y=y, markers=True, title=title or f"{y} over {x}")
    elif chart_type == "area":
        fig = px.area(df, x=x, y=y, title=title or f"{y} over {x}")
    elif chart_type == "pie":
        if df[y].sum() == 0:
            print("No data to display in pie chart.")
            return
        fig = px.pie(df, names=x, values=y, title=title or f"{y} by {x}")
    else:
        print(f"Unsupported chart type: {chart_type}")
        return

    if in_notebook:
        fig.show()
    else:
        import streamlit as st
        st.plotly_chart(fig, use_container_width=True)

In [5]:
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "query_sql",
            "description": "Generates a SQL query to retrieve transaction summaries or aggregations for a specific client.",
            "parameters": {
                "type": "object",
                "properties": {
                    "start_date": {
                        "type": "string",
                        "format": "date",
                        "description": "The start date for the transaction filter in YYYY-MM-DD format."
                    },
                    "end_date": {
                        "type": "string",
                        "format": "date",
                        "description": "The end date for the transaction filter in YYYY-MM-DD format."
                    },
                    "aggregation": {
                        "type": "string",
                        "enum": ["sum", "count", "avg", "max", "min"],
                        "description": "The type of aggregation to perform on the transaction amounts."
                    },
                    "direction": {
                        "type": "string",
                        "enum": ["spend", "income", "both"],
                        "description": "Filter transactions based on amount direction. 'spend' for negative, 'income' for positive, 'both' for all."
                    },
                    "category": {
                        "type": "string",
                        "description": "Filter transactions by a specific category."
                    },
                    "merchants": {
                        "type": "array",
                        "items": {
                        "type": "string"
                        },
                        "description": "List of merchants to filter the transactions by (case-insensitive)."
                    },
                    "descriptions": {
                        "type": "array",
                        "items": {
                        "type": "string"
                        },
                        "description": "List of keywords to search for in the transaction descriptions."
                    },
                    "group_by": {
                        "type": "array",
                        "items": {
                        "type": "string",
                        "enum": ["bank_id", "acc_id", "txn_id", "txn_date", "desc", "amt", "cat", "merchant"]
                        },
                        "description": "Columns to group the results by."
                    },
                    "limit": {
                        "type": "integer",
                        "minimum": 1,
                        "description": "Limit the number of rows returned."
                    }
                    },
                    "required": ["aggregation"]
                }
        }
    },

    {
        "type": "function",
        "function": {
            "name": "visualize_data",
            "description": "Generate a visualization based on structured transaction results.",
            "parameters": {
            "type": "object",
            "properties": {
                "chart_type": {
                "type": "string",
                "enum": ["pie", "bar", "line", "area", "calendar"]
                },
                "x": {
                "type": "string",
                "description": "Field to use on X-axis"
                },
                "y": {
                "type": "string",
                "description": "Field to use on Y-axis"
                },
                "title": {
                "type": "string",
                "description": "Title of the chart"
                }
            },
            "required": ["chart_type", "x", "y"]
            }
        }
    }
]

In [6]:
def call_model(context, merchants, descriptions):
    query = context["messages"]
    messages = [{"role": "system", "content": system_prompt(merchants, descriptions, today='2023-09-12')}] + query

    response = client.chat.completions.create(
        model=model_name,
        messages=messages,
        temperature=0.6,
        tools=TOOLS
    )

    reply = response.choices[0].message
    return reply


In [7]:
query = "Visualise a pie chart of where I spent money last month."
merchants, descriptions = vector_store.get_unique_merchants_and_descriptions(query, 880, 30)

In [8]:
context = {
    "messages": [
        {"role": "user", "content": query}
    ]
}
res = call_model(context, merchants, descriptions)

In [9]:
tool_calls = res.tool_calls

In [10]:
tool_calls[0]

ChatCompletionMessageToolCall(id='call_av52', function=Function(arguments='{"aggregation": "sum", "direction": "spend", "group_by": ["cat"], "start_date": "2023-08-01", "end_date": "2023-08-31"}', name='query_sql'), type='function')

In [11]:
import json
args = json.loads(tool_calls[0].function.arguments)


In [12]:
args["client_id"] = 880

In [13]:
from tools import query_sql

data = query_sql(**args)


In [14]:
data

{'rows': [{'cat': 'ATM', 'SUM(amt)': -42022.762},
  {'cat': 'Arts and Entertainment', 'SUM(amt)': -1635.04},
  {'cat': 'Bank Fee', 'SUM(amt)': -25.592000000000002},
  {'cat': 'Bank Fees', 'SUM(amt)': -4891.84},
  {'cat': 'Clothing and Accessories', 'SUM(amt)': -16148.888},
  {'cat': 'Convenience Stores', 'SUM(amt)': -8799.976},
  {'cat': 'Department Stores', 'SUM(amt)': -3185.616},
  {'cat': 'Digital Entertainment', 'SUM(amt)': -4913.164},
  {'cat': 'Gas Stations', 'SUM(amt)': -26300.254},
  {'cat': 'Gyms and Fitness Centers', 'SUM(amt)': -68.762},
  {'cat': 'Healthcare', 'SUM(amt)': -569.23},
  {'cat': 'Insurance', 'SUM(amt)': -11719.994},
  {'cat': 'Internal Account Transfer', 'SUM(amt)': -37901.236},
  {'cat': 'Loans', 'SUM(amt)': -40519.72},
  {'cat': 'Payment', 'SUM(amt)': -241.094},
  {'cat': 'Restaurants', 'SUM(amt)': -33843.042},
  {'cat': 'Service', 'SUM(amt)': -4227.642},
  {'cat': 'Shops', 'SUM(amt)': -20132.122},
  {'cat': 'Supermarkets and Groceries', 'SUM(amt)': -41212.58

In [18]:
visualize_data(data, "bar", "cat", "SUM(amt)", in_notebook=True)

In [28]:
type(tool_calls[0].function.arguments)

str

In [None]:
from tools import query_sql
tools_available = {
    "query_sql": query_sql
}

def tool_node(context):
    last_msg = context["messages"][-1]
    if not hasattr(last_msg, "tool_calls"):
        return context

    tool_calls = last_msg.tool_calls
    for tool_call in tool_calls:
        tool_name = tool_call.function.name
        tool_args = json.loads(tool_call.function.arguments)

        tool_func = tools_available[tool_name]
        if tool_func:
            try:
                result = tool_func(**tool_args)
                context["messages"].append({"role": "function", "name": tool_name, "content": str(result)})
            except Exception as e:
                context["messages"].append({"role": "function", "name": tool_name, "content": f"Tool error: {e}"})


    return context

In [36]:
res

ChatCompletionMessage(content=None, refusal=None, role='assistant', annotations=None, audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_3ana', function=Function(arguments='{"client_id": "user123", "aggregation": "sum", "direction": "spend", "merchants": ["uber"], "start_date": "2025-05-05", "end_date": "2025-05-11", "group_by": ["txn_date"]}', name='query_sql'), type='function')])

In [37]:
def should_continue(context):
    last_msg = context["messages"][-1]
    if last_msg.role == "function":
        return "call_model"
    if last_msg.tool_calls:
        return "tools"
    if last_msg.content:
        return "end"
    return "end"

In [39]:
res.role

'assistant'

In [None]:
def chat(query):
    context = {
        "messages": [],
        "state": "call_model",
    }
    context["messages"].append({"role": "user", "content": query})

    while context["state"] != "end":
        if context["state"] == "call_model":
            context = call_model(context, merchants, descriptions)
        elif context["state"] == "tools":
            context = tool_node(context)
        else:
            break
        context["state"] = should_continue(context)
    
    last_msg = context["messages"][-1].content

    messages = [
        {"role": "user", "content": query},
        {"role": last_msg.role, "content": last_msg.content.replace(system_prompt(merchants, descriptions), "")}
    ]

    return messages

---

## Testing Agent class

In [54]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [76]:
from agent import Agent
from model import Model
from vector_store import VectorStore

agent = Agent(Model(), VectorStore())


In [116]:
query = "Whats my last month Uber spendings?"
res = agent.chat(query, 880, "2023-09-19", model_name="llama-3.1-8b-instant")

Reply:  {'content': None, 'refusal': None, 'role': 'assistant', 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [{'id': 'call_atqx', 'function': {'arguments': '{"aggregation": "sum", "direction": "spend", "start_date": "2023-08-01", "end_date": "2023-08-31", "merchants": ["uber"], "group_by": []}', 'name': 'query_sql'}, 'type': 'function'}]}
Reply:  {'content': 'Based on the query results, your total Uber spending for last month was approximately $14.09.', 'refusal': None, 'role': 'assistant', 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': None}


In [117]:
res

{'content': 'Based on the query results, your total Uber spending for last month was approximately $14.09.',
 'chart': None}

In [100]:
res["content"]

'Here\'s the analysis of your Uber spending based on the data I retrieved:\n\n**Monthly Uber Spending (Jun-Sep 2023):**\n- June: $28.33\n- July: $60.81\n- August: $14.06\n- September: $6.80 (partial)\n\nThe chart displays daily transactions, but I notice most amounts are unusually small (e.g., $0.07). This might indicate:\n1. Test transactions or scaled-down data\n2. Possible currency scaling issues (e.g., amounts stored in cents)\n3. Limited transaction history in this format\n\nWould you like me to:\n1. Adjust the grouping to monthly totals instead of daily?\n2. Check for alternative merchant names (e.g., "Uber Technologies")?\n3. Verify the data format with your bank?\n\nLet me know how you\'d like to proceed!'

In [101]:
res["chart"]

In [111]:
query_args = {"aggregation": "max", "direction": "spend", "group_by": ["cat"], "start_date": "2023-06-19", "end_date": "2023-09-19", "client_id": 880}
data = query_sql(**query_args)

In [112]:
data

{'rows': [{'cat': 'ATM', 'MAX(amt)': -0.2},
  {'cat': 'Arts and Entertainment', 'MAX(amt)': -0.024},
  {'cat': 'Bank Fee', 'MAX(amt)': -0.098},
  {'cat': 'Bank Fees', 'MAX(amt)': -0.002},
  {'cat': 'Clothing and Accessories', 'MAX(amt)': -0.004},
  {'cat': 'Convenience Stores', 'MAX(amt)': -0.004},
  {'cat': 'Department Stores', 'MAX(amt)': -0.018},
  {'cat': 'Digital Entertainment', 'MAX(amt)': -0.014},
  {'cat': 'Gas Stations', 'MAX(amt)': -1.0},
  {'cat': 'Gyms and Fitness Centers', 'MAX(amt)': -0.104},
  {'cat': 'Healthcare', 'MAX(amt)': -0.072},
  {'cat': 'Insurance', 'MAX(amt)': -0.008},
  {'cat': 'Internal Account Transfer', 'MAX(amt)': -0.002},
  {'cat': 'Loans', 'MAX(amt)': -0.002},
  {'cat': 'Payment', 'MAX(amt)': -2.0},
  {'cat': 'Restaurants', 'MAX(amt)': -0.002},
  {'cat': 'Service', 'MAX(amt)': -0.05},
  {'cat': 'Shops', 'MAX(amt)': -0.034},
  {'cat': 'Supermarkets and Groceries', 'MAX(amt)': -0.002},
  {'cat': 'Telecommunication Services', 'MAX(amt)': -1.0},
  {'cat': 'T

In [113]:
import pandas as pd
df = pd.DataFrame(data["rows"])
df.head()

Unnamed: 0,cat,MAX(amt)
0,ATM,-0.2
1,Arts and Entertainment,-0.024
2,Bank Fee,-0.098
3,Bank Fees,-0.002
4,Clothing and Accessories,-0.004


In [114]:
df.columns.tolist()


['cat', 'MAX(amt)']

In [115]:
from tools import visualize_data
visualize_data(data, "pie", "cat", "SUM(amt)", title="Spending by Category")

In [67]:
reply = {'content': None, 'refusal': None, 'role': 'assistant', 'annotations': None, 'audio': None, 'function_call': None, 'tool_calls': [{'id': 'call_xb6h', 'function': {'arguments': '{"client_id":"user","aggregation":"sum","direction":"spend","start_date":"2023-06-19","end_date":"2023-09-19","group_by":["cat"]}', 'name': 'query_sql'}, 'type': 'function'}]}

In [73]:
reply.get("tool_calls")

[{'id': 'call_xb6h',
  'function': {'arguments': '{"client_id":"user","aggregation":"sum","direction":"spend","start_date":"2023-06-19","end_date":"2023-09-19","group_by":["cat"]}',
   'name': 'query_sql'},
  'type': 'function'}]

In [48]:
chart = res["chart"]


In [51]:
chart

---

## Trying Visualisation Tool