## Installations

In [1]:
!pip install pillow langchain_experimental "langchain[all]" langchain.tools matplotlib seaborn openpyxl pandas langchain-community langchain-experimental openai python-dotenv

Collecting langchain_experimental
  Downloading langchain_experimental-0.3.4-py3-none-any.whl (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.2/209.2 kB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain[all]
  Downloading langchain-0.3.23-py3-none-any.whl (1.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m69.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting langchain.tools
  Downloading langchain_tools-0.1.34-py3-none-any.whl (8.5 kB)
Collecting openpyxl
  Downloading openpyxl-3.1.5-py2.py3-none-any.whl (250 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m250.9/250.9 kB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
Collecting langchain-community
  Downloading langchain_community-0.3.21-py3-none-any.whl (2.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m96.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting openai
  Downloading openai-1.72.0-py3

## Import necessary Libraries

In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from dotenv import load_dotenv
from langchain_community.llms import OpenAI
from langchain_experimental.agents import create_pandas_dataframe_agent
from langchain_experimental.tools import PythonAstREPLTool
from langchain.agents import AgentType, initialize_agent, load_tools

ModuleNotFoundError: No module named 'dotenv'

## Build the SQL2text system

Upon running this application, upload the csv or excel file when prompted. Ensure its uploaded in the Files space in deepnote. Then copy its path from there and paste it in in the text box. Then the next prompt requires you to give it a name. Then you can proceed to query based on your question.

In [3]:
# Load environment variables
load_dotenv()

# Get API key
API_KEY = "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"

# Initialize tools list - we'll populate it after loading data
tools = []

# ===== DATA HANDLER CLASS =====
class DataHandler:
    def __init__(self):
        """Initialize a data handler to manage CSV datasets."""
        self.dataframes = {}  # Store loaded dataframes
        self.current_df = None  # Currently active dataframe
    

    def load_file(self, file_path, name=None, encoding=None):
        """
        Load a data file (CSV or Excel) into the system.
        
        Parameters:
            file_path (str): Path to the data file
            name (str, optional): Name to assign to the dataframe
            encoding (str, optional): File encoding to use for CSV files
        """
        try:
            # Check if file exists
            if not os.path.exists(file_path):
                return f"Error: File {file_path} not found."
            
            # Determine file type based on extension
            file_extension = os.path.splitext(file_path)[1].lower()
            
            # Load the file based on its extension
            if file_extension in ['.xlsx', '.xls', '.xlsm', '.xlsb']:
                # Handle Excel files
                try:
                    df = pd.read_excel(file_path)
                    file_type = "Excel"
                except Exception as e:
                    return f"Error loading Excel file: {str(e)}"
                    
            elif file_extension in ['.csv', '.txt', '.dat', '.tsv']:
                # Handle CSV/text files with encoding detection
                file_type = "CSV"
                if encoding is None:
                    # Try different encodings
                    encodings_to_try = ['utf-8', 'latin1', 'ISO-8859-1', 'cp1252']
                    success = False
                    
                    for enc in encodings_to_try:
                        try:
                            if file_extension == '.tsv':
                                df = pd.read_csv(file_path, encoding=enc, sep='\t')
                            else:
                                df = pd.read_csv(file_path, encoding=enc)
                            encoding = enc
                            success = True
                            break
                        except UnicodeDecodeError:
                            continue
                        except Exception as e:
                            return f"Error loading file: {str(e)}"
                    
                    if not success:
                        return "Error: Could not automatically detect file encoding. Please specify encoding parameter."
                else:
                    # Use the provided encoding
                    if file_extension == '.tsv':
                        df = pd.read_csv(file_path, encoding=encoding, sep='\t')
                    else:
                        df = pd.read_csv(file_path, encoding=encoding)
            else:
                return f"Error: Unsupported file format '{file_extension}'. Supported formats are: .csv, .txt, .tsv, .xlsx, .xls, .xlsm, .xlsb"
            
            # Use filename as name if none provided
            if name is None:
                name = os.path.basename(file_path).split('.')[0]
            
            # Store the dataframe
            self.dataframes[name] = df
            self.current_df = name
            
            encoding_info = f" using encoding: {encoding}" if file_type == "CSV" and encoding else ""
            return f"Successfully loaded '{name}' ({file_type}) with {df.shape[0]} rows and {df.shape[1]} columns{encoding_info}."
        
        except Exception as e:
            return f"Error loading data: {str(e)}"

    def get_dataframe_info(self, name=None):
        """Get information about a loaded dataframe."""
        # Use specified name or current dataframe
        df_name = name if name else self.current_df
        
        if not df_name or df_name not in self.dataframes:
            return "No dataframe selected or specified dataframe not found."
        
        df = self.dataframes[df_name]
        
        # Collect information
        info = {
            "name": df_name,
            "shape": df.shape,
            "columns": list(df.columns),
            "dtypes": {col: str(dtype) for col, dtype in zip(df.columns, df.dtypes)},
            "sample": df.head(3).to_dict()
        }
        
        return info
    
    def list_dataframes(self):
        """List all loaded dataframes."""
        if not self.dataframes:
            return "No dataframes loaded."
        
        return {
            name: {"shape": df.shape, "columns": list(df.columns)}
            for name, df in self.dataframes.items()
        }
    
    def set_current_dataframe(self, name):
        """Set the current active dataframe."""
        if name not in self.dataframes:
            return f"Error: Dataframe '{name}' not found."
        
        self.current_df = name
        return f"Current dataframe set to '{name}'."


# ===== QUERY ENGINE CLASS =====
class QueryEngine:
    def __init__(self, data_handler, temperature=0):
        """Initialize the query engine with a data handler."""
        self.data_handler = data_handler
        # Initialize the language model
        self.llm = OpenAI(temperature=temperature, openai_api_key=API_KEY)
    
    def get_visualization_tool(self, df):
        """Create a Python REPL tool with access to the dataframe and plotting libraries."""
        # Create locals dictionary with the dataframe and plotting libraries
        locals_dict = {
            "pd": pd, 
            "df": df,  # Now df is properly defined
            "plt": plt, 
            "sns": sns,
            "np": np
        }
        
        # Create and return the tool
        return PythonAstREPLTool(locals=locals_dict)
    
    def execute_query_with_viz(self, query, dataframe_name=None):
        """Execute a natural language query on a dataframe with visualization support."""
        # Get the dataframe to query
        df_name = dataframe_name if dataframe_name else self.data_handler.current_df
        
        if not df_name or df_name not in self.data_handler.dataframes:
            return "No dataframe selected or specified dataframe not found."
        
        df = self.data_handler.dataframes[df_name]
        df_info = self.data_handler.get_dataframe_info(df_name)
        
        try:
            # Check if the query is asking for a visualization
            viz_keywords = ['visualize', 'plot', 'graph', 'chart', 'display', 'show', 'histogram', 'scatter']
            is_viz_query = any(keyword in query.lower() for keyword in viz_keywords)
            
            if is_viz_query:
                # Create a plotting tool with access to this specific dataframe
                plotting_tool = self.get_visualization_tool(df)
                
                # Initialize an agent with the plotting tool
                agent = initialize_agent(
                    [plotting_tool],
                    self.llm,
                    agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
                    verbose=True
                )
                
                # Craft a prompt that guides the agent to create a visualization
                viz_prompt = f"""
                Create a visual representation of the data using matplotlib or seaborn.
                
                Dataframe information:
                - Name: {df_name}
                - Shape: {df_info['shape'][0]} rows × {df_info['shape'][1]} columns
                - Columns: {', '.join(df_info['columns'])}
                
                User query: {query}
                
                Follow these steps:
                1. Analyze what kind of visualization would best answer the query
                2. Create the visualization using matplotlib or seaborn
                3. Make sure to include:
                   - Appropriate title and axis labels
                   - Legend if multiple series are shown
                   - Use plt.tight_layout() for better spacing
                   - Use plt.savefig('visualization.png') to save the visualization
                
                Make sure to execute the code to generate the visualization.
                Finally, provide a brief explanation of what the visualization shows.
                """
                
                # Run the agent
                result = agent.run(viz_prompt)
                
                # Return both the textual result and reference to saved visualization
                return {
                    'text_result': result,
                    'visualization_path': 'visualization.png',
                    'visualization_type': 'file'
                }
            
            # For non-visualization queries, use the standard pandas agent
            else:
                agent = create_pandas_dataframe_agent(
                    self.llm, 
                    df, 
                    verbose=True,
                    allow_dangerous_code=True
                )
                
                # Enhanced query with context
                enhanced_query = f"""
                Working with dataframe: {df_name}
                
                Dataframe information:
                - Shape: {df_info['shape'][0]} rows × {df_info['shape'][1]} columns
                - Columns: {', '.join(df_info['columns'])}
                
                User query: {query}
                
                Please provide:
                1. A clear answer to the query
                2. Any calculations or reasoning used
                3. A brief explanation of the result
                """
                
                # Run the query
                result = agent.run(enhanced_query)
                return {'text_result': result}
                
        except Exception as e:
            return f"Error executing query: {str(e)}"
    
    def execute_multi_table_query(self, query, dataframe_names):
        """Execute a query that references multiple tables."""
        # Check if specified dataframes exist
        missing_dfs = [name for name in dataframe_names if name not in self.data_handler.dataframes]
        if missing_dfs:
            return f"Error: Dataframes not found: {', '.join(missing_dfs)}"
        
        # Get dataframes and their info
        dfs = {name: self.data_handler.dataframes[name] for name in dataframe_names}
        dfs_info = {name: self.data_handler.get_dataframe_info(name) for name in dataframe_names}
        
        try:
            # Create a combined dataframe with prefixed column names
            combined_data = {}
            for name, df in dfs.items():
                for col in df.columns:
                    combined_data[f"{name}_{col}"] = df[col]
            
            combined_df = pd.DataFrame(combined_data)
            
            # Create an agent with the combined dataframe
            agent = create_pandas_dataframe_agent(
                self.llm, 
                combined_df, 
                verbose=True,
                allow_dangerous_code=True
            )
            
            # Prepare dataframe information for the prompt
            dfs_info_text = "\n".join([
                f"Dataframe '{name}':\n- Columns: {', '.join(info['columns'])}\n- Shape: {info['shape'][0]} rows × {info['shape'][1]} columns"
                for name, info in dfs_info.items()
            ])
            
            # Enhance the query
            enhanced_query = f"""
            Working with multiple dataframes: {', '.join(dataframe_names)}
            
            {dfs_info_text}
            
            I've created a combined dataframe where columns are prefixed with the dataframe name.
            For example, to access the 'price' column from the 'products' dataframe, use 'products_price'.
            
            User query: {query}
            
            Please provide:
            1. A clear answer to the query
            2. Any calculations or reasoning used
            3. A brief explanation of the result
            """
            
            # Run the query
            result = agent.run(enhanced_query)
            return result
            
        except Exception as e:
            return f"Error executing multi-table query: {str(e)}"


# ===== MAIN LLM QUERY SYSTEM CLASS =====
class LLMQuerySystem:
    def __init__(self):
        """Initialize the LLM Query System."""
        self.data_handler = DataHandler()
        self.query_engine = QueryEngine(self.data_handler)
    
    def list_loaded_data(self):
        """List all loaded dataframes."""
        return self.data_handler.list_dataframes()
    
    def load_data(self, file_path, name=None, encoding=None):
        """Load a data file (CSV or Excel) into the system."""
        return self.data_handler.load_file(file_path, name, encoding)

    def set_active_dataframe(self, name):
        """Set the active dataframe for queries."""
        return self.data_handler.set_current_dataframe(name)
    
    def execute_query(self, query, dataframe=None):
        """Execute a natural language query on a dataframe."""
        return self.query_engine.execute_query(query, dataframe)
    
    def execute_multi_dataframe_query(self, query, dataframes):
        """Execute a query across multiple dataframes."""
        return self.query_engine.execute_multi_table_query(query, dataframes)
    
    def get_dataframe_sample(self, name=None, rows=5):
        """Get a sample of rows from a dataframe."""
        df_name = name if name else self.data_handler.current_df
        
        if not df_name or df_name not in self.data_handler.dataframes:
            return "No dataframe selected or specified dataframe not found."
        
        df = self.data_handler.dataframes[df_name]
        return df.head(rows).to_dict()

    def execute_query(self, query, dataframe=None):
        """Execute a natural language query on a dataframe, with visualization if appropriate."""
        result = self.query_engine.execute_query_with_viz(query, dataframe)
        return result


    

# ===== INTERACTIVE MODE =====
def run_interactive():
    """Run the system in interactive mode with improved user experience."""
    system = LLMQuerySystem()
    
    print("Welcome to the LLM Query System!")
    
    loaded_files = False
    
    while True:
        # If no files are loaded yet, prompt for file loading
        if not loaded_files and not system.data_handler.dataframes:
            print("\nYou need to load at least one data file to begin.")
            file_paths_input = input("Enter file path(s) separated by commas (or 'exit' to quit): ").strip()
            
            if file_paths_input.lower() == 'exit':
                print("Exiting system. Goodbye!")
                break
                
            # Process multiple file paths
            file_paths = [path.strip() for path in file_paths_input.split(',')]
            all_loaded = True
            
            for file_path in file_paths:
                # Extract filename as default name
                default_name = os.path.basename(file_path).split('.')[0]
                name_input = input(f"Enter name for {file_path} (press Enter to use '{default_name}'): ").strip()
                name = name_input if name_input else default_name
                
                result = system.load_data(file_path, name)
                print(result)
                
                if "Error" in result:
                    all_loaded = False
            
            if all_loaded and system.data_handler.dataframes:
                loaded_files = True
                # Display loaded dataframes
                print("\nLoaded dataframes:")
                dfs = system.list_loaded_data()
                for name, info in dfs.items():
                    print(f"  {name}: {info['shape'][0]} rows × {info['shape'][1]} columns")
                    
                # Set active dataframe if only one is loaded
                if len(system.data_handler.dataframes) == 1:
                    system.set_active_dataframe(list(system.data_handler.dataframes.keys())[0])
                    print(f"Active dataframe set to '{system.data_handler.current_df}'")
                else:
                    name = input("Enter name of dataframe to use as active: ").strip()
                    print(system.set_active_dataframe(name))
            
        # Main command loop when files are loaded
        else:
            # If files are loaded, prompt for query or command
            if system.data_handler.current_df:
                prompt = f"\nEnter query for '{system.data_handler.current_df}' (or 'exit' to quit): "
            else:
                prompt = "\nNo active dataframe selected. Enter 'exit' to quit: "
                
            command = input(prompt).strip()
            
            # Process the command
            if command.lower() == 'exit':
                print("Exiting system. Goodbye!")
                break
                
            elif command.lower().startswith('load '):
                parts = command[5:].strip().split(' ', 1)
                file_path = parts[0]
                name = parts[1] if len(parts) > 1 else None
                print(system.load_data(file_path, name))
                
            elif command.lower() == 'list':
                dfs = system.list_loaded_data()
                if isinstance(dfs, str):
                    print(dfs)
                else:
                    print("\nLoaded dataframes:")
                    for name, info in dfs.items():
                        print(f"  {name}: {info['shape'][0]} rows × {info['shape'][1]} columns")
                
            elif command.lower().startswith('use '):
                name = command[4:].strip()
                print(system.set_active_dataframe(name))
                
            elif command.lower().startswith('sample '):
                parts = command[7:].strip().split(' ', 1)
                name = parts[0] if parts and parts[0] else None
                rows = int(parts[1]) if len(parts) > 1 else 5
                sample = system.get_dataframe_sample(name, rows)
                
                # Display sample
                if isinstance(sample, str):
                    print(sample)
                else:
                    print(f"\nSample from dataframe:")
                    # Convert the dict to a more readable format
                    for i in range(rows):
                        if i in sample[list(sample.keys())[0]]:
                            row_data = {col: data[i] for col, data in sample.items()}
                            print(f"Row {i}: {row_data}")
                

            elif system.data_handler.current_df:
                # Execute a query
                print("\nExecuting query...")
                result = system.execute_query(command)
                
                # Check if result is a dictionary (contains visualization)
                if isinstance(result, dict):
                    # Print the text result
                    print("\nResult:")
                    print(result.get('text_result', ''))
                    
                    # Check if we have a visualization
                    if 'visualization_path' in result:
                        print(f"\nVisualization created and saved to {result['visualization_path']}")
                            
                else:
                    print("\nResult:")
                    print(result)


# Add a simple plotting wrapper
def get_plot_code(df, query):
    """Generate code to create a matplotlib or seaborn plot based on the query."""
    plotting_agent = initialize_agent(
        tools,
        self.llm,
        agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
        verbose=True
    )
    
    prompt = f"""
    Create Python code to visualize this data with matplotlib or seaborn.
    
    The dataframe is called 'df' and has these columns: {', '.join(df.columns)}
    The query is: {query}
    
    Return only the plotting code as a Python code block.
    """
    
    result = plotting_agent.run(prompt)
    return result

# ===== MAIN EXECUTION =====
if __name__ == "__main__":
    # Check if OPENAI_API_KEY is set
    if not API_KEY:
        print("Error: OPENAI_API_KEY not found in environment variables.")
        print("Please create a .env file with your API key or set it directly.")
        exit(1)
    
    # Run in interactive mode directly
    run_interactive()

  self.llm = OpenAI(temperature=temperature, openai_api_key=API_KEY)
Welcome to the LLM Query System!

You need to load at least one data file to begin.
Successfully loaded 'the' (CSV) with 1656 rows and 14 columns using encoding: utf-8.

Loaded dataframes:
  the: 1656 rows × 14 columns
Active dataframe set to 'the'

Executing query...


[1m> Entering new AgentExecutor chain...[0m
  agent = initialize_agent(
  result = agent.run(viz_prompt)
[32;1m[1;3m I should use matplotlib or seaborn to create a visualization.
Action: [python_repl_ast]
Action Input: import matplotlib.pyplot as plt[0m
Observation: [python_repl_ast] is not a valid tool, try one of [python_repl_ast].
Thought:[32;1m[1;3m I should use matplotlib or seaborn to create a visualization.
Action: [python_repl_ast]
Action Input: import matplotlib.pyplot as plt[0m
Observation: [python_repl_ast] is not a valid tool, try one of [python_repl_ast].
Thought:[32;1m[1;3m I should use matplotlib or seaborn to create a visualiz

KeyboardInterrupt: Interrupted by user