<a href="https://colab.research.google.com/github/dllochini/ai-dataset-agent/blob/main/notebooks/simple_data_analysing_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pandas matplotlib seaborn openai



In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import re
import gradio as gr

from google.colab import files, userdata
from openai import OpenAI

client = OpenAI(
    api_key=userdata.get("GROQ_API_KEY"),
    base_url="https://api.groq.com/openai/v1"
)

In [3]:
print("Upload your dataset (CSV):")
uploaded = files.upload()
filename = list(uploaded.keys())[0]
df = pd.read_csv(filename)

print("Dataset loaded successfully!")
display(df.head())

Upload your dataset (CSV):


Saving student_data.csv to student_data (2).csv
Dataset loaded successfully!


Unnamed: 0,student_id,name,age,gender,math_score,english_score,science_score,attendance_percentage
0,1,Alice,20,Female,85,78,92,95
1,2,Bob,21,Male,72,65,70,88
2,3,Charlie,19,Male,90,88,85,97
3,4,Diana,22,Female,60,75,68,80
4,5,Edward,20,Male,95,91,89,99


In [4]:
SYSTEM_PROMPT = """

You are a aimple dataset analyser ai agent.

You operate in a loop of:

Thought -> Action -> PAUSE -> Observation

When you have enough information, output:
Answer: <final answer>

Rules:
- Only call ONE action at a time.
- After calling Action, output PAUSE.
- Do not hallucinate observations.
- Wait for Observation before continuing.
- Use exact column names.
- Only perform actions directly necessary to answer the user's question.
- Do NOT perform additional exploratory analysis unless explicitly requested.
- If the user explicitly asks for a statistical summary of the dataset, only call statistical_summary and immediately output the result after receiving the observation.
- If the user asks for a summary or an overview, only call dataset_overview and immediately output the result after receiving the observation. Do not call any other actions like number_of_rows or number_of_columns.
- If a single action provides sufficient information, immediately output:
Answer: <final result>

When you receive an Observation, you MUST either:
- Call another Action directly relevant to answering the question, OR
- Output Answer: <final result>

You MUST eventually output:
Answer: <final result>

Available actions:

dataset_overview
statistical_summary
missing_values
duplicate_count
column_mean: <column_name>
column_min: <column_name>
column_max: <column_name>
value_counts: <column_name>
correlation_matrix
number_of_rows
number_of_columns
plot_numeric_columns
"""

In [6]:
def dataset_overview(df):
    overview = {
        "rows": df.shape[0],
        "columns": df.shape[1],
        "column_names": df.columns.tolist(),
        "dtypes": df.dtypes.astype(str).to_dict()
    }

    return overview

def statistical_summary(df):
    numeric_summary = df.describe(include='number').to_dict()
    return numeric_summary

def missing_values(df):
    return df.isnull().sum().to_string()

def duplicate_count(df):
    return int(df.duplicated().sum())

def column_mean(df, column):
    return df[column].mean() if column in df.columns else "Column not found."

def column_min(df, column):
    return df[column].min() if column in df.columns else "Column not found."

def column_max(df, column):
    return df[column].max() if column in df.columns else "Column not found."

def value_counts(df, column):
    return df[column].value_counts().to_string() if column in df.columns else "Column not found."

def correlation_matrix(df):
    return df.corr(numeric_only=True).to_string()

def number_of_rows(df):
    return df.shape[0]

def number_of_columns(df):
    return df.shape[1]

def plot_numeric_columns(df, save_dir="plots"):
    os.makedirs(save_dir, exist_ok=True)
    numeric_cols = df.select_dtypes(include=['int64', 'float64']).columns

    saved_paths = []

    for col in numeric_cols:
        plt.figure()
        sns.histplot(df[col], kde=True)
        plt.title(f"Distribution of {col}")

        path = os.path.join(save_dir, f"{col}.png")
        plt.savefig(path)
        plt.close()

        saved_paths.append(path)

    return saved_paths


In [7]:
KNOWN_ACTIONS = {
    "dataset_overview": dataset_overview,
    "statistical_summary": statistical_summary,
    "missing_values": missing_values,
    "duplicate_count": duplicate_count,
    "column_mean": column_mean,
    "column_min": column_min,
    "column_max": column_max,
    "value_counts": value_counts,
    "correlation_matrix": correlation_matrix,
    "number_of_rows": number_of_rows,
    "number_of_columns": number_of_columns,
    "plot_numeric_columns": plot_numeric_columns,
}


In [8]:
class Agent:
    def __init__(self, system_prompt):
        self.messages = [{"role": "system", "content": system_prompt}]

    def __call__(self, message):
        self.messages.append({"role": "user", "content": message})

        response = client.chat.completions.create(
            model="llama-3.1-8b-instant",
            messages=self.messages,
        )

        content = response.choices[0].message.content
        self.messages.append({"role": "assistant", "content": content})

        return content


In [9]:
bot = Agent(system_prompt=SYSTEM_PROMPT)

Example 1

In [10]:
que = "give me a overview of the datatset"

print(bot(que))

Thought: To get a summary of the dataset, I need to call an action that provides an overview.

Action: dataset_overview

PAUSE

Please wait while I receive the observation...


In [11]:
result = dataset_overview(df)
print('function result :', result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

function result : {'rows': 10, 'columns': 8, 'column_names': ['student_id', 'name', 'age', 'gender', 'math_score', 'english_score', 'science_score', 'attendance_percentage'], 'dtypes': {'student_id': 'int64', 'name': 'object', 'age': 'int64', 'gender': 'object', 'math_score': 'int64', 'english_score': 'int64', 'science_score': 'int64', 'attendance_percentage': 'int64'}}
Observation: {'rows': 10, 'columns': 8, 'column_names': ['student_id', 'name', 'age', 'gender', 'math_score', 'english_score', 'science_score', 'attendance_percentage'], 'dtypes': {'student_id': 'int64', 'name': 'object', 'age': 'int64', 'gender': 'object', 'math_score': 'int64', 'english_score': 'int64', 'science_score': 'int64', 'attendance_percentage': 'int64'}}


In [12]:
bot(next_prompt)

"With the observation, I now have enough information to answer your question about the dataset.\n\nAnswer: The dataset has 10 rows and 8 columns, with data types as follows: 'int64' for 'student_id', 'age', 'math_score', 'english_score', 'science_score', 'attendance_percentage' and 'object' for 'name', 'gender'."

Example 2

In [13]:
que = "What is the minimum math_score?"

print(bot(que))

Thought: To find the minimum 'math_score', I need to call an action that can retrieve the minimum value.

Action: column_min: math_score

PAUSE

Please wait while I receive the observation...


In [14]:
result = column_min(df,"math_score")
print(result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

60
Observation: 60


In [15]:
bot(next_prompt)

"With the observation of the minimum 'math_score', I now have enough information to answer your question.\n\nAnswer: 60"

Example 3

In [16]:
que = "give me a statistical overview of the datatset"

print(bot(que))

Thought: To get a statistical summary of the dataset, I need to call an action that provides the summary.

Action: statistical_summary

PAUSE

Please wait while I receive the observation...


In [17]:
result = statistical_summary(df)
print(result)
next_prompt = "Observation: {}".format(result)
print(next_prompt)

{'student_id': {'count': 10.0, 'mean': 5.5, 'std': 3.0276503540974917, 'min': 1.0, '25%': 3.25, '50%': 5.5, '75%': 7.75, 'max': 10.0}, 'age': {'count': 10.0, 'mean': 20.7, 'std': 1.3374935098492586, 'min': 19.0, '25%': 20.0, '50%': 20.5, '75%': 21.75, 'max': 23.0}, 'math_score': {'count': 10.0, 'mean': 80.0, 'std': 11.8227652339788, 'min': 60.0, '25%': 72.5, '50%': 81.5, '75%': 89.5, 'max': 95.0}, 'english_score': {'count': 10.0, 'mean': 79.6, 'std': 10.035492569431314, 'min': 65.0, '25%': 71.25, '50%': 80.0, '75%': 87.25, 'max': 94.0}, 'science_score': {'count': 10.0, 'mean': 82.2, 'std': 9.986657765906136, 'min': 68.0, '25%': 73.0, '50%': 84.5, '75%': 89.75, 'max': 96.0}, 'attendance_percentage': {'count': 10.0, 'mean': 91.2, 'std': 6.250333324444918, 'min': 80.0, '25%': 87.25, '50%': 91.5, '75%': 96.5, 'max': 99.0}}
Observation: {'student_id': {'count': 10.0, 'mean': 5.5, 'std': 3.0276503540974917, 'min': 1.0, '25%': 3.25, '50%': 5.5, '75%': 7.75, 'max': 10.0}, 'age': {'count': 10.0

In [18]:
bot(next_prompt)

'With the observation of the statistical summary, I now have enough information to answer your question about the dataset.\n\nAnswer: \n\n Student ID: mean 5.5, std 3.028, range 1-10\n Age: mean 20.7, std 1.337, range 19-23\n Math Score: mean 80, std 11.823, range 60-95\n English Score: mean 79.6, std 10.035, range 65-94\n Science Score: mean 82.2, std 9.987, range 68-96\n Attendance Percentage: mean 91.2, std 6.25, range 80-99'