# Import all libs

In [None]:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain import LLMChain, HuggingFacePipeline, PromptTemplate
import pandas as pd
from sqlalchemy import create_engine, text, inspect
import re

In [None]:
!huggingface-cli login

# Download model

In [None]:
# Load Llama model
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, max_length=500)  # Use CPU

# Function to create database

In [None]:
DATABASE_URL = "sqlite:///users.db"
engine = create_engine(DATABASE_URL)

with engine.connect() as conn:
    conn.execute(text("""
        CREATE TABLE IF NOT EXISTS users (
            id INTEGER PRIMARY KEY,
            name TEXT,
            email TEXT UNIQUE,
            signup_date DATE,
            status TEXT CHECK(status IN ('active', 'inactive')),
            age INTEGER,
            country TEXT,
            last_login DATE,
            membership_type TEXT CHECK(membership_type IN ('free', 'premium', 'enterprise'))
        )
    """))

    conn.execute(text("""
        INSERT INTO users (name, email, signup_date, status, age, country, last_login, membership_type) VALUES
        ('Alice', 'alice@example.com', '2023-01-10', 'active', 28, 'USA', '2024-03-01', 'premium'),
        ('Bob', 'bob@example.com', '2022-12-15', 'inactive', 35, 'Canada', '2023-12-20', 'free'),
        ('Charlie', 'charlie@example.com', '2023-03-22', 'active', 22, 'UK', '2024-02-25', 'enterprise'),
        ('David', 'david@example.com', '2021-07-30', 'active', 40, 'Germany', '2024-01-15', 'premium'),
        ('Emma', 'emma@example.com', '2022-09-10', 'inactive', 30, 'France', '2023-11-05', 'free'),
        ('Frank', 'frank@example.com', '2023-05-18', 'active', 27, 'USA', '2024-02-20', 'free'),
        ('Grace', 'grace@example.com', '2021-12-05', 'inactive', 33, 'Australia', '2023-09-10', 'enterprise'),
        ('Henry', 'henry@example.com', '2022-04-15', 'active', 45, 'India', '2024-02-28', 'premium'),
        ('Ivy', 'ivy@example.com', '2023-06-22', 'inactive', 29, 'Canada', '2023-12-15', 'free'),
        ('Jack', 'jack@example.com', '2020-11-02', 'active', 38, 'UK', '2024-03-03', 'enterprise')
    """))

    conn.commit()

print("Database initialized successfully!")

In [None]:
# create a function to save the schema with all details in chroma db. 
# this will be used for better query execution as LLMs will have context in the form of descriptions
def get_schema_description():
    inspector = inspect(engine)
    table_name = "users"

    columns = inspector.get_columns(table_name)

    schema_info = []

    for col in columns:
        column_name = col['name']
        column_type = str(col['type'])

        description = {
            "id": "Unique identifier for each user.",
            "name": "Full name of the user.",
            "email": "Unique email address for the user.",
            "signup_date": "Date when the user signed up.",
            "status": "Account status, either 'active' or 'inactive'.",
            "age": "Age of the user in years.",
            "country": "Country where the user resides.",
            "last_login": "Date when the user last logged in.",
            "membership_type": "Membership type: 'free', 'premium', or 'enterprise'."
        }.get(column_name, "No description available.")

        schema_info.append({
            "column": column_name,
            "type": column_type,
            "description": description
        })

    return schema_info

# Extract schema details
schema_descriptions = get_schema_description()
for item in schema_descriptions:
    print(f"{item['column']} ({item['type']}): {item['description']}")


# Create vector database to store table description and query rules

In [None]:
import chromadb
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.schema import Document


chroma_client = chromadb.PersistentClient(path="./chroma_db")  
collection = chroma_client.get_or_create_collection(name="schema_metadata")

for item in schema_descriptions:
    metadata = {"column": item["column"], "type": item["type"]}
    collection.add(
        ids=[item["column"]],
        documents=[item["description"]],
        metadatas=[metadata]
    )

print("Schema descriptions stored in ChromaDB successfully!")

In [None]:
def retrieve_schema_info(query):
    results = collection.query(
        query_texts=[query],
        n_results=3  
    )

    relevant_schema = []
    for doc, meta in zip(results["documents"][0], results["metadatas"][0]):
        relevant_schema.append(f"{meta['column']} ({meta['type']}): {doc}")

    return relevant_schema

# LLM query creation

In [None]:
def generate_sql_query(question):
    # retrieve schema details from ChromaDB
    schema_info = retrieve_schema_info(question)
    schema_context = "\n".join(schema_info)

    # enhanced prompt with schema awareness
    prompt = f"""
    You are an AI SQL assistant. Convert natural language questions into SQL queries.

    Below is the database schema information:
    {schema_context}

    **Important Instructions:**
    - If the query includes a continent (e.g., "Asia", "Europe"), replace it with a list of actual countries.
    - Use proper SQL syntax.
    - The column 'country' stores full country names (e.g., 'India', 'Germany').

    Question: {question}
    SQL Query:
    """

    # response from llama models 
    response = pipe(prompt, max_length=1000, truncation=True)[0]["generated_text"]

    # Only the SQL part
    sql_query = extract_sql_query(response)

    #sql_query = response.split("SQL Query:")[-1].strip()
    #print(sql_query)

    #sql_query = sql_query.split("\n")[0].strip()
    #print(sql_query)

    return sql_query


In [None]:
def extract_sql_query(response):
    """
    Extracts a valid SQL query from the LLM's generated response using regex.
    """
    sql_pattern = re.findall(r"SELECT.*?;", response, re.DOTALL | re.IGNORECASE)

    if sql_pattern:
        return sql_pattern[0].strip()  
    else:
        return "ERROR: No valid SQL query generated."

In [None]:
#query = "Find all premium users from the USA who logged in 2024"
#query = "Find all inactive users"
#query = "Find all records where age is more than 40 and belongs to India"
query = "Find all records where age is more than 40 and country belongs to Asia" ## fixed by prompt engg


schema_info = retrieve_schema_info(query)

print("\nRelevant Schema Information:")
for info in schema_info:
    print(info)

# Example Usage
#query = "Find all premium users from the USA who logged in this year."
generated_sql = generate_sql_query(query)
#print("Generated SQL:", generated_sql)

# Agent to run sql codes

In [None]:
def run_sql_query(sql_query):
    try:
        with engine.connect() as conn:
            result = conn.execute(text(sql_query))
            return result.fetchall()
    except Exception as e:
        return f"SQL Execution Error: {e}"

# Execute llama generated SQL
sql_result = run_sql_query(generated_sql)
print("SQL Result:", sql_result)