In [4]:
import gradio as gr
import tempfile
import pandas as pd
import matplotlib.pyplot as plt
import ollama  # Using Ollama instead of OpenAI
from pydantic import BaseModel, ValidationError  # Using pydantic for data validation
from dotenv import load_dotenv

load_dotenv()  # Load environment variables from .env file

# Ensure Ollama model is available (Download locally if needed)
OLLAMA_MODEL = "llama3"  # Change this if you want a different model

# Define a Pydantic model for structured data validation
class CSVData(BaseModel):
    column: str
    value: float

# Function to process queries using Ollama
def query_llm(query: str, df_summary: str) -> str:
    """Process user query using Ollama AI model."""
    prompt = f"""You are an AI trained to analyze CSV data.
    Here is a summary of the dataset:
    {df_summary}

    User Query: {query}
    Provide a clear and concise answer.
    """

    try:
        response = ollama.chat(model=OLLAMA_MODEL, messages=[{"role": "user", "content": prompt}])
        llm_answer = response["message"]["content"].strip()
        return llm_answer
    except Exception as e:
        return f"Error: {str(e)}"

# Function to validate CSV data using Pydantic
def validate_csv_data(df):
    """Validate CSV data using Pydantic model."""
    errors = []
    for index, row in df.iterrows():
        try:
            CSVData(column=row.iloc[0], value=row.iloc[1])  # Validate each row
        except ValidationError as e:
            errors.append(f"Row {index + 1}: {str(e)}")
    return errors

# Function to process user query and return dataframe response
def process_query(query, file):
    if file is None:
        return "Error: Please upload a CSV file.", None
    
    try:
        df = pd.read_csv(file)
        if df.empty:
            return "Error: The uploaded CSV file is empty.", None

        # Generate dataset summary (including categorical data)
        summary = df.describe(include='all').to_string()  # Include all columns

        # Add additional summary for categorical data
        categorical_summary = "\n\nCategorical Data Summary:\n"
        for col in df.select_dtypes(include=['object']).columns:
            categorical_summary += f"{col}:\n{df[col].value_counts().to_string()}\n\n"

        full_summary = summary + categorical_summary
        response = query_llm(query, full_summary)  # Call Ollama model
        
        return response, df
    except Exception as e:
        return f"Error processing file: {str(e)}", None

# Function to generate and return a plot as an image
def generate_plot(file, column, plot_type):
    if file is None:
        return "Please upload a CSV file."

    try:
        df = pd.read_csv(file)
        if column not in df.columns:
            return f"Error: Column '{column}' not found in the CSV file."

        plt.figure(figsize=(6, 4))
        
        if plot_type == "Histogram":
            if pd.api.types.is_numeric_dtype(df[column]):
                df[column].hist(bins=20, edgecolor="black")
                plt.xlabel(column)
                plt.ylabel("Frequency")
                plt.title(f"Histogram of {column}")
            else:
                df[column].value_counts().plot(kind='bar', edgecolor="black")
                plt.xlabel(column)
                plt.ylabel("Count")
                plt.title(f"Bar Chart of {column}")
        elif plot_type == "Scatter Plot":
            if len(df.columns) < 2:
                return "Error: Need at least two columns for a scatter plot."
            if pd.api.types.is_numeric_dtype(df[column]) and pd.api.types.is_numeric_dtype(df[df.columns[1]]):
                plt.scatter(df[column], df[df.columns[1]])
                plt.xlabel(column)
                plt.ylabel(df.columns[1])
                plt.title(f"Scatter Plot of {column} vs {df.columns[1]}")
            else:
                return "Error: Scatter plot requires numeric columns."
        elif plot_type == "Bar Chart":
            df[column].value_counts().plot(kind='bar', edgecolor="black")
            plt.xlabel(column)
            plt.ylabel("Count")
            plt.title(f"Bar Chart of {column}")
        elif plot_type == "Line Plot":
            if pd.api.types.is_numeric_dtype(df[column]):
                df[column].plot(kind='line')
                plt.xlabel("Index")
                plt.ylabel(column)
                plt.title(f"Line Plot of {column}")
            else:
                return "Error: Line plot requires a numeric column."

        # Save the plot to a temporary file
        temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
        plt.savefig(temp_file.name, format='png')
        plt.close()

        return temp_file.name  # Return the file path

    except Exception as e:
        return f"Error generating plot: {str(e)}"
        
# Create Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## CSV Query and Visualization Tool (Ollama)")

    with gr.Row():
        csv_file = gr.File(label="Upload CSV")

    with gr.Row():
        query_input = gr.Textbox(label="Enter your query")
        query_output = gr.Textbox(label="LLM Response")
    
    query_button = gr.Button("Process Query")
    query_button.click(process_query, inputs=[query_input, csv_file], outputs=[query_output, gr.Dataframe()])

    with gr.Row():
        column_name = gr.Textbox(label="Column for Plot",placeholder="write a column name")
        plot_type = gr.Dropdown(label="Plot Type", choices=["Histogram", "Scatter Plot", "Bar Chart", "Line Plot"])
        plot_output = gr.Image(label="Plot Output")

    plot_button = gr.Button("Generate Plot")
    plot_button.click(generate_plot, inputs=[csv_file, column_name, plot_type], outputs=plot_output)

    demo.launch()

* Running on local URL:  http://127.0.0.1:7866

To create a public link, set `share=True` in `launch()`.
