<a href="https://colab.research.google.com/github/dllochini/ai-dataset-agent/blob/main/notebooks/simple_data_analysing_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install pandas matplotlib seaborn openai gradio

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/pip/_internal/cli/base_command.py", line 179, in exc_logging_wrapper
    status = run_func(*args)
             ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pip/_internal/cli/req_command.py", line 67, in wrapper
    return func(self, options, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pip/_internal/commands/install.py", line 377, in run
    requirement_set = resolver.resolve(
                      ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pip/_internal/resolution/resolvelib/resolver.py", line 95, in resolve
    result = self._result = resolver.resolve(
                            ^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/pip/_vendor/resolvelib/resolvers.py", line 546, in resolve
    state = resolution.resolve(requirements, max_rounds=max_rounds)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

In [4]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import gradio as gr
import json

from google.colab import files, userdata
from openai import OpenAI

client = OpenAI(
    api_key=userdata.get("GROQ_API_KEY"),
    base_url="https://api.groq.com/openai/v1"
)

In [None]:
'''
print("Upload your dataset (CSV):")
uploaded = files.upload()
filename = list(uploaded.keys())[0]
df = pd.read_csv(filename)

print("Dataset loaded successfully!")
display(df.head())
'''

In [5]:
SYSTEM_PROMPT = """
You are a dataset analysis AI agent.

You operate strictly in the following loop:

Thought -> Action -> PAUSE -> Observation -> Thought -> ... -> Answer

GENERAL RULES :

1. You must either:
   - Call exactly ONE Action and then output PAUSE
   OR
   - Output a final Answer.

2. After calling an Action, you MUST format it exactly as:

Action: <action_name>[: <column_name_if_required>]
PAUSE

3. After PAUSE, you will receive:
Observation: <tool_output>

4. After receiving an Observation:
   - You MUST continue reasoning.
   - If the observation already contains all information required to answer the user's question,
     you MUST immediately output:

Answer: <final result>

   - Do NOT call additional actions if the answer can already be produced.
   - Do NOT recompute values that are already present in the Observation.
   - Do NOT repeat the same action unless absolutely necessary.

5. Never output an empty message.
6. Never output PAUSE unless you are calling an Action.
7. Never output Observation yourself.
8. Use exact column names when required.
9. Only perform actions directly necessary to answer the question.
10. If a single action provides enough information, immediately produce Answer.

SPECIAL RULES :

• If the user asks for an overview:
  - Only call dataset_overview.
  - After receiving the observation, immediately produce Answer.

• If the user asks for a statistical summary:
  - Only call statistical_summary.
  - After receiving the observation, immediately produce Answer.
  - Do NOT call column_mean or other actions afterward.

AVAILABLE ACTIONS:

dataset_overview
statistical_summary
missing_values
duplicate_count
column_mean: <column_name>
column_min: <column_name>
column_max: <column_name>
value_counts: <column_name>
correlation_matrix
number_of_rows
number_of_columns
plot_numeric_columns
"""

In [6]:
def dataset_overview(df, column=None):
    return {
        "rows": df.shape[0],
        "columns": df.shape[1],
        "column_names": df.columns.tolist(),
        "dtypes": df.dtypes.astype(str).to_dict()
    }

def statistical_summary(df, column=None):
    return df.describe(include='number').to_dict()

def missing_values(df, column=None):
    return df.isnull().sum().to_dict()

def duplicate_count(df, column=None):
    return int(df.duplicated().sum())

def column_mean(df, column=None):
    if column is None or column not in df.columns:
        return "Column not found."
    return df[column].mean()

def column_min(df, column=None):
    if column is None or column not in df.columns:
        return "Column not found."
    return df[column].min()

def column_max(df, column=None):
    if column is None or column not in df.columns:
        return "Column not found."
    return df[column].max()

def value_counts(df, column=None):
    if column is None or column not in df.columns:
        return "Column not found."
    return df[column].value_counts().to_dict()

def correlation_matrix(df, column=None):
    return df.corr(numeric_only=True)

def number_of_rows(df, column=None):
    return df.shape[0]

def number_of_columns(df, column=None):
    return df.shape[1]

def plot_numeric_columns(df, column=None, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)
    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns
    saved_paths = []
    for col in numeric_cols:
        plt.figure()
        sns.histplot(df[col], kde=True)
        plt.title(f"Distribution of {col}")
        path = os.path.join(save_dir, f"{col}.png")
        plt.savefig(path)
        plt.close()
        saved_paths.append(path)
    return saved_paths

KNOWN_ACTIONS = {
    "dataset_overview": dataset_overview,
    "statistical_summary": statistical_summary,
    "missing_values": missing_values,
    "duplicate_count": duplicate_count,
    "column_mean": column_mean,
    "column_min": column_min,
    "column_max": column_max,
    "value_counts": value_counts,
    "correlation_matrix": correlation_matrix,
    "number_of_rows": number_of_rows,
    "number_of_columns": number_of_columns,
    "plot_numeric_columns": plot_numeric_columns,
}

In [7]:
class Agent:
    def __init__(self, system_prompt):
        self.messages = [{"role": "system", "content": system_prompt}]

    def __call__(self, message):
        self.messages.append({"role": "user", "content": message})

        response = client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=self.messages,
        )

        content = response.choices[0].message.content
        self.messages.append({"role": "assistant", "content": content})

        return content


In [None]:
bot = Agent(system_prompt=SYSTEM_PROMPT)

Example 1

In [None]:
que = "give me a overview of the datatset"

print(bot(que))

In [None]:
result = dataset_overview(current_df)
print('function result :', result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

In [None]:
bot(next_prompt)

Example 2

In [None]:
que = "What is the minimum math_score?"

print(bot(que))

In [None]:
result = column_min(df,"math_score")
print(result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

In [None]:
bot(next_prompt)

Example 3

In [None]:
que = "give me a statistical overview of the datatset"

print(bot(que))

In [None]:
result = statistical_summary(df)
print(result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

In [None]:
bot(next_prompt)

Query Loop

In [8]:
action_re = re.compile(r"Action:\s*(\w+)(?::\s*(.*))?")

def query_loop(question, df, max_iters=8):
    agent = Agent(SYSTEM_PROMPT)
    reasoning_steps = []
    next_input = question
    images = None

    for _ in range(max_iters):
        output = agent(next_input)
        reasoning_steps.append(output)

        if not output.strip():
            return "Model returned empty response.", reasoning_steps, images

        if "Answer:" in output:
            final_answer = output.split("Answer:", 1)[1].strip()
            return final_answer, reasoning_steps, images

        match = action_re.search(output)

        if match:

            if "PAUSE" not in output:
                return "Protocol error: Model must output PAUSE after Action.", reasoning_steps, images

            action_name = match.group(1).strip().lower()
            action_input = match.group(2).strip() if match.group(2) else None

            if action_name not in KNOWN_ACTIONS:
                message = f"Unable to perform action '{action_name}' — this action is not supported."
                reasoning_steps.append(message)
                next_input = f"Observation: {message}"
                continue

            observation = KNOWN_ACTIONS[action_name](df, action_input)

            if action_name == "plot_numeric_columns":
                images = observation

            try:
                serialized = json.dumps(observation, default=str)
            except:
                serialized = str(observation)

            reasoning_steps.append(f"Observation (JSON): {serialized}")
            next_input = f"Observation: {serialized}"
            continue

    return (
        "The agent stopped after too many reasoning steps. "
        "This query may not be supported yet. Please try one of the available actions above.",
        reasoning_steps,
        images
    )

In [None]:
que = ["Give me an overview of this dataset.",
       "How many rows and columns are there?",
       "List all column names and their data types.",
       "Tell me which columns are numeric and which are categorical.",
       "Provide a statistical summary for all numeric columns.",
       "Give me the mean, min, max, and standard deviation of all columns.",
       "Show descriptive statistics of this dataset.",
       "What is the mean age of students?",
       "What is the maximum math_score?",
       "How many unique values are in the gender column?",
       "Check which columns have missing values.",
       "Are there any duplicate rows?",
       "Give me the correlation matrix for numeric columns.",
       "Plot the distribution of all numeric columns.",
       "Which student has the highest science_score and what is their attendance_percentage?",
       "Compare math_score and english_score correlation and mean values.",
       "Give me the mean of all columns.",
       "How many students have attendance_percentage above 90?"
       ]

query_loop(que[5], df)

In [9]:
description = """
Welcome! This AI Dataset Analysis Agent helps you explore and analyze your datasets interactively.

You can ask it to:

- **Get an overview of the dataset**: rows, columns, column names, and data types.
- **Generate a statistical summary** of numeric columns, including mean, standard deviation, min, max, and percentiles.
- **Check for missing values** in any column.
- **Count duplicate rows** in the dataset.
- **Get column-specific insights** such as mean, min, max, or value counts.
- **Compute the correlation matrix** between numeric columns.
- **Get the total number of rows or columns**.
- **Visualize numeric columns** with histograms (plots will appear in the gallery).

The agent uses step-by-step reasoning to answer your questions. Type your query about the dataset, and it will show its reasoning along with a final answer.

*Note:* Only supported actions from the above list can be performed. Queries outside these actions may not be answered.
"""

In [10]:
current_df = None

def load_file(file):
    global current_df
    if file is None:
        return "No file uploaded.", None
    current_df = pd.read_csv(file)
    return f"Dataset '{file}' loaded successfully with {current_df.shape[0]} rows and {current_df.shape[1]} columns.", None

def chat_interface(file, user_input):
    global current_df

    if current_df is None:
        if file is None:
            return "Please upload a dataset first.", None
        current_df = pd.read_csv(file)

    if not user_input:
        return "Please enter a question about the dataset.", None

    answer, reasoning, images = query_loop(user_input, current_df)
    return answer, images

demo = gr.Interface(
    fn=chat_interface,
    inputs=[
        gr.File(label="Upload your CSV dataset", type="filepath"),
        gr.Textbox(label="Ask a question about your dataset")
    ],
    outputs=[
        gr.Markdown(label="Agent Response"),
        gr.Gallery(label="Generated Plots")
    ],
    title="AI Dataset Analysis Agent",
    description=description
)


demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://3720b4015b9a25d7b1.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


