In [None]:
# This code has been run and tested on Kaggle

In [None]:
!pip install pyngrok
!pip install git+https://github.com/huggingface/transformers.git@main bitsandbytes accelerate==0.28.0  # we need latest transformers for this
!pip install peft==0.7.1
!pip install datasets==2.15.0
!pip install wandb
!pip install torch
!pip install numpy==1.23.2
!pip install psycopg2-binary

In [None]:
# add the huggingface token
!huggingface-cli login --token <huggingface_token>

In [None]:
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer

model_name = "apoorv-kr/CodeLlama_Text-to-SQL"
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    load_in_8bit=True,
    torch_dtype=torch.float16,
    device_map="auto",
    config=BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_4bit_quant_type='nf4',
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16
    )
)
tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-hf")

In [None]:
def get_sql(query: str, context: str) -> str:
    eval_prompt = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

You must output the SQL query that answers the question.
### Input:
{query}

### Context:
{context}

### Response:
"""
    model
    model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
    model.eval()
    with torch.no_grad():
        result = tokenizer.decode(model.generate(**model_input, max_new_tokens=100)[0], skip_special_tokens=True)
    
    """
    the result returned will be of the format:
    You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.

    You must output the SQL query that answers the question.
    ### Input:
    {query}

    ### Context:
    {context}

    ### Response:
    {response}
    """
    # we need to extract the response from the result
    response = result.split("### Response:")[1].strip()

    # remove anything after the response (which will be the next ###)
    response = response.split("#")[0].strip()

    # if the response does not contain ; at the end, we add it
    if not response.endswith(";"):
        response += ";"

    return response


In [None]:
# add the ngrok token
!ngrok config add-authtoken <ngrok_token>

In [None]:
import psycopg2
from psycopg2 import OperationalError

# Function to connect to the PostgreSQL server       
def connect_to_remote_postgresql_url(url):
    try:
        # Attempting to connect to the PostgreSQL server
        with psycopg2.connect(url) as conn:
            print('Connected to the PostgreSQL server.')
            return conn  # Return the connection object upon successful connection
    except (psycopg2.DatabaseError, Exception) as error:
        print(error)  # Print any exceptions or errors that occur during connection attempt

# insert the database url here
database_url = "database_url"
conn = connect_to_remote_postgresql_url(database_url)

In [None]:
# function to create the following tables
""" 

-- Table for Departments
CREATE TABLE Departments (
    DepartmentID INT PRIMARY KEY,
    DepartmentName VARCHAR(100)
);
 
-- Table for Employees
CREATE TABLE Employees (
    EmployeeID INT PRIMARY KEY,
    FirstName VARCHAR(50),
    LastName VARCHAR(50),
    DepartmentID INT references Departments(DepartmentID),
    Position VARCHAR(50),
    Salary DECIMAL(10, 2),
    HireDate DATE
);


-- Table for Projects
CREATE TABLE Projects (
    ProjectID INT PRIMARY KEY,
    ProjectName VARCHAR(100),
    StartDate DATE,
    EndDate DATE
);

-- Table for Assignments
CREATE TABLE Assignments (
    AssignmentID INT PRIMARY KEY,
    EmployeeID INT,
    ProjectID INT,
    AssignmentDate DATE,
    FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID),
    FOREIGN KEY (ProjectID) REFERENCES Projects(ProjectID)
);

-- Table for Salaries
CREATE TABLE Salaries (
    SalaryID INT PRIMARY KEY,
    EmployeeID INT,
    Salary DECIMAL(10, 2),
    EffectiveDate DATE,
    FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID)
);

-- Table for Benefits
CREATE TABLE Benefits (
    BenefitID INT PRIMARY KEY,
    EmployeeID INT,
    BenefitType VARCHAR(100),
    BenefitAmount DECIMAL(10, 2),
    FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID)
);
"""
def create_tables():
    query = """
    CREATE TABLE Departments (
        DepartmentID INT PRIMARY KEY,
        DepartmentName VARCHAR(100)
    );
     
    CREATE TABLE Employees (
        EmployeeID INT PRIMARY KEY,
        FirstName VARCHAR(50),
        LastName VARCHAR(50),
        DepartmentID INT references Departments(DepartmentID),
        Position VARCHAR(50),
        Salary DECIMAL(10, 2),
        HireDate DATE
    );
    
    CREATE TABLE Projects (
        ProjectID INT PRIMARY KEY,
        ProjectName VARCHAR(100),
        StartDate DATE,
        EndDate DATE
    );
    
    CREATE TABLE Assignments (
        AssignmentID INT PRIMARY KEY,
        EmployeeID INT,
        ProjectID INT,
        AssignmentDate DATE,
        FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID),
        FOREIGN KEY (ProjectID) REFERENCES Projects(ProjectID)
    );
    
    CREATE TABLE Salaries (
        SalaryID INT PRIMARY KEY,
        EmployeeID INT,
        Salary DECIMAL(10, 2),
        EffectiveDate DATE,
        FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID)
    );
    
    CREATE TABLE Benefits (
        BenefitID INT PRIMARY KEY,
        EmployeeID INT,
        BenefitType VARCHAR(100),
        BenefitAmount DECIMAL(10, 2),
        FOREIGN KEY (EmployeeID) REFERENCES Employees(EmployeeID)
    );
    """
    with conn.cursor() as cursor:
        cursor.execute(query)
    conn.commit()
    print("Tables created successfully")

# create the tables
#create_tables()

In [None]:
# function to insert sample values in the tables
def insert_sample_values():
    query = """
    INSERT INTO Departments (DepartmentID, DepartmentName) VALUES
    (1, 'CSE'),
    (2, 'ECE'),
    (3, 'ME');
    
    INSERT INTO Employees (EmployeeID, FirstName, LastName, DepartmentID, Position, Salary, HireDate) VALUES
    (1, 'John', 'Doe', 1, 'Manager', 10000.00, '2022-01-01'),
    (2, 'Jane', 'Smith', 2, 'Secretary', 8000.00, '2022-02-01'),
    (3, 'Bob', 'Johnson', 3, 'Supervisor', 9000.00, '2022-03-01'),
    (4, 'Alice', 'Williams', 1, 'Advisor', 9500.00, '2022-04-01'),
    (5, 'Charlie', 'Brown', 2, 'Coordinator', 8500.00, '2022-05-01');
    
    INSERT INTO Projects (ProjectID, ProjectName, StartDate, EndDate) VALUES
    (1, 'Project1', '2022-01-01', '2022-02-01'),
    (2, 'Project2', '2022-02-01', '2022-03-01'),
    (3, 'Project3', '2022-03-01', '2022-04-01'),
    (4, 'Project4', '2022-04-01', '2022-05-01'),
    (5, 'Project5', '2022-05-01', '2022-06-01');
    
    INSERT INTO Assignments (AssignmentID, EmployeeID, ProjectID, AssignmentDate) VALUES
    (1, 1, 1, '2022-01-01'),
    (2, 2, 2, '2022-02-01'),
    (3, 3, 3, '2022-03-01'),
    (4, 4, 4, '2022-04-01'),
    (5, 5, 5, '2022-05-01');
    
    INSERT INTO Salaries (SalaryID, EmployeeID, Salary, EffectiveDate) VALUES
    (1, 1, 10000.00, '2022-01-01'),
    (2, 2, 8000.00, '2022-02-01'),
    (3, 3, 9000.00, '2022-03-01'),
    (4, 4, 9500.00, '2022-04-01'),
    (5, 5, 8500.00, '2022-05-01');

    INSERT INTO Benefits (BenefitID, EmployeeID, BenefitType, BenefitAmount) VALUES
    (1, 1, 'Health', 500.00),
    (2, 2, 'Dental', 200.00),
    (3, 3, 'Vision', 300.00),
    (4, 4, 'Health', 400.00),
    (5, 5, 'Dental', 100.00);
    """
    with conn.cursor() as cursor:
        cursor.execute(query)
    conn.commit()
    print("Sample values inserted")

#insert_sample_values()

In [None]:
# function to get all the tables and their columns in the database, as a dictionary with key as table name and value as list of columns
def get_tables_columns():
    query = """
    SELECT table_name, column_name
    FROM information_schema.columns
    WHERE table_schema = 'public'
    ORDER BY table_name, ordinal_position;
    """
    with conn.cursor() as cursor:
        cursor.execute(query)
        tables_columns = cursor.fetchall()
    tables_columns_dict = {}
    for table, column in tables_columns:
        if table in tables_columns_dict:
            tables_columns_dict[table].append(column)
        else:
            tables_columns_dict[table] = [column]

    # do not return the table named "pg_stat_statements" as it is a system table
    if "pg_stat_statements" in tables_columns_dict:
        del tables_columns_dict["pg_stat_statements"]
    
    return tables_columns_dict

# function to get the tables, their columns and its data types in the database, as a dictionary with key as table name and value as dictionary with key as column name and value as data type
def get_tables_columns_datatypes():
    query = """
    SELECT table_name, column_name, data_type
    FROM information_schema.columns
    WHERE table_schema = 'public'
    ORDER BY table_name, ordinal_position;
    """
    with conn.cursor() as cursor:
        cursor.execute(query)
        tables_columns_datatypes = cursor.fetchall()
    tables_columns_datatypes_dict = {}
    for table, column, datatype in tables_columns_datatypes:
        if table in tables_columns_datatypes_dict:
            tables_columns_datatypes_dict[table][column] = datatype
        else:
            tables_columns_datatypes_dict[table] = {column: datatype}
    return tables_columns_datatypes_dict

# function to get the foreign key references (if any) in the given table
def get_foreign_key_references(table):
    query = f"""
    SELECT
        tc.table_name, kcu.column_name, ccu.table_name AS foreign_table_name, ccu.column_name AS foreign_column_name
    FROM
        information_schema.table_constraints AS tc
        JOIN information_schema.key_column_usage AS kcu
          ON tc.constraint_name = kcu.constraint_name
          AND tc.table_schema = kcu.table_schema
        JOIN information_schema.constraint_column_usage AS ccu
          ON ccu.constraint_name = tc.constraint_name
          AND ccu.table_schema = tc.table_schema
    WHERE tc.table_name = '{table}' AND tc.constraint_type = 'FOREIGN KEY';
    """
    with conn.cursor() as cursor:
        cursor.execute(query)
        foreign_keys = cursor.fetchall()
    return foreign_keys

# function to create context string given the table names as list
# context string is of format "CREATE TABLE table1 (col1 datatype1, col2 datatype2, foreign key references[if any]); CREATE TABLE table2 (col1 datatype1, col2 datatype2, foreign key references[if any]);" and so on
def create_context_string(tables):
    context_string = ""
    for table in tables:
        context_string += f"CREATE TABLE {table} ("
        columns = get_tables_columns_datatypes()[table]
        for column, datatype in columns.items():
            context_string += f"{column} {datatype}, "
        # add foreign key references (if any)
        foreign_keys = get_foreign_key_references(table)
        for foreign_key in foreign_keys:
            context_string += f"FOREIGN KEY ({foreign_key[1]}) REFERENCES {foreign_key[2]}({foreign_key[3]}), "
        context_string = context_string[:-2]  # remove the last comma and space
        context_string += "); "
    return context_string


# function to perform the query on the database and return the result, and if error occurs, return the error message (also use another return variable to check if error occurred or not)
def execute_query(query):
    # if the query does not start with 'select', we do not execute it
    if not query.lower().startswith("select"):
        return "Only SELECT queries are allowed!", True

    try:
        # execute the query and get the result, along with the column names
        with conn.cursor() as cursor:
            cursor.execute(query)
            result = cursor.fetchall()
            colnames = [desc[0] for desc in cursor.description]
        return [colnames, result], False
    except Exception as e:
        conn.rollback()  # Rollback the changes in case of an exception
        return str(e), True

In [None]:
# Import necessary libraries
from flask import Flask, request, jsonify, render_template
from pyngrok import ngrok

# Run Flask app
app = Flask(__name__)

# query endpoint, which gets the query and table names for generating context from the request, gets the corresponding SQL query, executes the query on the database and gets the result
# gets the query from form as text
# gets the tables from form (the form contains checkboxes for each table, so the tables are the names of the tables which are checked)
# renders the query.html page with the sql query generated and result/ or error message
@app.route('/query', methods=['POST'])
def query():
    query = request.form['query']
    tables = request.form.getlist('tables[]')
    context = create_context_string(tables)
    print(context)
    sql_query = get_sql(query, context)
    output, err = execute_query(sql_query)
    if err:
        result = output
    else:
        # if the query is executed successfully, output is a list with first element as column names and second element as the result
        # we add the column names on top of the result
        result = [output[0]] + output[1]
        
    print(err)
    print(result)
    return render_template('query.html', sql_query=sql_query, result=result, err=err)

@app.route('/')
def home():
    # get the tables and columns in the database
    tables_columns = get_tables_columns()

    #render the home page
    return render_template('index.html', tables_columns=tables_columns)

In [None]:
ngrok_tunnel = ngrok.connect(addr="8084", proto="http", bind_tls=True)
print("Public URL:", ngrok_tunnel.public_url)

In [None]:
app.run(port=8084, debug=False)