<a href="https://colab.research.google.com/github/hajraanwar/NLP_to_SQL/blob/main/nlp_to_sql_github.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!apt-get install libaio1
!pip install cx_Oracle


In [None]:
!mkdir -p /opt/oracle
!wget https://download.oracle.com/otn_software/linux/instantclient/211000/instantclient-basic-linux.x64-21.1.0.0.0.zip
!unzip -o instantclient-basic-linux.x64-21.1.0.0.0.zip
!cp -r instantclient_21_1/* /opt/oracle/
!sh -c "echo /opt/oracle > /etc/ld.so.conf.d/oracle-instantclient.conf"
!ldconfig


In [None]:
import os
os.environ["LD_LIBRARY_PATH"] = "/opt/oracle"
import cx_Oracle
cx_Oracle.init_oracle_client(lib_dir="/opt/oracle")


In [None]:
def get_db_connection():
    host = "HOST.com"
    port = 1521
    username = "YOUR DATABASE USERNAME"
    password = "YOUR  DATABASE PASSWORD"
    database_name = "DB NAME"

    dsn = cx_Oracle.makedsn(host, port, service_name=database_name)
    return cx_Oracle.connect(user=username, password=password, dsn=dsn)


Cell 5: Save the Extracted Schema

In [None]:
def extract_schema():
    schema = {}
    try:
        connection = get_db_connection()
        cursor = connection.cursor()

        cursor.execute("""
            SELECT table_name
            FROM all_tables
            WHERE owner = 'ATOMCAMP'
        """)
        tables = cursor.fetchall()

        for table in tables:
            table_name = table[0]
            cursor.execute("""
                SELECT column_name, data_type
                FROM all_tab_columns
                WHERE owner = 'ATOMCAMP'
                AND table_name = :table_name
            """, {'table_name': table_name})

            schema[table_name] = cursor.fetchall()

        cursor.close()
        connection.close()
    except cx_Oracle.DatabaseError as e:
        print("Database error:", e)
    return schema

schema = extract_schema()
print("Extracted Schema:", schema)


In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def load_llm():
    os.environ["HUGGINGFACE_HUB_TOKEN"] = "YOUR HUGGING FACE TOKEN"
    model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN"))
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_auth_token=os.getenv("HUGGINGFACE_HUB_TOKEN"),
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    return tokenizer, model

tokenizer, model = load_llm()


In [None]:
import re


In [None]:
def generate_sql_query(schema, user_query, tokenizer, model):
    """
    Generate an SQL query based on the schema and user query.

    :param schema: Dictionary of tables and columns
    :param user_query: Natural language query from the user
    :param tokenizer: Tokenizer for the LLM
    :param model: LLM for SQL generation
    :return: A valid SQL query as a string
    """
    # Restrict schema to the first 10 tables
    first_10_tables = dict(list(schema.items())[:10])
    schema_description = "\n\n".join(
        [f"Table: ATOMCAMP.{table}\nColumns: {', '.join([col[0] for col in columns])}" for table, columns in first_10_tables.items()]
    )

    prompt = f"""
    You are an SQL expert. Generate a valid SQL query using the provided schema.
    Ensure that the schema name 'ATOMCAMP.' is prefixed to all table names in the query.

    Schema:
    {schema_description}

    User Query:
    {user_query}

    Generate the SQL query only in your response. Ensure the SQL query is syntactically valid for Oracle databases.
    """

    # Tokenizer configuration
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024
    ).to(model.device)

    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=150,
        temperature=0.5,
        top_p=0.8
    )

    raw_response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    print("Raw LLM Output:\n", raw_response)  # Debugging: Print the raw output

    # Extract the first valid SQL query using regex
    sql_match = re.search(r"SELECT\s.*?FROM\s.*?(WHERE\s.*?|);", raw_response, re.DOTALL | re.IGNORECASE)

    if sql_match:
        # Ensure the query ends properly for Oracle
        generated_query = sql_match.group(0).strip()
        if generated_query.endswith(";"):
            generated_query = generated_query[:-1]  # Remove trailing semicolon
        return generated_query
    else:
        print("Model Output Debugging: The model did not generate a valid SQL query.")
        raise ValueError("Failed to extract a valid SQL query from the model's output.")


In [None]:
def execute_sql_query(connection_details, schema, user_query, tokenizer, model):
    """
    Generate and execute an SQL query for the first 10 tables.

    :param connection_details: Dictionary containing database connection parameters
    :param schema: Dictionary containing table schema
    :param user_query: The natural language query provided by the user
    :param tokenizer: Tokenizer for the LLM
    :param model: The LLM model
    :return: Query results or error messages
    """
    try:
        # Generate the SQL query
        generated_query = generate_sql_query(schema, user_query, tokenizer, model)
        print("Generated SQL Query:\n", generated_query)

        # Establish database connection
        connection = cx_Oracle.connect(
            user=connection_details["username"],
            password=connection_details["password"],
            dsn=cx_Oracle.makedsn(
                connection_details["host"],
                connection_details["port"],
                service_name=connection_details["database_name"]
            )
        )
        cursor = connection.cursor()

        # Execute the SQL query
        print("Executing SQL Query...")
        cursor.execute(generated_query)
        results = cursor.fetchall()

        # Display the results
        print("Query Results:")
        if results:
            for row in results:
                print(row)
        else:
            print("No results found.")

        # Close cursor and connection
        cursor.close()
        connection.close()

        return results
    except cx_Oracle.DatabaseError as e:
        print("Database error:", e)
        return None
    except ValueError as ve:
        print(f"Query generation error: {ve}")
        return None
    except Exception as ex:
        print("An unexpected error occurred:", ex)
        return None


In [None]:
def generate_natural_language_response(user_query, sql_query, query_results, tokenizer, model):
    """
    Generate a natural language explanation for the SQL query results.
    """
    results_description = (
        "No results found." if not query_results else
        "\n".join([", ".join(map(str, row)) for row in query_results])
    )

    prompt = f"""
    User Query: {user_query}
    SQL Query: {sql_query}
    Query Results:
    {results_description}

    Provide a human-readable explanation of the results. Respond in a natural language style, summarizing the query and the key results in a way that the user can understand.
    """

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=1024
    ).to(model.device)

    outputs = model.generate(
        inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=150,
        temperature=0.7,
        top_p=0.9
    )

    natural_language_response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    return natural_language_response


def execute_and_explain_query(connection_details, schema, user_query, tokenizer, model):
    """
    Generate, execute, and explain an SQL query based on user input.
    """
    try:
        # Step 1: Generate SQL query
        sql_query = generate_sql_query(schema, user_query, tokenizer, model)
        print("Generated SQL Query:\n", sql_query)

        # Step 2: Execute SQL query
        connection = cx_Oracle.connect(
            user=connection_details["username"],
            password=connection_details["password"],
            dsn=cx_Oracle.makedsn(
                connection_details["host"],
                connection_details["port"],
                service_name=connection_details["database_name"]
            )
        )
        cursor = connection.cursor()
        print("Executing SQL Query...")
        cursor.execute(sql_query)
        query_results = cursor.fetchall()

        # Close database resources
        cursor.close()
        connection.close()

        print("Query Results:", query_results)

        # Step 3: Generate natural language response
        explanation = generate_natural_language_response(user_query, sql_query, query_results, tokenizer, model)
        print("Natural Language Explanation:\n", explanation)

        return query_results, explanation

    except cx_Oracle.DatabaseError as e:
        print("Database error:", e)
        return None, f"Database error occurred: {e}"
    except ValueError as ve:
        print(f"Query generation error: {ve}")
        return None, f"Query generation error: {ve}"
    except Exception as ex:
        print("An unexpected error occurred:", ex)
        return None, f"An unexpected error occurred: {ex}"


In [None]:
# Define connection details
connection_details = {
    "host": "atomcamp.cdk6ie0wqi3p.ap-southeast-2.rds.amazonaws.com",
    "port": 1521,
    "username": "admin",
    "password": "4Hn_tc]5jj3uFu>nFMAkQzEu02EX",
    "database_name": "atomcamp"
}


# Announcements made in the year 2025
# In the announcements table dividend less than 1000
# In the announcements table bonus_share greater than  5000, or less than 1000

# First 10 ANNUAL_INCOME_ID from ANNUAL_INCOME table

# Fetch details of Rabobank bank from the Banks table

# Fetch details of Riverside Plaza from branch table

# Country code greater than 50 in City table

# Clearing Type with cash in clearing_calendar

# Package type carton in the client table
# Female/male gender entries in the client table

# BANK_ACCOUNT_TITLES named  David Green in the client_bank_info table


# Example user query
user_query = "In the announcements table dividend less than 1000"

# Execute the query and get results with explanation
query_results, explanation = execute_and_explain_query(connection_details, schema, user_query, tokenizer, model)

# Display the natural language explanation
print("\nHuman-readable Explanation:")
print(explanation)



In [None]:
!pip install gradio

In [None]:
import gradio as gr

# Initialize chat history
chat_history = []

def gpt_chatbot(user_query):
    global chat_history

    # Generate the model response (replace `main_agent` with your actual function)
    model_response = f"Response to: {user_query}"  # Mock response for testing

    # Append the user query and assistant response to the chat history
    chat_history.append((f"You: {user_query}", f"Assistant: {model_response}"))

    # Return updated chat history
    return chat_history

def clear_chat():
    global chat_history
    chat_history = []
    return []

# Gradio Interface
with gr.Blocks() as demo:
    with gr.Row():
        gr.Markdown("### GPT-Style Dynamic Chat Interface")

    # Dynamic chat history display
    with gr.Column(elem_id="chat_column"):
        chatbox = gr.Chatbot(label="Chat", value=[], elem_id="chat_display")

    # Input box and buttons at the bottom
    with gr.Row():
        user_input = gr.Textbox(
            show_label=False, placeholder="Type your question here...", lines=1
        )
        submit_button = gr.Button("Submit")
        clear_button = gr.Button("Clear Chat")

    # Define button actions
    submit_button.click(
        gpt_chatbot, inputs=[user_input], outputs=[chatbox]
    )
    clear_button.click(clear_chat, outputs=[chatbox])

# Launch the interface
demo.launch()
