In [None]:
from io import StringIO
from os import getenv

from IPython.display import Image, display
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from pandas import read_excel

In [None]:
llm = AzureChatOpenAI(
    azure_deployment=getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"),
    model_version=getenv("OPENAI_API_VERSION"),
    temperature=0.25
)

In [None]:
class MessagesState(MessagesState):
    excel_path: str

In [None]:
def columns(excel_path: str) -> list[str]:
    """
    Returns the list of column names from an Excel file.

    Args:
        excel_path (str): The path to the Excel file.

    Returns:
        list[str]: A list of column names from the Excel sheet.
    """
    return "Columns:\n" + read_excel(excel_path).columns.to_list().__str__()


def information(excel_path: str) -> str:
    """
    Reads an Excel file and returns a string containing a summary of its DataFrame information.

    Parameters:
        excel_path (str): The path to the Excel file to be read.

    Returns:
        str: A string representation of the DataFrame's information, including column names,
             non-null counts, and data types.
    """
    buf = StringIO()
    read_excel(excel_path).info(buf=buf)
    return "Information:\n" + buf.getvalue()


def description(excel_path: str) -> str:
    """
    Reads an Excel file and returns a string containing descriptive statistics 
    of the numerical columns in the DataFrame.

    Parameters:
        excel_path (str): The path to the Excel file to be read.

    Returns:
        str: A string representation of the descriptive statistics, including 
             count, mean, standard deviation, min, max, and quartile values 
             for each numerical column.
    """
    return "Description:\n" + read_excel(excel_path).describe().to_string()


def head(excel_path: str) -> str:
    """
    Reads an Excel file and returns a string representation of the first five rows 
    of the DataFrame.

    Parameters:
        excel_path (str): The path to the Excel file to be read.

    Returns:
        str: A string showing the first five rows of the DataFrame, formatted as a table.
    """
    return "Head:\n" + read_excel(excel_path).head().to_string()

In [None]:
llm_with_tools = llm.bind_tools([columns, information, description, head])

In [None]:
def metadata_generator(state: MessagesState) -> MessagesState:
    system_message = SystemMessage(
        content="You will be given an excel file path. Your task is to generate the columns, information, description and head of the excel file."
    )
    return MessagesState(
        messages=llm_with_tools.invoke(
            (system_message +
             HumanMessage(content=state["excel_path"], name="Human")).messages
        )
    )


def required_plot_generator(state: MessagesState) -> MessagesState:
    system_message = SystemMessage(
        content="""You are a data analyst assistant. Using the following metadata from a pandas DataFrame (df.columns, df.info(), df.describe(), and df.head()), generate a comprehensive and customized list of EDA plots.

Each recommendation must:
    1. Name the specific plot type (e.g., Histogram, Box plot, Pair plot).
    2. Specify exact columns involved (e.g., Histogram of age, Box plot of income grouped by gender).
    3. Justify the plot based on the data's type, distribution, presence of nulls, or potential relationships.

Avoid generic suggestions; tailor each plot to the data provided."""
    )
    return MessagesState(
        messages=llm_with_tools.invoke(
            (system_message + state["messages"]).messages
        )
    )


def coder(state: MessagesState) -> MessagesState:
    system_message = SystemMessage(
        content="""You are a skilled Python developer focused on data wrangling and visualization. Your goal is to write clean, efficient code to create visualizations from Excel data and plot instructions.

Follow these rules:
    * Create all the plots that are requested.
    * Use common libraries like pandas, matplotlib, seaborn, and others if needed (e.g., plotly, openpyxl, numpy).
    * Prepare the data properly:
        - Handle missing or inconsistent values.
        - Convert categories to numbers when doing numeric analysis.
        - Filter or reshape data as needed for the plots.
    * Let the library handle layout sizes (don't set figure sizes manually).
    * Don't add values inside heatmaps.
    * Organize the code clearly:
        - Use comments and good structure.
        - Use functions where it makes sense.
    * If the metadata is incomplete, explain any assumptions clearly in comments."""
    )
    return MessagesState(
        messages=llm_with_tools.invoke(
            (system_message + state["messages"]).messages
        )
    )

In [None]:
builder = StateGraph(MessagesState)
builder.add_node("metadata_generator", metadata_generator)
builder.add_node("tools", ToolNode([columns, information, description, head]))
builder.add_node("required_plot_generator", required_plot_generator)
builder.add_node("coder", coder)
builder.add_edge(START, "metadata_generator")
builder.add_edge("metadata_generator", "tools")
builder.add_edge("tools", "required_plot_generator")
builder.add_edge("required_plot_generator", "coder")
builder.add_edge("coder", END)
graph = builder.compile()

In [None]:
display(Image(graph.get_graph().draw_mermaid_png()))

In [None]:
result = graph.invoke(MessagesState(excel_path="titanic_dataset.xlsx"))

In [None]:
print(result["messages"][-1].content)