<a href="https://colab.research.google.com/github/loni9164/text_sql/blob/main/sql_starcoder_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install libraries

In [None]:
!nvidia-smi

Mon Jan  8 07:04:29 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla V100-DGXS-16GB           On  | 00000000:05:00.0 Off |                    0 |
| N/A   29C    P0              51W / 300W |  10942MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS-16GB           On  | 00000000:06:00.0 Off |  

In [None]:
!pip install langchain langchain-experimental
!pip install -q  langchain
!pip install sentence-transformers
!pip install chromadb

!pip3 install transformers optimum
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu122/  # Use cu117 if on CUDA 11.7
!pip install gradio

Collecting langchain
  Downloading langchain-0.1.0-py3-none-any.whl.metadata (13 kB)
Collecting langchain-experimental
  Downloading langchain_experimental-0.0.47-py3-none-any.whl.metadata (1.9 kB)
Collecting SQLAlchemy<3,>=1.4 (from langchain)
  Downloading SQLAlchemy-2.0.25-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.6 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain)
  Downloading aiohttp-3.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.4 kB)
Collecting async-timeout<5.0.0,>=4.0.0 (from langchain)
  Downloading async_timeout-4.0.3-py3-none-any.whl.metadata (4.2 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain)
  Downloading dataclasses_json-0.6.3-py3-none-any.whl.metadata (25 kB)
Collecting jsonpatch<2.0,>=1.33 (from langchain)
  Downloading jsonpatch-1.33-py2.py3-none-any.whl.metadata (3.0 kB)
Collecting langchain-community<0.1,>=0.0.9 (from langchain)
  Downloading langchain_community-0.0.10-py3-none-any.whl.met

In [None]:
!sudo apt-get -qq update
!sudo apt-get -qq -y install postgresql
!pip install psycopg2-binary

[0m

# Imports

In [None]:
import psycopg2
import sqlite3
import time
import pickle
import re

# from langchain.llms import CTransformers
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain
import gradio as gr

# Create PostgreSQL

In [None]:
!service postgresql start
!sudo -u postgres psql -U postgres -c "ALTER USER postgres PASSWORD '12345';"
!service postgresql restart

 * Starting PostgreSQL 12 database server
   ...done.
ALTER ROLE
 * Restarting PostgreSQL 12 database server
   ...done.


In [None]:
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",  # Use the correct password here
    dbname="postgres"
)
conn.autocommit = True
cursor = conn.cursor()

In [None]:
# Creating a new database
cursor.execute("CREATE DATABASE credit_card_system")
cursor.close()
conn.close()

DuplicateDatabase: database "credit_card_system" already exists


In [None]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [None]:
# SQL statements to create tables
create_table_statements = [
    """
    CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )
    """,
    """
    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )
    """,

    """
    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )
    """,

    """
    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        total_credit_limit REAL,
        current_outstanding_amount REAL,
        remaining_credit_limit REAL,
        total_amount_due REAL,
        minimum_amount_due REAL,
        statement_date DATE,
        amount_due_on DATE,
        control_limit REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )
    """,

    """
    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )
    """,

    """
    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )
    """,

    """
    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )
    """
]


# Execute each CREATE TABLE statement
for statement in create_table_statements:
    cursor.execute(statement)

conn.commit()

In [None]:
import pandas as pd

# Function to load data from CSV to a table
def load_csv_to_table(csv_file_path, table_name):
    data = pd.read_csv(csv_file_path)
    for i, row in data.iterrows():
        insert_query = "INSERT INTO {} VALUES %s".format(table_name)
        cursor.execute(insert_query, (tuple(row),))

In [None]:
!git clone https://github.com/loni9164/text_sql.git

Cloning into 'text_sql'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (115/115), done.[K
remote: Total 133 (delta 73), reused 41 (delta 16), pack-reused 0[K
Receiving objects: 100% (133/133), 557.82 KiB | 5.69 MiB/s, done.
Resolving deltas: 100% (73/73), done.


In [None]:
# Load data from CSV files in the correct order
load_csv_to_table('text_sql/csv_files/branch.csv', 'branch')
load_csv_to_table('text_sql/csv_files/category.csv', 'category')
load_csv_to_table('text_sql/csv_files/users.csv', 'users')
load_csv_to_table('text_sql/csv_files/credit_card.csv', 'credit_card')
load_csv_to_table('text_sql/csv_files/transactions.csv', 'transactions')
load_csv_to_table('text_sql/csv_files/credit_card_financial.csv', 'credit_card_financial')
load_csv_to_table('text_sql/csv_files/reward.csv', 'reward')

conn.commit()

UniqueViolation: duplicate key value violates unique constraint "branch_pkey"
DETAIL:  Key (branch_id)=(1) already exists.


# DB connection

In [None]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)

In [None]:
def query_db(query):
  cursor = conn.cursor()
  cursor.execute("ROLLBACK")
  cursor.execute(query)
  try:
      return cursor.fetchall()
  except:
      return ""

query_db('SELECT user_id FROM users LIMIT 5;')

[(3,), (4,), (5,), (6,), (7,)]

In [None]:
db_user = "postgres"
db_password = "12345"
db_host = "localhost"
db_name = "credit_card_system"

cursor = conn.cursor()

connection_string = f"postgresql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(connection_string)
table_info = db.table_info
# print(table_info)

# Load model

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/sqlcoder2-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-128g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

In [None]:
from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, max_input_length=10000)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [None]:
from transformers import GenerationConfig
generation_config = GenerationConfig(
    # max_length=1000,  # Adjust as needed
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

In [None]:
# print("\n\n*** Generate:")



# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
# pipe = pipeline(
#     "text-generation",
#     model=model,
#     tokenizer=tokenizer,
#     generation_config=generation_config
# )


# print(pipe(prompt_template)[0]['generated_text'])

In [None]:
# print("\n\n*** Generate:")

# input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# output = model.generate(inputs=input_ids, temperature=0, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))

# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

# print(pipe(prompt_template)[0]['generated_text'])

In [None]:
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain
llm = HuggingFacePipeline(pipeline=pipe)

# Few shot learning

We will use few shot learning to fix issues we have seen so far

### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='BAAI/bge-large-en-v1.5')

In [None]:
few_shot_examples = [
   {
      "Question":"Can you update my email id to loni12b@gmail.com for user_id = 1",
      "SQLQuery":"UPDATE users SET user_email = 'loni12b@gmail.com' WHERE user_id = 1;",
      "SQLResult":"",
      "Answer":"Your email has been updated to loni12b@gmail.com"
   },
   {
      "Question":"Can you update my name to loni for user_id = 2",
      "SQLQuery":"UPDATE users SET user_name = 'loni' WHERE user_id = 2;",
      "SQLResult":"",
      "Answer":"Your name has been updated to loni"
   },
   {
      "Question":"How many transactions do i have for user_id = 1?",
      "SQLQuery":"SELECT COUNT(*) AS transaction_count FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 1;",
      "SQLResult":"[(507,)]",
      "Answer":"You have 507 transactions."
   },
   {
      "Question":"What is my total credit limit for user_id = 1?",
      "SQLQuery":"SELECT SUM(total_credit_limit) AS total_credit_limit FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[(40000.0,)]",
      "Answer":"Your total credit limit is 40000"
   },
   {
      "Question":"Can I see my last 5 transactions for user_id = 2?",
      "SQLQuery":"SELECT * FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 2 ORDER BY transaction_date DESC LIMIT 5;",
      "SQLResult":"[(1565, 4, datetime.date(2023, 11, 30), 352.09, 'Evans, Norman and Howe', 4, 'Credit', 'Student language chance scene rock sort building need.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0), (1308, 6, datetime.date(2023, 11, 30), 362.46, 'Brown-Bailey', 3, 'Credit', 'Road early may whom.', 6, 2, '4860000000000000.0', 'VISA 13 digit', datetime.date(2033, 3, 31), 744, datetime.date(2021, 6, 16), 5000.0, 1814.35, 3185.65, 1814.35, 206.0, datetime.date(2023, 12, 25), datetime.date(2023, 12, 15), 501.0), (1442, 4, datetime.date(2023, 11, 30), 230.35, 'Jenkins PLC', 6, 'Credit', 'Red president anyone.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0), (1189, 5, datetime.date(2023, 11, 30), 8.57, 'Mills Inc', 10, 'Credit', 'Serious career thought on exactly director.', 5, 2, '7860000000000000.0', 'Diners Club / Carte Blanche', datetime.date(2026, 9, 30), 758, datetime.date(2018, 12, 8), 15000.0, 66.25, 14933.75, 66.25, 205.0, datetime.date(2023, 12, 24), datetime.date(2023, 12, 14), 501.0), (1580, 4, datetime.date(2023, 11, 29), 347.8, 'Price, Contreras and Gomez', 6, 'Credit', 'Edge run writer perform between.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0)]",
      "Answer":"Here are your last 5 transactions, [(1565, 4, datetime.date(2023, 11, 30), 352.09, 'Evans, Norman and Howe', 4, 'Credit', 'Student language chance scene rock sort building need.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0), (1308, 6, datetime.date(2023, 11, 30), 362.46, 'Brown-Bailey', 3, 'Credit', 'Road early may whom.', 6, 2, '4860000000000000.0', 'VISA 13 digit', datetime.date(2033, 3, 31), 744, datetime.date(2021, 6, 16), 5000.0, 1814.35, 3185.65, 1814.35, 206.0, datetime.date(2023, 12, 25), datetime.date(2023, 12, 15), 501.0), (1442, 4, datetime.date(2023, 11, 30), 230.35, 'Jenkins PLC', 6, 'Credit', 'Red president anyone.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0), (1189, 5, datetime.date(2023, 11, 30), 8.57, 'Mills Inc', 10, 'Credit', 'Serious career thought on exactly director.', 5, 2, '7860000000000000.0', 'Diners Club / Carte Blanche', datetime.date(2026, 9, 30), 758, datetime.date(2018, 12, 8), 15000.0, 66.25, 14933.75, 66.25, 205.0, datetime.date(2023, 12, 24), datetime.date(2023, 12, 14), 501.0), (1580, 4, datetime.date(2023, 11, 29), 347.8, 'Price, Contreras and Gomez', 6, 'Credit', 'Edge run writer perform between.', 4, 2, '6360000000000000.0', 'VISA 19 digit', datetime.date(2031, 5, 31), 660, datetime.date(2022, 11, 1), 20000.0, 3431.64, 16568.36, 3431.64, 204.0, datetime.date(2023, 12, 23), datetime.date(2023, 12, 13), 501.0)]."
   },
   {
      "Question":"How many transactions do I have in category 2 for user_id = 1?",
      "SQLQuery":"SELECT COUNT(*) AS transaction_count FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 1 AND transactions.category_id = 2;",
      "SQLResult":"[(52,)]",
      "Answer":"You have 52 transactions in category 2."
   },
   {
      "Question":"How many transactions do I have in category 4 for user_id = 1?",
      "SQLQuery":"SELECT COUNT(*) AS transaction_count FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 1 AND transactions.category_id = 4;",
      "SQLResult":"[(46,)]",
      "Answer":"You have 46 transactions in category 4."
   },
   {
      "Question":"How many transactions do I have in category 5 for user_id = 2?",
      "SQLQuery":"SELECT COUNT(*) AS transaction_count FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 2 AND transactions.category_id = 5;",
      "SQLResult":"[(61,)]",
      "Answer":"You have 61 transactions in category 5."
   },
    {
      "Question":"What are my credit card numbers for user_id = 1?",
      "SQLQuery":"SELECT card_number FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[('5360000000000000.0',), ('4360000000000000.0',), ('7360000000000000.0',)]",
      "Answer":"Your credit card numbers are 5360000000000000, 4360000000000000, 7360000000000000."
   },
   {
      "Question":"What types of credit cards do I hold for user_id = 2?",
      "SQLQuery":"SELECT card_type FROM credit_card WHERE user_id = 2;",
      "SQLResult":"[('VISA 19 digit',), ('Diners Club / Carte Blanche',), ('VISA 13 digit',)]",
      "Answer":"You hold the following types of credit cards: VISA 19 digit, Diners Club / Carte Blanche, VISA 13 digit."
   },
   {
      "Question":"When do my credit cards expire for user_id = 2?",
      "SQLQuery":"SELECT card_number, expiry_date FROM credit_card WHERE user_id = 2;",
      "SQLResult":"[('6360000000000000.0', '2031-05-31'), ('7860000000000000.0', '2026-09-30'), ('4860000000000000.0', '2033-03-31')]",
      "Answer":"Your credit cards expire on the following dates: 6360000000000000 on 31-05-2031, 7860000000000000 on 30-09-2026, 4860000000000000 on 31-03-2033."
   },
   {
      "Question":"When were my credit cards issued for user_id = 1?",
      "SQLQuery":"SELECT card_number, issue_date FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[('5360000000000000.0', '2017-04-28'), ('4360000000000000.0', '2019-11-22'), ('7360000000000000.0', '2014-11-10')]",
      "Answer":"Your credit cards were issued on the following dates: 5360000000000000 on 28-04-2017, 4360000000000000 on 22-11-2019, 7360000000000000 on 10-11-2014."
   },
   {
      "Question":"What is my total credit limit across all cards for user_id = 1?",
      "SQLQuery":"SELECT SUM(total_credit_limit) FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[(40000.0,)]",
      "Answer":"Your total credit limit across all cards is ₹40000.0."
   },
   {
      "Question":"How much do I currently owe on my credit cards for user_id = 2?",
      "SQLQuery":"SELECT card_number, current_outstanding_amount FROM credit_card WHERE user_id = 2;",
      "SQLResult":"[('6360000000000000.0', 3431.64), ('7860000000000000.0', 66.25), ('4860000000000000.0', 1814.35)]",
      "Answer":"You currently owe ₹3431.64 on card 6360000000000000, ₹66.25 on card 7860000000000000, and ₹1814.35 on card 4860000000000000."
   },
   {
      "Question":"What is the remaining credit limit on my cards for user_id = 1?",
      "SQLQuery":"SELECT card_number, remaining_credit_limit FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[('5360000000000000.0', 4034.3), ('4360000000000000.0', 10728.85), ('7360000000000000.0', 16458.87)]",
      "Answer":"The remaining credit limits on your cards are ₹4034.3 for card 5360000000000000, ₹10728.85 for card 4360000000000000, and ₹16458.87 for card 7360000000000000."
   },
   {
      "Question":"What is the minimum amount I need to pay on each card for user_id = 2?",
      "SQLQuery":"SELECT card_number, minimum_amount_due FROM credit_card WHERE user_id = 2;",
      "SQLResult":"[('6360000000000000.0', 204.0), ('7860000000000000.0', 205.0), ('4860000000000000.0', 206.0)]",
      "Answer":"The minimum amounts you need to pay are ₹204.0 for card 6360000000000000, ₹205.0 for card 7860000000000000, and ₹206.0 for card 4860000000000000."
   },
   {
      "Question":"When is my next statement date for each card for user_id = 1?",
      "SQLQuery":"SELECT card_number, statement_date FROM credit_card WHERE user_id = 1;",
      "SQLResult":"[('5360000000000000.0', '2023-12-20'), ('4360000000000000.0', '2023-12-21'), ('7360000000000000.0', '2023-12-22')]",
      "Answer":"Your next statement dates for each card are: for card 5360000000000000 on 20-12-2023, for card 4360000000000000 on 21-12-2023, and for card 7360000000000000 on 22-12-2023."
   },
   {
      "Question":"By when do I need to pay my credit card bills for user_id = 2?",
      "SQLQuery":"SELECT card_number, amount_due_on FROM credit_card WHERE user_id = 2;",
      "SQLResult":"[('6360000000000000.0', '2023-12-13'), ('7860000000000000.0', '2023-12-14'), ('4860000000000000.0', '2023-12-15')]",
      "Answer":"You need to pay your credit card bills by: for card 6360000000000000 on 13-12-2023, for card 7860000000000000 on 14-12-2023, and for card 4860000000000000 on 15-12-2023."
   },
   {
      "Question":"What is my branch address for user_id = 1?",
      "SQLQuery":"SELECT branch_address FROM branch WHERE branch_id = (SELECT branch_id FROM users WHERE user_id = 1);",
      "SQLResult":"[('178 Young Neck Suite 826, New Jennifer, FL 64057',)]",
      "Answer":"The address of your branch is 178 Young Neck Suite 826, New Jennifer, FL 64057."
   },
   {
      "Question":"What is my branch name  for user_id = 2?",
      "SQLQuery":"SELECT branch_name FROM branch WHERE branch_id = (SELECT branch_id FROM users WHERE user_id = 2);",
      "SQLResult":"[('Branch 2',)]",
      "Answer":"The name of your branch is Branch 2."
   },
   {
      "Question":"What is my branch email id for user_id = 1?",
      "SQLQuery":"SELECT branch_email FROM branch WHERE branch_id = (SELECT branch_id FROM users WHERE user_id = 1);",
      "SQLResult":"[('kristen70@morgan.com',)]",
      "Answer":"The email ID of your branch is kristen70@morgan.com."
   },
   {
      "Question":"When did the branch associated with my account open for user_id = 1?",
      "SQLQuery":"SELECT established_date FROM branch WHERE branch_id = (SELECT branch_id FROM users WHERE user_id = 1);",
      "SQLResult":"[('14-10-2008',)]",
      "Answer":"The branch associated with your account was opened on 14-10-2008."
   },
   {
      "Question":"How many employees do we have in the branch associated with my account for user_id = 2?",
      "SQLQuery":"SELECT number_of_employees FROM branch WHERE branch_id = (SELECT branch_id FROM users WHERE user_id = 2);",
      "SQLResult":"[(39,)]",
      "Answer":"The branch associated with your account has 39 employees."
   },
   {
      "Question":"What is my email id?",
      "SQLQuery":"SELECT user_email FROM users WHERE user_id = 1;",
      "SQLResult":"[('jacobmoore@hotmail.com',)]",
      "Answer":"Your email ID is jacobmoore@hotmail.com."
   },
   {
      "Question":"When did I open my account?",
      "SQLQuery":"SELECT registration_date FROM users WHERE user_id = 1;",
      "SQLResult":"[('27-01-2020',)]",
      "Answer":"You opened your account on 27-01-2020."
   },
   {
      "Question":"What is my account status?",
      "SQLQuery":"SELECT status FROM users WHERE user_id = 1;",
      "SQLResult":"[('Active',)]",
      "Answer":"Your account status is Active."
   },
   {
      "Question":"What is my registered phone number?",
      "SQLQuery":"SELECT user_phone FROM users WHERE user_id = 1;",
      "SQLResult":"[('2527317877',)]",
      "Answer":"Your registered phone number is 2527317877.",
   },
   {
      "Question":"What is my total spending on fuel for user_id = 1?",
      "SQLQuery":"SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE category.category_name ILIKE '%fuel%' AND transactions.transaction_type ILIKE '%Debit%' AND credit_card.user_id = 1;",
      "SQLResult":"[(8282.859,)]",
      "Answer":"Your total spending on fuel for completed transactions is ₹8282.86."
   },
   {
      "Question":"Find my oldest credit card for user_id = 1.",
      "SQLQuery":"SELECT card_number FROM credit_card WHERE user_id = 1 ORDER BY issue_date ASC LIMIT 1;",
      "SQLResult":"[('7360000000000000.0',)]",
      "Answer":"Your oldest credit card is the one with the number 7360000000000000."
   },
   {
      "Question":"List my cards which are going to expire by Sep 2026 for user_id = 2",
      "SQLQuery":"SELECT card_number FROM credit_card WHERE user_id = 2 AND expiry_date <= '2026-09-30';",
      "SQLResult":"[('7860000000000000.0',)]",
      "Answer":"Your card number which is going to expire by September 2026 is 7860000000000000."
   },
   {
      "Question":"What is my spending for food and groceries for user_id = 2?",
      "SQLQuery":"SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE category.category_name ILIKE '%food%and%groceries%' AND credit_card.user_id = 2 AND transactions.transaction_type ILIKE '%debit%';",
      "SQLResult":"[(6724.62,)]",
      "Answer":"Your total spending for food and groceries is ₹6724.62."
   },
   {
      "Question":"List my cards which are expiring in 2024 for user_id = 1.",
      "SQLQuery":"SELECT card_number FROM credit_card WHERE user_id = 1 AND EXTRACT(YEAR FROM expiry_date) = 2024;",
      "SQLResult":"[('5360000000000000.0',)]",
      "Answer":"Your card number which is expiring in 2024 is 5360000000000000."
   },
   {
      "Question":"What is my spending on Movies and Entertainment for user_id = 2?",
      "SQLQuery":"SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE category.category_name ILIKE '%movies%and%entertainment%' AND credit_card.user_id = 2 AND transactions.transaction_type ILIKE '%debit%';",
      "SQLResult":"[(6335.5806,)]",
      "Answer":"Your total spending on Movies and Entertainment is ₹6335.58."
   }
]

In [None]:
len(few_shot_examples)

33

In [None]:
few_shot_examples3 = []
for item in few_shot_examples2:
    item["SQLResult"] = query_db(item['SQLQuery'])
    few_shot_examples3.append(item)

In [None]:
print(few_shot_examples3)

[{'Question': 'What are my credit card numbers for user_id = 1?', 'SQLQuery': 'SELECT card_number FROM credit_card WHERE user_id = 1;', 'SQLResult': [('5360000000000000.0',), ('4360000000000000.0',), ('7360000000000000.0',)], 'Answer': 'Your credit card numbers are [List of Card Numbers].'}, {'Question': 'What types of credit cards do I hold for user_id = 2?', 'SQLQuery': 'SELECT card_type FROM credit_card WHERE user_id = 2;', 'SQLResult': [('VISA 19 digit',), ('Diners Club / Carte Blanche',), ('VISA 13 digit',)], 'Answer': 'You hold the following types of credit cards: [List of Card Types].'}, {'Question': 'When do my credit cards expire for user_id = 2?', 'SQLQuery': 'SELECT card_number, expiry_date FROM credit_card WHERE user_id = 2;', 'SQLResult': [('6360000000000000.0', datetime.date(2031, 5, 31)), ('7860000000000000.0', datetime.date(2026, 9, 30)), ('4860000000000000.0', datetime.date(2033, 3, 31))], 'Answer': 'Your credit cards expire on the following dates: [List of Expiry Date

In [None]:
updated_few_shot_examples = []
for example in few_shot_examples3:
  if not isinstance(example['SQLResult'], str):
    example['SQLResult'] = str(example['SQLResult'])
  else:
    example['SQLResult'] = example['SQLResult']
  updated_few_shot_examples.append(example)

# updated_few_shot_examples

In [None]:
print(updated_few_shot_examples)

[{'Question': 'What are my credit card numbers for user_id = 1?', 'SQLQuery': 'SELECT card_number FROM credit_card WHERE user_id = 1;', 'SQLResult': "[('5360000000000000.0',), ('4360000000000000.0',), ('7360000000000000.0',)]", 'Answer': 'Your credit card numbers are [List of Card Numbers].'}, {'Question': 'What types of credit cards do I hold for user_id = 2?', 'SQLQuery': 'SELECT card_type FROM credit_card WHERE user_id = 2;', 'SQLResult': "[('VISA 19 digit',), ('Diners Club / Carte Blanche',), ('VISA 13 digit',)]", 'Answer': 'You hold the following types of credit cards: [List of Card Types].'}, {'Question': 'When do my credit cards expire for user_id = 2?', 'SQLQuery': 'SELECT card_number, expiry_date FROM credit_card WHERE user_id = 2;', 'SQLResult': "[('6360000000000000.0', datetime.date(2031, 5, 31)), ('7860000000000000.0', datetime.date(2026, 9, 30)), ('4860000000000000.0', datetime.date(2033, 3, 31))]", 'Answer': 'Your credit cards expire on the following dates: [List of Expir

In [None]:
len(updated_few_shot_examples)

8

In [None]:
to_vectorize = [" ".join(example.values()) for example in few_shot_examples]

In [None]:
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=updated_few_shot_examples)

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

In [None]:
example_selector.select_examples({"Question": "Can you update my name to loni for user_id = 2"})

[{'Answer': 'Your name has been updated to loni',
  'Question': 'Can you update my name to loni for user_id = 2',
  'SQLQuery': "UPDATE users SET user_name = 'loni' WHERE user_id = 2;",
  'SQLResult': ''},
 {'Answer': 'Your email has been updated to loni12b@gmail.com',
  'Question': 'Can you update my email id to loni12b@gmail.com for user_id = 1',
  'SQLQuery': "UPDATE users SET user_email = 'loni12b@gmail.com' WHERE user_id = 1;",
  'SQLResult': ''}]

### Setting up PromptTemplete using input variables

In [None]:
from langchain.prompts.prompt import PromptTemplate
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import POSTGRES_PROMPT, PROMPT_SUFFIX

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

In [None]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    # prefix= prefix,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)
db_chain_sql = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_sql=True)

In [None]:
user_id = 1
message = "how much tax i paid"
message = str(message) + ' ' + 'for user_id = '+ str(user_id)
print(message)
db_chain.run(query=message, batch_size=1)

how much tax i paid for user_id = 1


[1m> Entering new SQLDatabaseChain chain...[0m
how much tax i paid for user_id = 1
SQLQuery:



[32;1m[1;3mSELECT SUM(amount::float::integer) AS total_tax_paid FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 1 AND transactions.category_id::integer > 4;[0m
SQLResult: [33;1m[1;3m[(79103,)][0m
Answer:



[32;1m[1;3mYou have earned $79,103 in tax today![0m
[1m> Finished chain.[0m


'You have earned $79,103 in tax today!'

In [None]:
few_shot_examples

[{'Question': 'Can you update my email id to loni12b@gmail.com for user_id = 1',
  'SQLQuery': "UPDATE users SET user_email = 'loni12b@gmail.com' WHERE user_id = 1;",
  'SQLResult': '',
  'Answer': 'Your email has been updated to loni12b@gmail.com'},
 {'Question': 'Can you update my name to loni for user_id = 2',
  'SQLQuery': "UPDATE users SET user_name = 'loni' WHERE user_id = 2;",
  'SQLResult': '',
  'Answer': 'Your name has been updated to loni'},
 {'Question': 'How many transactions do i have for user_id = 1?',
  'SQLQuery': 'SELECT COUNT(*) AS transaction_count FROM transactions JOIN credit_card ON transactions.card_id = credit_card.card_id WHERE credit_card.user_id = 1;',
  'SQLResult': '[(507,)]',
  'Answer': 'You have 507 transactions.'},
 {'Question': 'What is my total credit limit for user_id = 1?',
  'SQLQuery': 'SELECT SUM(total_credit_limit) AS total_credit_limit FROM credit_card WHERE user_id = 1;',
  'SQLResult': '[(40000.0,)]',
  'Answer': 'Your total credit limit is 

# Test

In [None]:
# with open('text_sql/few_shot_examples', 'rb') as f:
#   examples = pickle.load(f)

# examples[0]

{'Question': 'List all the cards going to expire by sep 2026',
 'SQLQuery': "SELECT CardNumber FROM credit_card WHERE ExpiryDate <= '2026-09-30';",
 'SQLResult': [('5040000000001',),
  ('35300000000000000',),
  ('52300000000000000',)],
 'Answer': '5040000000001, 35300000000000000, 52300000000000000'}

In [None]:
failed_items = []
for i, item in enumerate(examples):
    que = item['Question']
    print(que)
    print('----------------------------------')
    print(item['Answer'])
    print('----------------------------------')
    try:
      model_output = db_chain.run(que)
    except:
      failed_items.append(que)
      print('\nModel failed')

    print('**********************************')
    print('\n')

In [None]:
def sql_agent(query):
  try:
    return SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt).run(query)
  except:
    return SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt, return_sql=True).run(query)

In [None]:
sql_agent("what is the category name for category id 5")



[1m> Entering new SQLDatabaseChain chain...[0m
what is the category name for category id 5
SQLQuery:



[32;1m[1;3mSELECT category_name FROM category WHERE category_id = 5;[0m
SQLResult: [33;1m[1;3m[('Loan',)][0m
Answer:



[32;1m[1;3mLoan[0m
[1m> Finished chain.[0m


'Loan'

# UI

In [None]:
def respond(message, user_id, chat_history):
    # message = message.replace('i', 'we').replace('my', 'the')
    message = str(message) + ' ' + 'for user_id = '+ user_id
    sql_output = sql_agent(message)
    chat_history.append((message, sql_output))
    return "", user_id, chat_history

with gr.Blocks() as demo:
    gr.Markdown("Text to SQL agent")
    chat_history = gr.Chatbot(bubble_full_width=False, label="Conversations", show_copy_button=True, layout="bubble", visible=True)
    user_id = gr.Textbox(label = "User Id", placeholder = "Enter user id here")
    question = gr.Textbox(label = "Question", placeholder = "Ask your question here")

    clear = gr.ClearButton([question, chat_history, user_id])

    question.submit(respond, [question, user_id, chat_history], [question, user_id, chat_history])

if __name__ == "__main__":
    demo.launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://bb1ad65894a47705ef.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)




[1m> Entering new SQLDatabaseChain chain...[0m
Can you update my email id to loni12b@gmail.com for user_id = 1
SQLQuery:



[32;1m[1;3mSELECT user_id::TEXT || ', '::TEXT || user_name::TEXT || ', '::TEXT || user_email::TEXT AS info FROM users WHERE user_id::TEXT ilike '%1%' ORDER BY length(user_id::TEXT);[0m
SQLResult: [33;1m[1;3m[('1, Gabrielle Anderson, loni12b@gmail.com',), ('61, Larry Scott, watsonjennifer@gmail.com',), ('71, Kathryn Landry, cody76@schneider.com',), ('10, Juan Bailey, sandersshelby@barber-allen.com',), ('81, Monica Gilbert, davidkey@holmes.com',), ('91, David Barker, alfred46@yahoo.com',), ('11, Dwayne Rogers, patriciawebster@gmail.com',), ('12, Lisa Roberson, elizabethwalters@burton.info',), ('13, Jessica Barber, ortegalisa@williams-munoz.com',), ('51, Courtney Norton, wilsongregory@yahoo.com',), ('15, Anthony White, harrissarah@yahoo.com',), ('16, Rick Johnson, michele09@graham-baker.net',), ('17, Joseph Garcia, stacey81@clay-perez.org',), ('18, Brenda Mckinney, xkim@gmail.com',), ('19, Shane Walker MD, jmoore@hotmail.com',), ('21, Kevin Lowe, tbrown@fields.com',), ('31, William C




[1m> Finished chain.[0m
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://bb1ad65894a47705ef.gradio.live
