In [1]:
from __future__ import annotations

import datetime as dt
from enum import StrEnum

import numpy as np
import polars as pl
from config import get_chat_model
from langchain.agents import create_agent
from langchain.messages import HumanMessage
from langchain.tools import BaseTool, tool
from pydantic import BaseModel, Field
from sklearn import datasets

from dfkit.context import DataFrameContext
from dfkit.models import DataFrameReference

  from pydantic.v1.fields import FieldInfo as FieldInfoV1


In [2]:
data, target = datasets.load_diabetes(return_X_y=True, scaled=False)
df = pl.DataFrame(
    data=data,
    schema=["age", "sex", "bmi", "bp", "s1", "s2", "s3", "s4", "s5", "s6"],
)

df = df.with_columns(
    pl.col("sex").map_elements(lambda x: "male" if x == 1 else "female", return_dtype=pl.String),
    pl.Series(target).alias("disease_progression"),
)

df

age,sex,bmi,bp,s1,s2,s3,s4,s5,s6,disease_progression
f64,str,f64,f64,f64,f64,f64,f64,f64,f64,f64
59.0,"""female""",32.1,101.0,157.0,93.2,38.0,4.0,4.8598,87.0,151.0
48.0,"""male""",21.6,87.0,183.0,103.2,70.0,3.0,3.8918,69.0,75.0
72.0,"""female""",30.5,93.0,156.0,93.6,41.0,4.0,4.6728,85.0,141.0
24.0,"""male""",25.3,84.0,198.0,131.4,40.0,5.0,4.8903,89.0,206.0
50.0,"""male""",23.0,101.0,192.0,125.4,52.0,4.0,4.2905,80.0,135.0
…,…,…,…,…,…,…,…,…,…,…
60.0,"""female""",28.2,112.0,185.0,113.8,42.0,4.0,4.9836,93.0,178.0
47.0,"""female""",24.9,75.0,225.0,166.0,42.0,5.0,4.4427,102.0,104.0
60.0,"""female""",24.9,99.67,162.0,106.6,43.0,3.77,4.1271,95.0,132.0
36.0,"""male""",30.0,95.0,201.0,125.2,42.0,4.79,5.1299,85.0,220.0


In [4]:
dfr = DataFrameReference.from_dataframe(
    name="Diabetes Progression Dataset",
    dataframe=df,
    description="""
    Ten baseline variables, age, sex, body mass index, average blood pressure,
    and six blood serum measurements were obtained for each diabetes patient,
    as well as the response of interest, a quantitative measure of disease 
    progression one year after baseline.
    """,
    column_descriptions={
        "age": "Age of the patient in years.",
        "sex": "Sex of the patient",
        "bmi": "Body mass index.",
        "bp": "Average blood pressure.",
        "s1": "TC, total serum cholesterol.",
        "s2": "LDL, low-density lipoproteins.",
        "s3": "HDL, high-density lipoproteins.",
        "s4": "TCH, total cholesterol / HDL.",
        "s5": "LTG, possibly log of serum triglycerides level.",
        "s6": "GLU, blood sugar level.",
        "disease_progression": "A quantitative measure of disease progression one year after baseline.",
    },
)
dfr

DataFrameReference(id='df_fcf14bba', name='Diabetes Progression Dataset', description='\n    Ten baseline variables, age, sex, body mass index, average blood pressure,\n    and six blood serum measurements were obtained for each diabetes patient,\n    as well as the response of interest, a quantitative measure of disease \n    progression one year after baseline.\n    ', num_rows=442, num_columns=11, column_names=['age', 'sex', 'bmi', 'bp', 's1', 's2', 's3', 's4', 's5', 's6', 'disease_progression'], column_summaries={'age': ColumnSummary(description='Age of the patient in years.', dtype='Float64', count=442, null_count=0, unique_count=58, min=19.0, max=79.0, mean=48.51809954751131, std=13.10902782204109, p25=38.0, p50=50.0, p75=59.0), 'sex': ColumnSummary(description='Sex of the patient', dtype='String', count=442, null_count=0, unique_count=2, min='female', max='male', mean=None, std=None, p25=None, p50=None, p75=None), 'bmi': ColumnSummary(description='Body mass index.', dtype='Float

In [5]:
def to_markdown_table(df: pl.DataFrame, columns: list[str] | None = None, num_rows: int = 10) -> str:
    """Convert a Polars DataFrame to a markdown table string representation.

    Args:
        df (pl.DataFrame): The Polars DataFrame to convert.
        columns (list[str] | None): Optional list of column names to include in the output.
            If None, all columns are included (by not necessarily displayed). Defaults to None.
        num_rows (int): The number of rows to include in the output. Defaults to 10.

    Returns:
        str: A string representation of the DataFrame in markdown table format.
    """
    with pl.Config(
        tbl_formatting="MARKDOWN",
        tbl_hide_column_data_types=True,
        tbl_hide_column_names=False,
        tbl_hide_dataframe_shape=True,
        tbl_rows=num_rows,
        tbl_cols=len(columns) if columns is not None else None,
    ):
        # Select only the specified columns if provided, otherwise include all columns
        if columns is not None:
            if (extra_columns := set(columns) - set(df.columns)):
                raise ValueError(f"Columns {extra_columns} not found in DataFrame.")
            df = df.select(columns)

        return str(df)

In [6]:
print(to_markdown_table(df, num_rows=25))

| age  | sex    | bmi  | bp    | … | s4   | s5     | s6    | disease_progression |
|------|--------|------|-------|---|------|--------|-------|---------------------|
| 59.0 | female | 32.1 | 101.0 | … | 4.0  | 4.8598 | 87.0  | 151.0               |
| 48.0 | male   | 21.6 | 87.0  | … | 3.0  | 3.8918 | 69.0  | 75.0                |
| 72.0 | female | 30.5 | 93.0  | … | 4.0  | 4.6728 | 85.0  | 141.0               |
| 24.0 | male   | 25.3 | 84.0  | … | 5.0  | 4.8903 | 89.0  | 206.0               |
| 50.0 | male   | 23.0 | 101.0 | … | 4.0  | 4.2905 | 80.0  | 135.0               |
| 23.0 | male   | 22.6 | 89.0  | … | 2.0  | 4.1897 | 68.0  | 97.0                |
| 36.0 | female | 22.0 | 90.0  | … | 3.0  | 3.9512 | 82.0  | 138.0               |
| 66.0 | female | 26.2 | 114.0 | … | 4.55 | 4.2485 | 92.0  | 63.0                |
| 60.0 | female | 32.1 | 83.0  | … | 4.0  | 4.4773 | 94.0  | 110.0               |
| 29.0 | male   | 30.0 | 85.0  | … | 4.0  | 5.3845 | 88.0  | 310.0               |
| 22