In [1]:
## Existing Tools

### Data Setup
import pandas as pd

df = pd.read_csv("StudentPerformanceFactors.csv")
df

Unnamed: 0,Hours_Studied,Attendance,Parental_Involvement,Access_to_Resources,Extracurricular_Activities,Sleep_Hours,Previous_Scores,Motivation_Level,Internet_Access,Tutoring_Sessions,Family_Income,Teacher_Quality,School_Type,Peer_Influence,Physical_Activity,Learning_Disabilities,Parental_Education_Level,Distance_from_Home,Gender,Exam_Score
0,23,84,Low,High,No,7,73,Low,Yes,0,Low,Medium,Public,Positive,3,No,High School,Near,Male,67
1,19,64,Low,Medium,No,8,59,Low,Yes,2,Medium,Medium,Public,Negative,4,No,College,Moderate,Female,61
2,24,98,Medium,Medium,Yes,7,91,Medium,Yes,2,Medium,Medium,Public,Neutral,4,No,Postgraduate,Near,Male,74
3,29,89,Low,Medium,Yes,8,98,Medium,Yes,1,Medium,Medium,Public,Negative,4,No,High School,Moderate,Male,71
4,19,92,Medium,Medium,Yes,6,65,Medium,Yes,3,Medium,High,Public,Neutral,4,No,College,Near,Female,70
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6602,25,69,High,Medium,No,7,76,Medium,Yes,1,High,Medium,Public,Positive,2,No,High School,Near,Female,68
6603,23,76,High,Medium,No,8,81,Medium,Yes,3,Low,High,Public,Positive,2,No,High School,Near,Female,69
6604,20,90,Medium,Low,Yes,6,65,Low,Yes,3,Low,Medium,Public,Negative,2,No,Postgraduate,Near,Female,68
6605,10,86,High,High,Yes,6,91,High,Yes,2,Low,Medium,Private,Positive,3,No,High School,Far,Female,68


In [2]:
### Aggregation Tool
from langchain_core.tools import tool


@tool("aggregate")
def aggregate(column_name: str, aggregate_function: str, grouping_criteria: str | None = None) -> str:
    """
    A tool to apply aggregation functions to a specified column in a pandas DataFrame.
    Optionally groups results by another column.

    Example:
      column_name='math score', aggregate_function='mean', grouping_criteria='city'.
    """
    if grouping_criteria is None:
        return df[column_name].agg(aggregate_function)
    else:
        return df.groupby(grouping_criteria)[column_name].agg(aggregate_function)

In [3]:
### Model Setup
from langchain_aws.chat_models import ChatBedrockConverse
from langchain_core.prompts.chat import ChatPromptTemplate

system_message = f"""You are a helpful agent that has access to a student performance DataFrame with columns:
{df.dtypes}"""

task = "How many hours did students study across low income groups?"
prompt = ChatPromptTemplate(
    [
        ("system", system_message),
        ("human", task),
    ]
)

llm_llama3_70b_instruct = ChatBedrockConverse(model="us.meta.llama3-1-70b-instruct-v1:0", region_name="us-east-1")

In [4]:
### Chain Execution (Initial Attempt)
tools_llm = llm_llama3_70b_instruct.bind_tools([aggregate])
chain = prompt | tools_llm

response = chain.invoke({})
print(response)
print(response.tool_calls)

content=[{'type': 'tool_use', 'name': 'aggregate', 'input': {'column_name': 'Hours_Studied', 'grouping_criteria': 'Family_Income', 'aggregate_function': 'sum'}, 'id': 'tooluse_ALw9FBdrSu6VuQnTrNYobQ'}] additional_kwargs={} response_metadata={'ResponseMetadata': {'RequestId': 'f81cb7a8-51bb-4c05-805b-b7246c33c31f', 'HTTPStatusCode': 200, 'HTTPHeaders': {'date': 'Tue, 04 Mar 2025 09:24:46 GMT', 'content-type': 'application/json', 'content-length': '353', 'connection': 'keep-alive', 'x-amzn-requestid': 'f81cb7a8-51bb-4c05-805b-b7246c33c31f'}, 'RetryAttempts': 0}, 'stopReason': 'tool_use', 'metrics': {'latencyMs': 2337}} id='run-824e702f-0c95-4f09-bcef-0e493a06787f-0' tool_calls=[{'name': 'aggregate', 'args': {'column_name': 'Hours_Studied', 'grouping_criteria': 'Family_Income', 'aggregate_function': 'sum'}, 'id': 'tooluse_ALw9FBdrSu6VuQnTrNYobQ', 'type': 'tool_call'}] usage_metadata={'input_tokens': 292, 'output_tokens': 45, 'total_tokens': 337}
[{'name': 'aggregate', 'args': {'column_nam

In [5]:
### Forcing Tool Usage and Retry
from langchain_aws.function_calling import ToolsOutputParser
from langchain_core.output_parsers import (PydanticOutputParser,
                                           PydanticToolsParser)
from langchain_core.runnables.base import RunnableLambda


def verify(msg):
    print(msg.tool_calls)
    if len(msg.tool_calls) <= 0:
        raise Exception("No tools were used in the chain")
    return msg


def execute_tool(msg):
    tool_params = msg.tool_calls[0]
    return aggregate.invoke(tool_params)


chain = prompt | tools_llm | RunnableLambda(verify) | RunnableLambda(execute_tool)
response = chain.invoke({})
print(response.content)

[{'name': 'aggregate', 'args': {'column_name': 'Hours_Studied', 'grouping_criteria': 'Family_Income', 'aggregate_function': 'sum'}, 'id': 'tooluse_Id8nuvz0TqymJAzq_4R5iQ', 'type': 'tool_call'}]
Family_Income
High      25258
Low       53261
Medium    53458
Name: Hours_Studied, dtype: int64
