<h1>DataFrame Worker Agent</h1>

In [None]:
%pip install langchain langgraph pandas python-dotenv

In [None]:
from dotenv import load_dotenv

_ = load_dotenv()

In [None]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Annotated, List
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.checkpoint.memory import MemorySaver
from IPython.display import Markdown, display
import os
import json
import pandas as pd

memory = MemorySaver()

In [None]:
from typing import TypedDict, Union, List, Dict, Any

class AgentState(TypedDict):
    task: str # This is the user input
    query: str # This is the generated by the LLM response to the task - a DataFrame query
    result: Union[
        Dict[str, Any],  # For structured output (DataFrame, Series, list, dict, str)
        str  # For error messages or fallback string results
    ]
    company_info: List[Dict[str, Union[str, Any]]]
    

In [None]:
from langchain_core.tools import tool
from tavily import TavilyClient

# Initialize the Tavily client
tavily = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])

@tool
def get_company_info_tool(ticker_symbol: str) -> str:

    """
    metadata:
      name: get_company_info_tool
      description: "Fetches the most recent company info for a given ticker symbol using Tavily."
      usage: "Useful when web search and scrape must be done to retrieve required information."
      args: ticker_symbol (str): The ticker symbol for the company.
      returns: str: JSON string with company information.
    """
 
    try:
        # Query Tavily for company information
        query = f"{ticker_symbol}, What is the company name, company business and company recent stock market news?"
        response = tavily.search(query, max_results=1) # NOTE, We are asking Tavily for just one result and therefore the below code processes just one element: [0]! 

        # Process response
        if response and response["results"]:
            company_info = response["results"][0]
            return json.dumps({
                "ticker_symbol": ticker_symbol,
                "company_info": company_info
            })
        else:
            return json.dumps({
                "ticker_symbol": ticker_symbol,
                "error": "No information found."
            })
    except Exception as e:
        return json.dumps({
            "ticker_symbol": ticker_symbol,
            "error": str(e)
        })


In [None]:
from langchain_openai import ChatOpenAI
model = ChatOpenAI(model="gpt-4o", temperature=0)

# Binding the tool to the model
model_with_tools = model.bind_tools(
    [get_company_info_tool],  # List of tools
    parallel_tool_calls=True  # Parallel vs sequential execution. Set to parallel=true currently
)

In [None]:
USER_INPUT_TRANSLATION_PROMPT = """
You are an assistant that translates free-text user input into Python queries for a DataFrame. 
The user input will include exact mappings to the DataFrame's column names. 
Your task is to generate a well-formatted Python query based on the input.

Guidelines:
1. Do not include any additional text, labels, backticks, or newlines in the output. Return only the clean query string.
2. Ensure the query is ready for direct execution.

Example:
User input: "Give me all unique industries listed in the Industry column"
Output: df['Industry'].unique()

User input: {task}
Output:
"""

In [None]:
def task_to_query_node(state: AgentState):
    messages = [
        SystemMessage(content=USER_INPUT_TRANSLATION_PROMPT), 
        HumanMessage(content=state['task'])
    ]
    response = model.invoke(messages)

    # Print the response in Markdown format
    print(response)
    
    return {"query": response.content}

In [None]:
def execution_node(state: AgentState):
    """
    Executes the query generated by the LLM on a DataFrame read from 'ipo_data.xlsx' and provides consistent output.

    Args:
        state (AgentState): The state containing the query to execute.

    Returns:
        AgentState: Updated state with the query execution results.
    """
    try:
        # Read the DataFrame directly from the Excel file
        current_dir = os.getcwd()  # Get the current directory
        file_path = os.path.join(current_dir, "ipo_data.xlsx")  # Path to the Excel file
        df = pd.read_excel(file_path)  # Load the DataFrame from the file

        # Retrieve the query from the state
        query = state["query"]

        # Execute the query
        result = eval(query, {"df": df, "__builtins__": {}})

        # Handle different types of outputs consistently
        if isinstance(result, pd.DataFrame):
            # Convert DataFrame to a dictionary
            output = {
                "type": "DataFrame",
                "columns": result.columns.tolist(),
                "rows": result.to_dict(orient="records")
            }
        elif isinstance(result, pd.Series):
            # Convert Series to a list
            output = {
                "type": "Series",
                "name": result.name,
                "values": result.tolist()
            }
        elif isinstance(result, (list, dict)):
            # If it's a list or dictionary, return as-is
            output = {
                "type": type(result).__name__,
                "values": result
            }
        elif isinstance(result, (int, float, str)):
            # Handle scalar values (e.g., single numbers, strings)
            output = {
                "type": "Scalar",
                "value": result
            }
        else:
            # Catch-all fallback: Convert to string
            output = {
                "type": "Other",
                "value": str(result)
            }

        # Update the state with the formatted result
        state["result"] = output

        # Display the result in Markdown for better visualization
        # display(Markdown(f"**Execution Result:**\n```json\n{json.dumps(output, indent=2)}\n```"))

        return state
    except Exception as e:
        # Handle execution errors and update the state
        error_message = f"Error executing query: {e}"
        display(Markdown(f"**Execution Error:** {error_message}"))
        state["result"] = {"error": error_message}
        return state


In [None]:
GET_COMPANY_INFO_PROMPT = """
You are an assistant with access to a tool for retrieving company information.

Tool Available:
- `get_company_info_tool`: Fetches the most recent company information for a given ticker symbol. Provide the `ticker_symbol` as the argument.

Your task:
1. Parse the input structure and extract ALL unique `proposedTickerSymbol` values, regardless of the structure complexity or nesting.
   - The structure may contain nested objects or arrays, and `proposedTickerSymbol` may appear multiple times.

Expected Output:
A list of unique ticker symbols:
   ["AAPL", "JWN", "STRL"]
   
2. Using generated list of unique ticker symbols, iterate over each value in the list and invoke the tool available:
   - For each invocation pass the current list value as ticker symbol.
   - Continue invoking the tool for **all** elements in the list.

Expected Output:
Combined results from all tool invocations
"""

In [None]:
def get_company_data_node(state: AgentState):
    """
    Handles the model's response to invoke the appropriate tool.

    Args:
        state (AgentState): The agent's current state.

    Returns:
        AgentState: Updated state with tool invocation results.
    """
    try:
        result_data = state["result"]
        if not isinstance(result_data, str):
            result_data = json.dumps(result_data, indent=2)
        else:
            result_data = state["result"]

        # Create the prompt with the serialized result
        messages = [
            SystemMessage(content=GET_COMPANY_INFO_PROMPT),
            HumanMessage(content=f"Input: {result_data}")
        ]
        
        # Invoke the model
        response = model_with_tools.invoke(messages)

        # Check for tool calls
        if response.tool_calls:
            company_info_list = []
            for tool_call in response.tool_calls:
                if tool_call["name"] == "get_company_info_tool":
                    ticker_symbol = tool_call["args"]["ticker_symbol"]
                    # Use the invoke method instead of the deprecated call
                    tool_result = get_company_info_tool.invoke({"ticker_symbol": ticker_symbol})
                    company_info_list.append(json.loads(tool_result))

            state["company_info"] = company_info_list
        else:
            state["company_info"] = {"error": "No tool invocation detected."}

        # Display the result
        # display(Markdown(f"**Company Info:**\n```json\n{json.dumps(state['company_info'], indent=2)}\n```"))
        return state

    except Exception as e:
        # Handle errors gracefully
        error_message = f"Error in task_to_tool_node: {e}"
        display(Markdown(f"**Error:** {error_message}"))
        state["company_info"] = {"error": error_message}
        return state


In [None]:
# Initialise the graph with the agent state
builder = StateGraph(AgentState)

In [None]:
# Add all of the created nodes
builder.add_node("input", task_to_query_node)
builder.add_node("execute", execution_node)
builder.add_node("get_company_data", get_company_data_node)

In [None]:
builder.set_entry_point("input")

In [None]:
# Add the basic edges
builder.add_edge("input", "execute")
builder.add_edge("execute", "get_company_data")

In [None]:
builder.set_finish_point("get_company_data")

In [None]:
graph = builder.compile(checkpointer=memory)

In [None]:
from IPython.display import Image

Image(graph.get_graph().draw_png())

In [None]:
def pretty_print_event(event: dict):
    """Pretty-print a single event (dictionary)."""
    print(json.dumps(event, indent=2))

thread = {"configurable": {"thread_id": "3"}}
for s in graph.stream({
    # "data_frame": working_df,
    "task": "Given ""proposedTickerSymbol"" and ""Month 1"" through ""Month 13"" columns, give me only the proposedTickerSymbols and its ""Industry"" which only have positive performance in the Months 1 through 13. If no value in a month 1 through 13 count it as positive performance!",
}, thread):
    pretty_print_event(s) 