In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

class DataAnalystAgent:
    def __init__(self, dataset_path, api_key=None):
        """
        Initialize the Data Analyst Agent with a dataset and optional API key.

        :param dataset_path: Path to the dataset (CSV format).
        :param api_key: API key for external integrations (optional).
        """
        self.dataset = pd.read_csv(dataset_path)
        self.api_key = api_key

    def summarize_data(self):
        """Summarize the dataset with general statistics."""
        return self.dataset.describe(include='all')

    def get_column_info(self):
        """Provide information about the dataset columns."""
        return self.dataset.info(), self.dataset.isnull().sum()

    def identify_unique_values(self):
        """Identify unique values in each column."""
        return {col: self.dataset[col].unique() for col in self.dataset.select_dtypes(include=['object', 'category']).columns}

    def identify_missing_values(self):
        """Identify missing values in the dataset."""
        return self.dataset.isnull().sum()

    def identify_duplicated_values(self):
        """Identify duplicated rows in the dataset."""
        return self.dataset[self.dataset.duplicated()]

    def get_shape_of_data(self):
        """Get the shape of the dataset."""
        return self.dataset.shape

    def analyze_query(self, query):
        """Analyze a specific query."""
        query = query.lower()
        
        if "basic statistics" in query:
            return self.summarize_data()
        elif "unique values" in query:
            return self.identify_unique_values()
        elif "missing values" in query:
            return self.identify_missing_values()
        elif "duplicated values" in query:
            return self.identify_duplicated_values()
        elif "shape of data" in query:
            return self.get_shape_of_data()
        elif "distribution of" in query:
            column_name = query.split("distribution of ")[1].strip()
            if column_name in self.dataset.columns:
                plt.figure(figsize=(8, 5))
                sns.histplot(self.dataset[column_name], kde=True, color='blue')
                plt.title(f"Distribution of {column_name}")
                plt.xlabel(column_name)
                plt.ylabel("Frequency")
                plt.show()
                return f"Displayed distribution for {column_name}."
            else:
                return f"Column '{column_name}' not found."
        elif "top" in query and "values in" in query:
            try:
                n = int(query.split("top ")[1].split(" values in")[0])
                column_name = query.split("values in ")[1].strip()
                if column_name in self.dataset.columns:
                    return self.dataset[column_name].value_counts().head(n)
                else:
                    return f"Column '{column_name}' not found."
            except Exception as e:
                return f"Error in parsing query: {e}"
        elif "average" in query and "grouped by" in query:
            parts = query.split("average of ")[1].split(" grouped by ")
            num_col = parts[0].strip()
            cat_col = parts[1].strip()
            if num_col in self.dataset.columns and cat_col in self.dataset.columns:
                return self.dataset.groupby(cat_col)[num_col].mean()
            else:
                return f"Columns '{num_col}' or '{cat_col}' not found."
        elif "correlation between" in query:
            cols = query.split("correlation between ")[1].split(" and ")
            if len(cols) == 2 and all(col.strip() in self.dataset.columns for col in cols):
                return self.dataset[cols[0]].corr(self.dataset[cols[1]])
            else:
                return f"One or both columns not found: {cols}"
        elif "entries that meet" in query:
            condition = query.split("entries that meet ")[1].strip()
            try:
                return self.dataset.query(condition)
            except Exception as e:
                return f"Error in parsing condition: {e}"
        elif "time trend of" in query:
            column_name = query.split("time trend of ")[1].strip()
            if column_name in self.dataset.columns:
                plt.figure(figsize=(10, 6))
                self.dataset[column_name].plot()
                plt.title(f"Time Trend of {column_name}")
                plt.xlabel("Index")
                plt.ylabel(column_name)
                plt.show()
                return f"Displayed time trend for {column_name}."
        else:
            return "Query not recognized. Try asking about basic statistics, unique values, distributions, correlations, or groupings."

    def plot_histogram(self, column_name):
        """Plot a histogram for a specific column."""
        plt.figure(figsize=(8, 5))
        self.dataset[column_name].hist(bins=20, color='skyblue', edgecolor='black')
        plt.title(f"Histogram of {column_name}")
        plt.xlabel(column_name)
        plt.ylabel("Frequency")
        plt.show()

    def plot_boxplot(self, column_name):
        """Plot a boxplot for a specific column."""
        plt.figure(figsize=(8, 5))
        sns.boxplot(y=self.dataset[column_name], color='lightcoral')
        plt.title(f"Boxplot of {column_name}")
        plt.ylabel(column_name)
        plt.show()

    def plot_pairplot(self, columns=None):
        """Plot pairplots for multiple columns in the dataset."""
        if columns:
            sns.pairplot(self.dataset[columns])
        else:
            sns.pairplot(self.dataset.select_dtypes(include=['number']))
        plt.show()

    def plot_heatmap(self):
        """Plot a heatmap of correlations for numeric columns in the dataset."""
        plt.figure(figsize=(10, 8))
        sns.heatmap(self.dataset.corr(), annot=True, cmap='coolwarm', fmt='.2f')
        plt.title("Correlation Heatmap")
        plt.show()

    def use_api(self, endpoint, params):
        """Example method to demonstrate using an API with the provided key."""
        if not self.api_key:
            raise ValueError("API key is not set. Please provide a valid API key.")
        
        # Simulated API call
        print(f"Calling API at {endpoint} with params {params} and API key {self.api_key}")
        return {"status": "success", "data": "This is a simulated API response."}

# Example usage (uncomment to use):
# agent = DataAnalystAgent(dataset_path="C:/Users/XXXX/Downloads/orders.csv", api_key=os.getenv("YOUR_API_KEY"))
# print(agent.analyze_query("unique values"))
# print(agent.analyze_query("missing values"))
# print(agent.analyze_query("duplicated values"))
# print(agent.analyze_query("shape of data"))


In [4]:
agent = DataAnalystAgent(dataset_path="orders.csv", api_key="")


In [12]:
display(agent.summarize_data())

Unnamed: 0,Column1,Column2,Column3,Column4,Column5,Column6,Column7,Column8,Column9,Column10,Column11
count,21.0,20,18,18,17.0,15,20.0,21,21.0,20.0,18
unique,,18,15,8,15.0,13,17.0,18,,,14
top,,Kate,Smith,F,111111111.0,"Atlanta, GA",56.0,pasta and pizza,,,$14
freq,,2,2,5,2.0,2,2.0,2,,,3
mean,521.666667,,,,,,,,2.095238,4.125,
std,283.703249,,,,,,,,0.889087,1.890941,
min,98.0,,,,,,,,1.0,1.0,
25%,234.0,,,,,,,,2.0,3.375,
50%,568.0,,,,,,,,2.0,4.25,
75%,765.0,,,,,,,,2.0,5.0,


In [5]:
print(agent.analyze_query("missing values"))

Column1     0
Column2     1
Column3     3
Column4     3
Column5     4
Column6     6
Column7     1
Column8     0
Column9     0
Column10    1
Column11    3
dtype: int64


In [6]:
print(agent.analyze_query("duplicated values"))

Empty DataFrame
Columns: [Column1, Column2, Column3, Column4, Column5, Column6, Column7, Column8, Column9, Column10, Column11]
Index: []


In [8]:
print(agent.analyze_query("columns"))

Query not recognized. Try asking about basic statistics, unique values, distributions, correlations, or groupings.
