In [None]:
from __future__ import annotations

import polars as pl
from config import get_chat_model
from langchain.agents import create_agent
from langchain.messages import HumanMessage
from sklearn import datasets

from dfkit import DataFrameToolkit

# Load dataset as a dataframe

In [None]:
data, target = datasets.load_diabetes(return_X_y=True, scaled=False)
df = pl.DataFrame(
    data=data,
    schema=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"],
)

df = df.with_columns(
    pl.col("sex").map_elements(lambda x: "male" if x == 1 else "female", return_dtype=pl.String),
    pl.Series(target).alias("disease_progression"),
)

df

# Register the dataset with the toolkit

In [None]:
# Initialize the toolkit
toolkit = DataFrameToolkit()

# Register the diabetes dataset with the toolkit
_ = toolkit.register_dataframe(
    name="Diabetes Progression Dataset",
    dataframe=df,
    description="""
    Ten baseline variables, age, sex, body mass index, average blood pressure,
    and six blood serum measurements were obtained for each diabetes patient,
    as well as the response of interest, a quantitative measure of disease
    progression one year after baseline.
    """,
    column_descriptions={
        "age": "Age of the patient in years.",
        "sex": "Sex of the patient",
        "bmi": "Body mass index.",
        "bp": "Average blood pressure.",
        "s1": "TC, total serum cholesterol.",
        "s2": "LDL, low-density lipoproteins.",
        "s3": "HDL, high-density lipoproteins.",
        "s4": "TCH, total cholesterol / HDL.",
        "s5": "LTG, possibly log of serum triglycerides level.",
        "s6": "GLU, blood sugar level.",
        "disease_progression": "A quantitative measure of disease progression one year after baseline.",
    },
)

# View the registered dataset as a markdown table
print(toolkit.view_as_markdown_table("Diabetes Progression Dataset"))

# Create an agent with the toolkit

In [None]:
df_agent = create_agent(
    model=get_chat_model(),
    tools=toolkit.get_tools(),
    system_prompt=toolkit.get_system_prompt(),
)

# Ask questions about the dataset

In [None]:
response = df_agent.invoke({
    "messages": [HumanMessage("What's the relationship between BMI and disease progression in the diabetes dataset?")]
})
messages = response.get("messages", [])
if messages:
    last_message = messages[-1]
    print(last_message.content)

In [None]:
response = df_agent.invoke({
    "messages": [
        HumanMessage("Is there a significant difference in disease progression between male and female patients?")
    ]
})
messages = response.get("messages", [])
if messages:
    last_message = messages[-1]
    print(last_message.content)

In [None]:
print(len(toolkit.list_dataframes()))