In [None]:
from __future__ import annotations

import datetime as dt

import polars as pl
from langchain.agents import create_agent
from langchain.chat_models import init_chat_model
from langchain.messages import HumanMessage
from langchain.tools import BaseTool, tool
from pydantic import BaseModel, Field

from chain_reaction.config import APIKeys, ModelBehavior, ModelName

# Toolkit

In [None]:
class ColumnSummary(BaseModel):
    """A summary of a single column in a DataFrame."""

    dtype: str = Field(description="The data type of the column.")
    count: int = Field(description="The number of non-null entries in the column.")
    null_count: int = Field(description="The number of null entries in the column.")
    unique_count: int = Field(description="The number of unique entries in the column.")

    @classmethod
    def from_series(cls, series: pl.Series) -> ColumnSummary:
        """Create a ColumnSummary from a Polars Series."""
        return cls(
            dtype=str(series.dtype),
            count=series.len() - series.null_count(),
            null_count=series.null_count(),
            unique_count=series.n_unique(),
        )


class DataFrameSummary(BaseModel):
    """A summary of a DataFrame."""

    num_rows: int = Field(description="The number of rows in the DataFrame.")
    num_columns: int = Field(description="The number of columns in the DataFrame.")
    column_names: list[str] = Field(description="The names of the columns in the DataFrame.")
    column_summaries: dict[str, ColumnSummary] = Field(description="A summary of each column in the DataFrame.")

    @classmethod
    def from_dataframe(cls, df: pl.DataFrame) -> DataFrameSummary:
        """Create a DataFrameSummary from a Polars DataFrame."""
        return cls(
            num_rows=df.height,
            num_columns=df.width,
            column_names=df.columns,
            column_summaries={col: ColumnSummary.from_series(df[col]) for col in df.columns},
        )


class PolarsToolkit:
    """A toolkit for working with Polars DataFrames."""

    def __init__(self, dataframes: dict[str, pl.DataFrame] | None = None) -> None:
        """Initialize the PolarsToolkit with an optional dataframe cache."""
        self.dataframes = dataframes or {}

    def list_dataframes(self) -> list[str]:
        """List the names of available dataframes in the dataframe cache.

        Returns:
            list[str]: A list of dataframe names in the dataframe cache.
        """
        return list(self.dataframes.keys())

    def get_dataframe_summary(self, name: str) -> DataFrameSummary:
        """Get a summary of the specified dataframe.

        Args:
            name (str): The name of the dataframe to summarize.

        Returns:
            DataFrameSummary: A summary of the specified dataframe.

        Raises:
            KeyError: If the specified dataframe is not found in the cache.
        """
        df = self.dataframes.get(name)
        if df is None:
            raise KeyError(f"DataFrame '{name}' not found in cache.")
        return DataFrameSummary.from_dataframe(df)

    def get_tools(self) -> list[BaseTool]:
        """Get tools."""
        return [tool(self.list_dataframes), tool(self.get_dataframe_summary)]

# Example

In [None]:
# Define a sample DataFrame
df = pl.DataFrame({
    "name": ["Alice Archer", "Ben Brown", "Chloe Cooper", "Daniel Donovan"],
    "birthdate": [
        dt.date(1997, 1, 10),
        dt.date(1985, 2, 15),
        dt.date(1983, 3, 22),
        dt.date(1981, 4, 30),
    ],
    "weight": [57.9, 72.5, 53.6, 83.1],  # (kg)
    "height": [1.56, 1.77, 1.65, 1.75],  # (m)
})

In [None]:
toolkit = PolarsToolkit(dataframes={"people": df})

df_tools = toolkit.get_tools()

In [None]:
# Initialize a chat model
chat_model = init_chat_model(
    model=ModelName.CLAUDE_HAIKU,
    timeout=60,
    max_retries=2,
    api_key=APIKeys().anthropic,
    **ModelBehavior.factual().model_dump(),
)

system_prompt = """
You are an agent designed to interact with a Polars DataFrame.
Use the tools provided to you to answer user questions about the DataFrame.
"""

df_agent = create_agent(
    chat_model,
    tools=df_tools,
    system_prompt=system_prompt,
)

In [None]:
df_agent.invoke({"messages": [HumanMessage(content="What data is available?")]})