<a href="https://colab.research.google.com/github/loni9164/text_sql/blob/main/sql_starcoder_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install libraries

In [None]:
!pip install langchain langchain-experimental
!pip install -q  langchain
!pip install sentence-transformers
!pip install chromadb

!pip3 install transformers optimum
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/  # Use cu117 if on CUDA 11.7
!pip install gradio

In [None]:
!apt-get -qq update
!apt-get -qq -y install postgresql
!pip install psycopg2-binary

# Imports

In [1]:
import psycopg2
import sqlite3
import time
import pickle
import re

# from langchain.llms import CTransformers
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
import gradio as gr

# Create PostgreSQL

In [2]:
!service postgresql start
!sudo -u postgres psql -U postgres -c "ALTER USER postgres PASSWORD '12345';"
!service postgresql restart

 * Starting PostgreSQL 14 database server
   ...done.
ALTER ROLE
 * Restarting PostgreSQL 14 database server
   ...done.


In [3]:
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",  # Use the correct password here
    dbname="postgres"
)
conn.autocommit = True
cursor = conn.cursor()

In [4]:
# Creating a new database
cursor.execute("CREATE DATABASE credit_card_system")
cursor.close()
conn.close()

In [5]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [6]:
# SQL statements to create tables
create_table_statements = [
    """
    CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )
    """,
    """
    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )
    """,

    """
    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )
    """,

    """
    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        total_credit_limit REAL,
        current_outstanding_amount REAL,
        remaining_credit_limit REAL,
        total_amount_due REAL,
        minimum_amount_due REAL,
        statement_date DATE,
        amount_due_on DATE,
        control_limit REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )
    """,

    """
    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )
    """,

    """
    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )
    """,

    """
    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )
    """
]


# Execute each CREATE TABLE statement
for statement in create_table_statements:
    cursor.execute(statement)

conn.commit()

In [7]:
import pandas as pd

# Function to load data from CSV to a table
def load_csv_to_table(csv_file_path, table_name):
    data = pd.read_csv(csv_file_path)
    for i, row in data.iterrows():
        insert_query = "INSERT INTO {} VALUES %s".format(table_name)
        cursor.execute(insert_query, (tuple(row),))

In [8]:
!git clone https://github.com/loni9164/text_sql.git

Cloning into 'text_sql'...
remote: Enumerating objects: 130, done.[K
remote: Counting objects: 100% (130/130), done.[K
remote: Compressing objects: 100% (111/111), done.[K
remote: Total 130 (delta 71), reused 43 (delta 17), pack-reused 0[K
Receiving objects: 100% (130/130), 548.02 KiB | 4.72 MiB/s, done.
Resolving deltas: 100% (71/71), done.


In [9]:
# Load data from CSV files in the correct order
load_csv_to_table('text_sql/csv_files/branch.csv', 'branch')
load_csv_to_table('text_sql/csv_files/category.csv', 'category')
load_csv_to_table('text_sql/csv_files/users.csv', 'users')
load_csv_to_table('text_sql/csv_files/credit_card.csv', 'credit_card')
load_csv_to_table('text_sql/csv_files/transactions.csv', 'transactions')
load_csv_to_table('text_sql/csv_files/credit_card_financial.csv', 'credit_card_financial')
load_csv_to_table('text_sql/csv_files/reward.csv', 'reward')

conn.commit()

# DB connection

In [10]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)

In [11]:
def query_db(query):
  cursor = conn.cursor()
  cursor.execute("ROLLBACK")
  cursor.execute(query)
  return cursor.fetchall()

query_db('SELECT user_id FROM users LIMIT 5;')

[(1,), (2,), (3,), (4,), (5,)]

In [12]:
db_user = "postgres"
db_password = "12345"
db_host = "localhost"
db_name = "credit_card_system"

cursor = conn.cursor()

connection_string = f"postgresql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(connection_string)
table_info = db.table_info
# print(table_info)

# Load model

In [13]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/sqlcoder2-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-128g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/9.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [14]:
from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, max_input_length=3000)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

tokenizer_config.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/532 [00:00<?, ?B/s]

In [15]:
from transformers import GenerationConfig
generation_config = GenerationConfig(
    # max_length=1000,  # Adjust as needed
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

In [16]:
# print("\n\n*** Generate:")



# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
# pipe = pipeline(
#     "text-generation",
#     model=model,
#     tokenizer=tokenizer,
#     generation_config=generation_config
# )


# print(pipe(prompt_template)[0]['generated_text'])

In [17]:
# print("\n\n*** Generate:")

# input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# output = model.generate(inputs=input_ids, temperature=0, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))

# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

# print(pipe(prompt_template)[0]['generated_text'])

In [18]:
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
llm = HuggingFacePipeline(pipeline=pipe)

# Few shot learning

We will use few shot learning to fix issues we have seen so far

### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')

In [20]:
few_shot_examples = [
  {'Question': 'What is the total number of active credit cards?',
  'SQLQuery': 'SELECT COUNT(DISTINCT card_id) AS number_of_cards FROM credit_card WHERE expiry_date > CURRENT_DATE;',
  'SQLResult': '[(6,)]',
  'Answer': 'There are six active credit cards in the system'},
 {'Question': 'Find the total amount spent on fuel by all users.',
  'SQLQuery': "SELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id WHERE category.category_name ilike '%fuel%' AND transactions.transaction_status ILIKE '%completed%' AND transactions.transaction_type ILIKE '%Debit%';",
  'SQLResult': '[(4508.70002746582,)]',
  'Answer': 'Total amount spent on fuel by all users for completed transactions in ₹4508.70'},
 {'Question': 'Find the user with the oldest credit card.',
  'SQLQuery': 'SELECT users.user_name FROM users JOIN credit_card ON users.user_id = credit_card.user_id ORDER BY credit_card.expiry_date ASC LIMIT 1;',
  'SQLResult': "[('Gabrielle Anderson',)]",
  'Answer': 'The user with oldest credit card is Gabrielle Anderson'},
 {'Question': 'How many rewards were earned for transactions above ₹1000?',
  'SQLQuery': 'SELECT COUNT(DISTINCT r.reward_id) AS total_rewards FROM transactions t JOIN reward r ON t.transaction_id = r.transaction_id WHERE t.amount > 1000;',
  'SQLResult': '[(None,)]',
  'Answer': 'There were no rewards earned for transactions above ₹1000'},
 {'Question': 'List the top 3 users by total transaction amount.',
  'SQLQuery': 'SELECT u.user_name, SUM(t.amount) AS total_amount FROM users u JOIN credit_card c ON u.user_id = c.user_id JOIN transactions t ON c.card_id = t.card_id GROUP BY u.user_name ORDER BY total_amount DESC NULLS LAST LIMIT 3;',
  'SQLResult': "[('Gabrielle Anderson', 128059.26), ('Michael Baldwin', 121319.25)]",
  'Answer': 'Top 3 users by total transaction amount are Gabrielle Anderson (with ₹128059.26) and Michael Baldwin (With ₹121319.25).'},
  {'Question': 'List all the cards going to expire by sep 2026',
  'SQLQuery': "SELECT * FROM credit_card WHERE expiry_date <= '2026-09-30';",
  'SQLResult': '''
  [('53600000000000000', datetime.date(2029, 4, 30)),
   ('22900000000000000', datetime.date(2031, 5, 31)),
   ('40900000000006', datetime.date(2033, 3, 31))]
   ''',
  'Answer': 'Card ids which are going to expire by sep 2026 are 1, 3 and 5'},
 {'Question': 'what is the total amount spent by user 2 for food and groceries?',
  'SQLQuery': "SELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id JOIN users ON credit_card.user_id = users.user_id WHERE category.category_name ILIKE '%food%and%groceries%' AND users.user_id = 2 AND transactions.transaction_status ILIKE '%Completed%' AND transactions.transaction_type ILIKE '%debit%';",
  'SQLResult': [(1634.1600036621094,)],
  'Answer': "Total amount spent by user 2 for food and groceries is ₹1634.16"},
 {'Question': 'How many unique users do we have in transactions',
  'SQLQuery': 'SELECT COUNT(DISTINCT user_id) FROM credit_card WHERE card_id IN (SELECT card_id FROM transactions);',
  'SQLResult': [(2,)],
  'Answer': 'There are two users in transactions'},
 {'Question': 'What is the total amount lended to the user 1',
  'SQLQuery': "SELECT SUM(loan_amount) AS total_amount_lended FROM credit_card_financial WHERE card_id IN (SELECT card_id FROM credit_card WHERE user_id = 1);",
  'SQLResult': [(27493.361,)],
  'Answer': 'Total amount lended to the user1 is ₹27493.36'},
 {'Question': 'Total reward points earned by user2',
  'SQLQuery': "SELECT SUM(points_earned) AS reward_earned FROM reward WHERE transaction_id IN (SELECT transaction_id FROM transactions WHERE card_id IN (SELECT card_id FROM credit_card WHERE user_id=2));",
  'SQLResult': [(24340,)],
  'Answer': 'Total reward points earned by user2 is ₹24340'},
 {'Question': 'List all branches established after 2010.',
  'SQLQuery': 'SELECT * FROM branch WHERE EXTRACT(YEAR FROM established_date) > 2010;',
  'SQLResult': '''[(5,
  'Branch 5',
  '31651 Scott Ranch, East Sydney, HI 99471',
  '-9482',
  'Hailey Newton',
  'lorialexander@stevenson.org',
  datetime.date(2012, 3, 2),
  16)]''',
  'Answer': 'Branch established after 2010 is branch id 5'},
 {'Question': 'List the user names and total reward points earned by each user.',
  'SQLQuery': 'SELECT u.user_name, SUM(r.points_earned) AS total_points_earned FROM users u JOIN credit_card cc ON u.user_id = cc.user_id JOIN transactions t ON cc.card_id = t.card_id JOIN reward r ON t.transaction_id = r.transaction_id GROUP BY u.user_id, u.user_name;',
  'SQLResult': [('Michael Baldwin', 24340), ('Gabrielle Anderson', 26762)],
  'Answer': 'Rewared earned by Michael Baldwin and Gabrielle Andersoand are ₹24340 and ₹26762 respectively.'},
 {'Question': 'Find the average balance of credit cards issued by Branch 2.',
  'SQLQuery': "SELECT AVG(cc.current_balance::FLOAT) AS avg_balance FROM credit_card cc JOIN users u ON cc.user_id = u.user_id JOIN branch b ON u.branch_id = b.branch_id WHERE b.branch_name ILIKE '%branch%2%';",
  'SQLResult': [(1770.747817993164,)],
  'Answer': 'Average balance of credit cards issued by Branch 2 is ₹1770.74'},
 {'Question': 'List all users who have a credit card expiring in 2024.',
  'SQLQuery': "SELECT DISTINCT u.user_name FROM users u JOIN credit_card cc ON u.user_id = cc.user_id WHERE EXTRACT(YEAR FROM cc.expiry_date) = 2024;",
  'SQLResult': [('Gabrielle Anderson',)],
  'Answer': "Gabrielle Anderson's credit card is expiring in 2024"},
 {'Question': "Find the total amount spent in 'Movies and Entertainment' category by User 1.",
  'SQLQuery': "SELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id JOIN users ON credit_card.user_id = users.user_id WHERE category.category_name ILIKE '%movies%and%entertainment%' AND users.user_id = 1 AND transactions.transaction_status ILIKE '%Completed%' AND transactions.transaction_type ILIKE '%debit%' ;",
  'SQLResult': [(1664.789981842041,)],
  'Answer': "Total amount spent in 'Movies and Entertainment' category by User1 is ₹1664.78"},
 {'Question': 'What is the total loan amount issued to users of Branch 3?',
  'SQLQuery': "SELECT SUM(ccfi.loan_amount::FLOAT) AS total_loan_amount FROM branch b JOIN users u ON CAST(b.branch_id AS integer) = u.branch_id JOIN credit_card cc ON CAST(u.user_id AS integer) = cc.user_id JOIN credit_card_financial ccfi ON cc.card_id = ccfi.card_id WHERE b.branch_id::TEXT ilike '%2';",
  'SQLResult': [(20251.509757995605,)],
  'Answer': 'Total loan amount issued to users from branch2 is ₹20251.50'},
 {'Question': 'What is the total amount of all transactions completed in the last month?',
  'SQLQuery': "SELECT SUM(amount) AS total_amount_last_month FROM transactions WHERE transaction_date > CURRENT_DATE - INTERVAL '1 month';",
  'SQLResult': [(20251.512,)],
  'Answer': 'Total amount of all the transactions completed in the last month is ₹20251.51'}]

In [21]:
len(few_shot_examples)

17

In [22]:
updated_few_shot_examples = []
for example in few_shot_examples:
  if not isinstance(example['SQLResult'], str):
    example['SQLResult'] = str(example['SQLResult'])
  else:
    example['SQLResult'] = example['SQLResult']
  updated_few_shot_examples.append(example)

# updated_few_shot_examples

In [23]:
len(updated_few_shot_examples)

17

In [24]:
to_vectorize = [" ".join(example.values()) for example in few_shot_examples]

In [25]:
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=updated_few_shot_examples)

In [26]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

In [27]:
example_selector.select_examples({"Question": "what is the total spending for user 1"})

[{'Answer': 'Total amount spent by user 2 for food and groceries is ₹1634.16',
  'Question': 'what is the total amount spent by user 2 for food and groceries?',
  'SQLQuery': "SELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id JOIN users ON credit_card.user_id = users.user_id WHERE category.category_name ILIKE '%food%and%groceries%' AND users.user_id = 2 AND transactions.transaction_status ILIKE '%Completed%' AND transactions.transaction_type ILIKE '%debit%';",
  'SQLResult': '[(1634.1600036621094,)]'},
 {'Answer': "Total amount spent in 'Movies and Entertainment' category by User1 is ₹1664.78",
  'Question': "Find the total amount spent in 'Movies and Entertainment' category by User 1.",
  'SQLQuery': "SELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id

### Setting up PromptTemplete using input variables

In [31]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import POSTGRES_PROMPT, PROMPT_SUFFIX

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

In [32]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    # prefix= prefix,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)
db_chain_sql = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_sql=True)

In [None]:
db_chain_sql.run("Who user is having more spending user1 or user2")

# Test

In [None]:
with open('text_sql/few_shot_examples', 'rb') as f:
  examples = pickle.load(f)

examples[0]

{'Question': 'List all the cards going to expire by sep 2026',
 'SQLQuery': "SELECT CardNumber FROM credit_card WHERE ExpiryDate <= '2026-09-30';",
 'SQLResult': [('5040000000001',),
  ('35300000000000000',),
  ('52300000000000000',)],
 'Answer': '5040000000001, 35300000000000000, 52300000000000000'}

In [None]:
failed_items = []
for i, item in enumerate(examples):
    que = item['Question']
    print(que)
    print('----------------------------------')
    print(item['Answer'])
    print('----------------------------------')
    try:
      model_output = db_chain.run(que)
    except:
      failed_items.append(que)
      print('\nModel failed')

    print('**********************************')
    print('\n')

In [None]:
def sql_agent(query):
  try:
    return SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt).run(query)
  except:
    return SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_sql=True).run(query)

In [None]:
sql_agent("what is the category name for category id 5")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the category name for category id 5
SQLQuery:



[32;1m[1;3mSELECT category_name FROM category WHERE category_id = 5;[0m
SQLResult: [33;1m[1;3m[('Loan',)][0m
Answer:



[32;1m[1;3mLoan[0m
[1m> Finished chain.[0m


'Loan'

# UI

In [82]:
def respond(message, user_id, chat_history):
    # message = message.replace('i', 'we').replace('my', 'the')
    message = str(message) + ' ' + 'for user_id='+ user_id
    # bot_message = sql_agent(message)
    sql_output = message
    chat_history.append((message, sql_output))
    return "", user_id, chat_history

with gr.Blocks() as demo:
    gr.Markdown("Text to SQL agent")
    chat_history = gr.Chatbot(bubble_full_width=False, label="Conversations", show_copy_button=True, layout="bubble", visible=True)
    user_id = gr.Textbox(label = "User Id", placeholder = "Enter user id here")
    question = gr.Textbox(label = "Question", placeholder = "Ask your question here")

    clear = gr.ClearButton([question, chat_history, user_id])

    question.submit(respond, [question, user_id, chat_history], [question, user_id, chat_history])

if __name__ == "__main__":
    demo.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://4aae523ac09569b937.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


KeyboardInterrupt: ignored