In [8]:
import os
import dotenv
import psycopg2
from psycopg2 import sql
import pandas as pd
from groq import Groq

In [9]:
def get_database_connection():
    # Establishes a connection to the PostgreSQL database
    dotenv.load_dotenv()
    try:
        conn = psycopg2.connect(
            database=os.environ.get("DATABASE_NAME"),
            user=os.environ.get("DATABASE_USER"),
            password=os.environ.get("DATABASE_PASSWORD"),
            host=os.environ.get("DATABASE_HOST"),
            port=os.environ.get("DATABASE_PORT")
        )
        print("Connected to the PostgreSQL database!")
        return conn
    except (Exception, psycopg2.Error) as error:
        print("Error while connecting to PostgreSQL", error)
        return None

In [10]:
def get_table_metadata(conn, table_name):
    try:
        cursor = conn.cursor()

        # Get column names and data types
        cursor.execute(sql.SQL("""
            SELECT column_name, data_type
            FROM information_schema.columns
            WHERE table_name = %s;
        """), [table_name])
        columns = cursor.fetchall()

        # Prepare a dictionary to hold the column metadata
        table_metadata = {}
        
        for column_name, data_type in columns:
            # Convert column name to lowercase and replace blank spaces and slashes with underscores
            column_name = column_name.lower().replace(" ", "_").replace("/", "_")

            unique_values = []

            # Only fetch unique values for text columns
            if data_type == 'text':
                cursor.execute(sql.SQL("""
                    SELECT DISTINCT {column}
                    FROM {table}
                    LIMIT 10;
                """).format(
                    column=sql.Identifier(column_name),
                    table=sql.Identifier(table_name)
                ))
                unique_values = cursor.fetchall()
                # Flatten the list of unique values
                unique_values = [val[0] for val in unique_values]
            
            # Store the column metadata
            table_metadata[column_name] = {
                'data_type': data_type,
                'unique_values': unique_values
            }
        
        cursor.close()
        # print("Table metadata fetched successfully!")
        # print(table_metadata)
        return table_metadata

    except (Exception, psycopg2.Error) as error:
        print("Error while fetching table metadata", error)
        return None

In [11]:
def format_metadata(metadata):
    formatted_metadata = ""
    for col, info in metadata.items():
        formatted_metadata += f"{col}: {info['data_type']}"
        if info['data_type'] == 'text':
            formatted_metadata += f" (Unique Values: {', '.join(info['unique_values'])})"
        formatted_metadata += "\n"
    # print(f"Formatted metadata: {formatted_metadata}")
    return formatted_metadata

In [12]:
def get_llama_assistance(prompt, formatted_metadata, table_name):
    main_purpose = f"""
    As an SQL Query Expert, your primary role is to understand the given data, answer the questions based on the provided input and generate accurate SQL queries ONLY. 
    Remember, you only have to answer the Query for the given input, don't give any explanation, just the query. 
    Here are the column names with respect to their information: 
    {formatted_metadata}
    The table name is {table_name}
    Here is/are the Questions:"""

    client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
    completion = client.chat.completions.create(
        model="llama3-70b-8192",
        messages=[
            {
                "role": "user",
                "content": f"{main_purpose} {prompt}"
            },
            {
                "role": "assistant",
                "content": ""
            }
        ],
        temperature=1.4,
        max_tokens=8192,
        top_p=1,
        stream=True,
        stop=None,
    )

    response_text = ""
    for chunk in completion:
        response_text += chunk.choices[0].delta.content or ""
    
    # Remove backticks from the response
    cleaned_response = response_text.replace("```", "").strip()
    
    return cleaned_response

In [13]:
def execute_query(connection, query):
    try:
        cursor = connection.cursor()
        cursor.execute(query)
        # Fetch the result
        result = cursor.fetchall()
        # Get column names
        colnames = [desc[0] for desc in cursor.description]
        # Convert the result to a DataFrame for better display
        df = pd.DataFrame(result, columns=colnames)
        cursor.close()
        return df
    except (Exception, psycopg2.Error) as error:
        print("Error while executing the query", error)
        return None

In [14]:
connection = get_database_connection()
print("\n")
if connection:
    table_name = "disc_off"
    metadata = get_table_metadata(connection, table_name)
    formatted_metadata = format_metadata(metadata)
    # print(formatted_metadata)
    
    input_question = input("Enter the Question: ")
    print(f"Question asked: {input_question}")
    query = get_llama_assistance(input_question, formatted_metadata, table_name)
    print("Generated SQL Query:")
    print(query)
    
    print("\n")
    result_df = execute_query(connection, query)
    if result_df is not None:
        print("Query Results:")
        print(result_df.to_string(index=False))  
    
    connection.close()

Connected to the PostgreSQL database!


Question asked: List the occupations where the average medical insurance is less than the overall average medical insurance.
Generated SQL Query:
SELECT occupation
FROM disc_off
GROUP BY occupation
HAVING AVG(medical_insurance) < (SELECT AVG(medical_insurance) FROM disc_off);


Query Results:
      occupation
           Other
         Unknown
       Housewife
         Retired
Business Manager
