In [51]:
import os
import re
from typing import Dict, List, Any, Tuple, Annotated, TypedDict, Sequence
from langchain.agents import Tool
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_ollama import ChatOllama
from langchain_core.tools import tool
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolExecutor
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [52]:
MODEL_NAME = "llama3.2"
TEMPERATURE = 0.1
NUM_CTX = 4096
NUM_PREDICT = 2048

In [53]:
llm = ChatOllama(
    model=MODEL_NAME,
    temperature=TEMPERATURE,
    num_ctx=NUM_CTX,
    num_predict=NUM_PREDICT,
)

In [54]:
class AgentState(TypedDict):
    messages: Sequence[BaseMessage]
    agent_scratchpad: List[BaseMessage]

In [55]:
def load_csv_data(file_path: str) -> pd.DataFrame:
    """Load CSV file into pandas DataFrame."""
    return pd.DataFrame(pd.read_csv(file_path))

In [56]:
class DataAnalysisTools:
    def __init__(self, df: pd.DataFrame):
        self.df = df

    def get_column_names(self) -> List[str]:
        """Get list of column names in the dataset."""
        return list(self.df.columns)

    def get_basic_stats(self, column_name: str) -> str:
        """Get basic statistics for a specific column."""
        try:
            stats = self.df[column_name].describe()
            return f"Statistics for {column_name}:\n{stats}"
        except KeyError:
            return f"Column {column_name} not found in the dataset. Available columns are: {', '.join(self.get_column_names())}"

    def create_visualization(
        self, plot_type: str, x_col: str, y_col: str = None
    ) -> str:
        """Create visualization based on specified type and columns."""
        try:
            plt.figure(figsize=(10, 6))

            if plot_type == "histogram":
                sns.histplot(data=self.df, x=x_col)
                plt.title(f"Histogram of {x_col}")

            elif plot_type == "scatter":
                if y_col is None:
                    return "For scatter plot, both x and y columns are required."
                sns.scatterplot(data=self.df, x=x_col, y=y_col)
                plt.title(f"Scatter plot: {x_col} vs {y_col}")

            elif plot_type == "boxplot":
                sns.boxplot(data=self.df, y=x_col)
                plt.title(f"Box plot of {x_col}")

            plot_path = "temp_plot.png"
            plt.savefig(plot_path)
            plt.close()

            return f"Plot has been created and saved as {plot_path}"
        except KeyError as e:
            return f"Column not found. Available columns are: {', '.join(self.get_column_names())}"
        except Exception as e:
            return f"Error creating visualization: {str(e)}"

    def query_data(self, query: str) -> str:
        """Query data based on specific conditions."""
        try:
            if ">" in query:
                col, value = query.split(">")
                result = self.df[self.df[col.strip()] > float(value)]
            elif "<" in query:
                col, value = query.split("<")
                result = self.df[self.df[col.strip()] < float(value)]
            else:
                return "Invalid query format. Use '>' or '<' operators."

            return (
                f"Query results:\n{result.head()}\nTotal matching rows: {len(result)}"
            )
        except KeyError:
            return f"Column not found. Available columns are: {', '.join(self.get_column_names())}"
        except Exception as e:
            return f"Error executing query: {str(e)}"

In [60]:
def create_agent(csv_file_path: str):
    # Load data
    df = load_csv_data(csv_file_path)
    analysis_tools = DataAnalysisTools(df)

    # Create the tools
    tools = [
        Tool(
            name="get_stats",
            func=analysis_tools.get_basic_stats,
            description="Get basic statistics for a specific column. Input: column name",
        ),
        Tool(
            name="create_plot",
            func=analysis_tools.create_visualization,
            description="Create visualization. Input format: 'plot_type,x_column,y_column'. Valid plot types: histogram, scatter, boxplot",
        ),
        Tool(
            name="query_data",
            func=analysis_tools.query_data,
            description="Query data with conditions. Input format: 'column > value' or 'column < value'",
        ),
    ]

    # Create tool executor
    tool_executor = ToolExecutor(tools)

    # Define the prompt
    prompt = ChatPromptTemplate.from_messages(
        [
            (
                "system",
                "You are a helpful AI assistant that analyzes data from CSV files. "
                "You have access to the following tools:\n\n"
                "1. get_stats: Get basic statistics for a specific column\n"
                "2. create_plot: Create visualization (histogram, scatter, boxplot)\n"
                "3. query_data: Query data with conditions\n\n"
                "To use a tool, respond with: /tool_name parameter\n"
                "For example: /get_stats price\n"
                "Or: /create_plot histogram,age\n"
                "Or: /query_data price > 1000\n\n"
                "Use these tools to help answer user questions about the data.",
            ),
            MessagesPlaceholder(variable_name="messages"),
            MessagesPlaceholder(variable_name="agent_scratchpad"),
        ]
    )

    # Function to parse tool calls
    def parse_tool_call(text: str) -> tuple[str, str] | None:
        pattern = r"^/(\w+)\s+(.+)$"
        match = re.match(pattern, text.strip())
        if match:
            return match.group(1), match.group(2)
        return None

    # Function to determine which tool to use
    def determine_next_action(state: AgentState) -> AgentState:
        messages = state["messages"]

        response = llm.invoke(
            prompt.format_messages(
                messages=messages, agent_scratchpad=state.get("agent_scratchpad", [])
            )
        )

        tool_call = parse_tool_call(response.content)
        if tool_call:
            tool_name, tool_args = tool_call

            # Find the matching tool
            matching_tools = [t for t in tools if t.name == tool_name]
            if not matching_tools:
                return {
                    "messages": list(messages)
                    + [AIMessage(content=f"Tool '{tool_name}' not found.")],
                    "agent_scratchpad": [],
                }

            tool = matching_tools[0]
            try:
                # Execute the tool
                if tool_name == "create_plot":
                    args = tool_args.split(",")
                    plot_type = args[0]
                    x_col = args[1]
                    y_col = args[2] if len(args) > 2 else None
                    tool_response = tool.func(plot_type, x_col, y_col)
                else:
                    tool_response = tool.func(tool_args)

                # Add both the tool call and its response to messages
                return {
                    "messages": list(messages)
                    + [
                        AIMessage(content=f"Let me {tool_name} for you."),
                        AIMessage(content=str(tool_response)),
                    ],
                    "agent_scratchpad": [],
                }
            except Exception as e:
                return {
                    "messages": list(messages)
                    + [AIMessage(content=f"Error executing tool: {str(e)}")],
                    "agent_scratchpad": [],
                }
        else:
            # If no tool call detected, just return the response
            return {
                "messages": list(messages) + [AIMessage(content=response.content)],
                "agent_scratchpad": [],
            }

    # Create the graph
    workflow = StateGraph(AgentState)

    # Add node
    workflow.add_node("agent", determine_next_action)

    # Add edge from agent to end
    workflow.set_entry_point("agent")
    workflow.add_edge("agent", END)

    # Compile the graph
    chain = workflow.compile()

    return chain

In [61]:
def process_user_question(chain, question: str):
    """Process user question using the agent chain."""
    result = chain.invoke({
        "messages": [HumanMessage(content=question)],
        "agent_scratchpad": []
    })
    return "\n".join(msg.content for msg in result["messages"][-2:])

In [64]:
def main():
    # Set up the agent
    csv_file_path = "./data/HR_Analytics.csv"  # Replace with actual CSV file path
    agent_chain = create_agent(csv_file_path)

    # Interactive loop
    print("CSV Analysis Agent Ready! (Type 'quit' to exit)")
    print("Available commands:")
    print("1. /get_stats column_name")
    print("2. /create_plot plot_type,x_column,y_column")
    print("   Plot types: histogram, scatter, boxplot")
    print("3. /query_data column > value")

    while True:
        question = input("\nEnter your question: ")
        if question.lower() == "quit":
            break

        try:
            response = process_user_question(agent_chain, question)
            print(f"\nResponse: {response}")
        except Exception as e:
            print(f"Error: {str(e)}")

In [66]:
main()

  tool_executor = ToolExecutor(tools)


CSV Analysis Agent Ready! (Type 'quit' to exit)
Available commands:
1. /get_stats column_name
2. /create_plot plot_type,x_column,y_column
   Plot types: histogram, scatter, boxplot
3. /query_data column > value



Response: /create_plot BusinessTravel DailyRate
Here is a scatter plot of BusinessTravel vs DailyRate:

```
BusinessTravel | DailyRate 
----------------|-----------
Yes             | 50.00
No              | 30.00
Yes             | 75.00
No              | 25.00
...
```

Note: The actual data points are not shown here, but the scatter plot would display them.

Would you like to customize the plot further (e.g., add a title, labels, etc.)?
