<a href="https://colab.research.google.com/github/loni9164/text_sql/blob/main/sql_starcoder_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Create PostgreSQL

In [None]:
!apt-get -qq update
!apt-get -qq -y install postgresql
!service postgresql start

In [None]:
!pip install psycopg2-binary

In [None]:
!service postgresql restart

In [None]:
!sudo -u postgres psql

In [None]:
# ALTER USER postgres PASSWORD '12345';

In [None]:
import psycopg2
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",  # Use the correct password here
    dbname="postgres"
)
conn.autocommit = True
cursor = conn.cursor()

In [None]:
# Creating a new database
cursor.execute("CREATE DATABASE credit_card_system")
cursor.close()
conn.close()

In [None]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [None]:
# SQL statements to create tables
create_table_statements = [
    """
    CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )
    """,
    """
    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )
    """,

    """
    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )
    """,

    """
    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )
    """,

    """
    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )
    """,

    """
    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )
    """,

    """
    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )
    """
]


# Execute each CREATE TABLE statement
for statement in create_table_statements:
    cursor.execute(statement)

conn.commit()

In [None]:
import pandas as pd

# Function to load data from CSV to a table
def load_csv_to_table(csv_file_path, table_name):
    data = pd.read_csv(csv_file_path)
    for i, row in data.iterrows():
        insert_query = "INSERT INTO {} VALUES %s".format(table_name)
        cursor.execute(insert_query, (tuple(row),))

In [None]:
# Load data from CSV files in the correct order
load_csv_to_table('branch.csv', 'branch')
load_csv_to_table('category.csv', 'category')
load_csv_to_table('user.csv', 'users')
load_csv_to_table('credit_card.csv', 'credit_card')
load_csv_to_table('transactions.csv', 'transactions')
load_csv_to_table('credit_card_financial.csv', 'credit_card_financial')
load_csv_to_table('reward.csv', 'reward')

conn.commit()

In [None]:
def query_db(query):
  cursor = conn.cursor()
  cursor.execute("ROLLBACK")
  cursor.execute(query)
  return cursor.fetchall()


In [None]:
# query_db('SELECT user_id FROM users;')

# Installation

In [None]:
import sqlite3
import time
import pickle
import re

In [None]:
!git clone https://github.com/loni9164/text_sql.git

In [None]:
!pip install langchain langchain-experimental
!pip install -q  langchain
!pip install sentence-transformers
!pip install chromadb

In [None]:
# from langchain.llms import CTransformers
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain

# DB connection

In [None]:
db_user = "postgres"
db_password = "12345"
db_host = "localhost"
db_name = "credit_card_system"

connection_string = f"postgresql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(connection_string)
table_info = db.table_info
# print(table_info)

# Load LLM model

In [None]:
!pip3 install transformers optimum
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/  # Use cu117 if on CUDA 11.7

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/sqlcoder2-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-128g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

In [None]:
from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, max_input_length=3000)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [None]:
# print("\n\n*** Generate:")

# input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# output = model.generate(inputs=input_ids, temperature=0, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))

# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

# print(pipe(prompt_template)[0]['generated_text'])

# Quey generation without longchain

In [53]:
def update_prompt_template(prompt):
  template = f'''
  ## Task
  Generate a SQL query to answer the following question:
  `{prompt}`

  ### Database Schema
  This query will run on a database whose schema is represented in this string:
  CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )


    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )

    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )


    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )


    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )


    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )

    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )

  ### SQL
  Given the database schema, here is the SQL query that answers `{prompt}`:
  ```sql
  '''
  return template


In [99]:
def gen_sql_wto_longchain(prompt):
  updated_prompt = update_prompt_template(prompt)
  sql_query = pipe(updated_prompt)[0]['generated_text'].split("```sql")[-1].strip()
  print(sql_query)
  return query_db(sql_query)

In [67]:
gen_sql_wto_longchain('how many users do we have')



'SELECT COUNT(DISTINCT users.user_id) AS total_users FROM users;'

# Test queries

In [None]:
questions = [
    'What is the total number of active credit cards?',
    "How many transactions were made in the 'Shopping' category?",
    "Which user has the most number of credit cards?",
    "Find the total amount spent on fuel by all users.",
    "What is the oldest branch of the bank?",
    "Find the user with the oldest credit card.",
    "How many rewards were earned for transactions above $1000?",
    "Which category has the lowest average transaction amount?",
    "What is the total outstanding balance for all credit cards?",
    "How many branches have less than 20 employees?",
    "List the top 3 users by total transaction amount.",
    "What is the average number of transactions per user?"
]


answers = ["6", "93", "Michael Baldwin", "23960.57", "Branch 3", "Gabrielle Anderson", "0", "Shopping", "14090.2296532", "2", "Gabrielle Anderson, Michael Baldwin", "166.66666666666666"]

test_data = [{'qury_text':q, 'result': a} for q ,a in zip(questions, answers)]

# Longchain

## longchain Template

In [68]:
template = f'''
  ## Task
  Generate a SQL query to answer the following question:
  `{input}`

  ### Database Schema
  This query will run on a database whose schema is represented in this string:
  CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )


    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )

    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )


    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )


    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )


    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )

    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )

  ### SQL
  Given the database schema, here is the SQL query that answers `{input}`:
  ```sql
  '''

# Integration

In [None]:
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain

In [69]:
llm = HuggingFacePipeline(pipeline=pipe)

## Dummpy template

In [105]:
PROMPT_SUFFIX = """
Only use the following tables:
{table_info}
if the sql query result is None, return the final output as "Details not found"
Limit the final output to the single sentance.

Question: {input}"""

## Chain

In [106]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False,
                                     prompt=PromptTemplate(input_variables=["input", "table_info"],
                                     template=PROMPT_SUFFIX))

In [107]:
db_chain.run('what is the name of user for userid 2')



[1m> Entering new SQLDatabaseChain chain...[0m
what is the name of user for userid 2
SQLQuery:



[32;1m[1;3mSELECT users.user_name FROM users WHERE users.user_id = 2;[0m
SQLResult: [33;1m[1;3m[('Michael Baldwin',)][0m
Answer:



[32;1m[1;3mMichael Baldwin[0m
[1m> Finished chain.[0m


'Michael Baldwin'

In [111]:
db_chain.run(test_data[6]['qury_text'])



[1m> Entering new SQLDatabaseChain chain...[0m
How many rewards were earned for transactions above $1000?
SQLQuery:



[32;1m[1;3mSELECT SUM(reward.points_earned::INTEGER) AS total_rewards FROM transactions JOIN reward ON transactions.transaction_id = reward.transaction_id WHERE transactions.amount > 1000;[0m
SQLResult: [33;1m[1;3m[(None,)][0m
Answer:



[32;1m[1;3m44[0m
[1m> Finished chain.[0m


'44'

In [103]:
template ='''
if the sql query is None, then return the final output as Not found
Restrict the final output to the single line.

'''

In [None]:
db_chain2 = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False,
                                     prompt=PromptTemplate(input_variables=["input", "table_info"],
                                     template=template))

db_chain2.run('what is the name of user for userid 2')

In [None]:
db_chain3 = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False)


db_chain3.run('what is the name of user for userid 2')

In [None]:
gen_sql_wto_longchain('how many cards do we have')

In [None]:
db.run('SELECT SUM(transactions."amount") AS total_amount_spent ON transactions WHERE transactions."transactiontype" LIKE "%fuel%"')

In [None]:
db_chain.run("How many employees are there?")

In [98]:
res_lst = []
failed_cases = []

for i, item in enumerate(test_data):
  print(i)
  print("Question:")
  question = item['qury_text']
  print(question)
  print('--------------------')
  print(item['result'])
  print('--------------------')
  print("Gen_sql_wto_longchain:")
  print(gen_sql_wto_longchain(question))
  print('--------------------')
  print("Longchain1")
  try:
    res = db_chain.run(question)
    print(res)
  except:
    print("Longchain1 Failed")
  print('--------------------')
  print("Longchain3")
  try:
    res = db_chain3.run(question)
    print(res)
  except:
    print("Longchain2 Failed")

  print('***********************')
  print('\n')

0
Question:
What is the total number of active credit cards?
--------------------
6
--------------------
Gen_sql_wto_longchain:




[(6,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total number of active credit cards?
SQLQuery:



[32;1m[1;3mSELECT COUNT(*) AS active_cards FROM credit_card WHERE card_id IN (SELECT card_id FROM credit_card_financial WHERE CURRENT_BALANCE > 0);[0m
SQLResult: [33;1m[1;3m[(6,)][0m
Answer:



[32;1m[1;3mThere are six active credit cards.[0m
[1m> Finished chain.[0m
There are six active credit cards.
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total number of active credit cards?
SQLQuery:



[32;1m[1;3mSELECT COUNT(DISTINCT card_id) AS number_of_cards FROM credit_card WHERE expiry_date > CURRENT_DATE AND status = 'active';[0mLongchain2 Failed
***********************


1
Question:
How many transactions were made in the 'Shopping' category?
--------------------
93
--------------------
Gen_sql_wto_longchain:
[(93,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
How many transactions were made in the 'Shopping' category?
SQLQuery:



[32;1m[1;3mSELECT COUNT(transaction_id) AS number_of_transactions FROM transactions JOIN category ON transactions.category_id = category.category_id WHERE category.category_name ilike '%Shopping%'[0m
SQLResult: [33;1m[1;3m[(93,)][0m
Answer:



[32;1m[1;3m93[0m
[1m> Finished chain.[0m
93
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
How many transactions were made in the 'Shopping' category?
SQLQuery:



[32;1m[1;3mSELECT COUNT(transaction_id) AS number_of_transactions FROM transactions WHERE category_id = (SELECT category_id FROM category WHERE category_name ilike '%Shopping%');[0m
SQLResult: [33;1m[1;3m[(93,)][0m
Answer:



[32;1m[1;3mThere were 93 transactions done within the Shopping category across different branches connected through different cards registered under different users connected via different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through different branches connected through diffe



[32;1m[1;3mSELECT COUNT(DISTINCT card_id::TEXT::INT) AS number_of_cards, user_id::TEXT::INT AS user_id FROM credit_card GROUP BY user_id ORDER BY number_of_cards DESC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[(3, 1)][0m
Answer:



[32;1m[1;3mThe user with id 1 has the most number of credit cards registered - 3[0m
[1m> Finished chain.[0m
The user with id 1 has the most number of credit cards registered - 3
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
Which user has the most number of credit cards?
SQLQuery:



[32;1m[1;3mSELECT users.user_name, COUNT(DISTINCT credit_card.card_id) AS card_count FROM users JOIN credit_card ON users.user_id = credit_card.user_id GROUP BY users.user_name ORDER BY card_count DESC NULLS LAST LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Gabrielle Anderson', 3)][0m
Answer:



[32;1m[1;3mGabrielle Anderson has the most number of credit cards with 3.[0m
[1m> Finished chain.[0m
Gabrielle Anderson has the most number of credit cards with 3.
***********************


3
Question:
Find the total amount spent on fuel by all users.
--------------------
23960.57
--------------------
Gen_sql_wto_longchain:
[(249378.48,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
Find the total amount spent on fuel by all users.
SQLQuery:



[32;1m[1;3mSELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id WHERE category.category_name ILIKE '%fuel%' GROUP BY transactions.user_id;[0mLongchain1 Failed
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
Find the total amount spent on fuel by all users.
SQLQuery:



[32;1m[1;3mSELECT SUM(transactions.amount::FLOAT) AS total_amount FROM transactions JOIN category ON transactions.category_id = category.category_id WHERE category.category_name ilike '%fuel%' AND transactions.transaction_status ILIKE '%success%';[0m
SQLResult: [33;1m[1;3m[(None,)][0m
Answer:



[32;1m[1;3mTotal amount spent during successful transactions categorized under Fuel is $4744.47.[0m
[1m> Finished chain.[0m
Total amount spent during successful transactions categorized under Fuel is $4744.47.
***********************


4
Question:
What is the oldest branch of the bank?
--------------------
Branch 3
--------------------
Gen_sql_wto_longchain:
[('1',)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
What is the oldest branch of the bank?
SQLQuery:



[32;1m[1;3mSELECT MIN(branch::established_date::TEXT::DATE) AS oldest_established_date FROM branch;[0mLongchain1 Failed
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
What is the oldest branch of the bank?
SQLQuery:



[32;1m[1;3mSELECT branch_name FROM branch ORDER BY established_Date ASC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Branch 3',)][0m
Answer:



[32;1m[1;3mBranch 3[0m
[1m> Finished chain.[0m
Branch 3
***********************


5
Question:
Find the user with the oldest credit card.
--------------------
Gabrielle Anderson
--------------------
Gen_sql_wto_longchain:
[('Gabrielle Anderson',)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
Find the user with the oldest credit card.
SQLQuery:



[32;1m[1;3mSELECT * FROM users WHERE user_id IN (SELECT user_id FROM credit_card ORDER BY expiry_date ASC LIMIT 1);[0m
SQLResult: [33;1m[1;3m[(1, 'Gabrielle Anderson', 'jacobmoore@hotmail.com', '778 Compton Pine Suite 171, Dunlapmouth, AL 57452', '2527317877', datetime.date(2010, 5, 29), datetime.date(2020, 1, 27), 'Active', 1)][0m
Answer:



[32;1m[1;3mGabrielle Anderson[0m
[1m> Finished chain.[0m
Gabrielle Anderson
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
Find the user with the oldest credit card.
SQLQuery:



[32;1m[1;3mSELECT * FROM users WHERE user_id IN (SELECT user_id FROM credit_card ORDER BY expiry_date ASC LIMIT 1);[0m
SQLResult: [33;1m[1;3m[(1, 'Gabrielle Anderson', 'jacobmoore@hotmail.com', '778 Compton Pine Suite 171, Dunlapmouth, AL 57452', '2527317877', datetime.date(2010, 5, 29), datetime.date(2020, 1, 27), 'Active', 1)][0m
Answer:



[32;1m[1;3mThe oldest credit card expires on 2024-10-31 expiring for <NAME>.[0m
[1m> Finished chain.[0m
The oldest credit card expires on 2024-10-31 expiring for <NAME>.
***********************


6
Question:
How many rewards were earned for transactions above $1000?
--------------------
0
--------------------
Gen_sql_wto_longchain:
[(0,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
How many rewards were earned for transactions above $1000?
SQLQuery:



[32;1m[1;3mSELECT COUNT(DISTINCT r.reward_id) AS total_rewards FROM reward r JOIN transactions t ON r.transaction_id = t.transaction_id WHERE t.amount > 1000;[0m
SQLResult: [33;1m[1;3m[(0,)][0m
Answer:



[32;1m[1;3m0[0m
[1m> Finished chain.[0m
0
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
How many rewards were earned for transactions above $1000?
SQLQuery:



[32;1m[1;3mSELECT SUM(reward.points_earned::INTEGER) AS total_rewards FROM transactions JOIN reward ON transactions.transaction_id = reward.transaction_id WHERE transactions.amount > 1000;[0m
SQLResult: [33;1m[1;3m[(None,)][0m
Answer:



[32;1m[1;3m44 points earned across different branches.[0m
[1m> Finished chain.[0m
44 points earned across different branches.
***********************


7
Question:
Which category has the lowest average transaction amount?
--------------------
Shopping
--------------------
Gen_sql_wto_longchain:
[('Shopping',)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
Which category has the lowest average transaction amount?
SQLQuery:



[32;1m[1;3mSELECT c.category_name, AVG(t.amount::FLOAT) AS avg_amount FROM transactions t JOIN category c ON t.category_id = c.category_id GROUP BY c.category_name ORDER BY avg_amount ASC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Shopping', 227.74870998628677)][0m
Answer:



[32;1m[1;3mShopping[0m
[1m> Finished chain.[0m
Shopping
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
Which category has the lowest average transaction amount?
SQLQuery:



[32;1m[1;3mSELECT category.category_name, AVG(transactions.amount::NUMERIC) AS avg_amount FROM transactions JOIN category ON transactions.category_id = category.category_id GROUP BY category.category_name ORDER BY avg_amount ASC LIMIT 1;[0m
SQLResult: [33;1m[1;3m[('Shopping', Decimal('227.7487096774193548'))][0m
Answer:



[32;1m[1;3mShopping[0m
[1m> Finished chain.[0m
Shopping
***********************


8
Question:
What is the total outstanding balance for all credit cards?
--------------------
14090.2296532
--------------------
Gen_sql_wto_longchain:
[(14090.2295,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total outstanding balance for all credit cards?
SQLQuery:



[32;1m[1;3mSELECT SUM(cc.current_balance::FLOAT) AS total_outstanding_balance FROM credit_card cc;[0m
SQLResult: [33;1m[1;3m[(14090.229843139648,)][0m
Answer:



[32;1m[1;3m14090.229843[0m
[1m> Finished chain.[0m
14090.229843
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total outstanding balance for all credit cards?
SQLQuery:



[32;1m[1;3mSELECT SUM(cc.current_balance::FLOAT) AS total_outstanding_balance FROM credit_card cc;[0m
SQLResult: [33;1m[1;3m[(14090.229843139648,)][0m
Answer:



[32;1m[1;3mThe total outstanding balance across all credit cards is $14090.22.[0m
[1m> Finished chain.[0m
The total outstanding balance across all credit cards is $14090.22.
***********************


9
Question:
How many branches have less than 20 employees?
--------------------
2
--------------------
Gen_sql_wto_longchain:
[(2,)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
How many branches have less than 20 employees?
SQLQuery:



[32;1m[1;3mSELECT COUNT(DISTINCT branch_id) AS number_of_branches FROM branch WHERE number_of_employees < 20;[0m
SQLResult: [33;1m[1;3m[(2,)][0m
Answer:



[32;1m[1;3mThere are 2 branches with fewer than 20 employees.[0m
[1m> Finished chain.[0m
There are 2 branches with fewer than 20 employees.
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
How many branches have less than 20 employees?
SQLQuery:



[32;1m[1;3mSELECT COUNT(*) AS number_of_branches FROM branch WHERE number_of_employees < 20;[0m
SQLResult: [33;1m[1;3m[(2,)][0m
Answer:



[32;1m[1;3mThere are 2 branches with fewer than 20 employees.[0m
[1m> Finished chain.[0m
There are 2 branches with fewer than 20 employees.
***********************


10
Question:
List the top 3 users by total transaction amount.
--------------------
Gabrielle Anderson, Michael Baldwin
--------------------
Gen_sql_wto_longchain:
[('Gabrielle Anderson', 128059.26), ('Michael Baldwin', 121319.25)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
List the top 3 users by total transaction amount.
SQLQuery:



[32;1m[1;3mSELECT u.user_name, SUM(r.points_earned::INTEGER - r.points_redeemed::INTEGER) AS total_change FROM reward r JOIN transactions t ON r.transaction_id = t.transaction_id JOIN credit_card c ON t.card_id = c.card_id JOIN users u ON c.user_id = u.user_id GROUP BY u.user_name ORDER BY total_change DESC NULLS LAST LIMIT 3;[0m
SQLResult: [33;1m[1;3m[('Gabrielle Anderson', 13359), ('Michael Baldwin', 12503)][0m
Answer:



[32;1m[1;3mGabrielle Anderson, Michael Baldwin[0m
[1m> Finished chain.[0m
Gabrielle Anderson, Michael Baldwin
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
List the top 3 users by total transaction amount.
SQLQuery:



[32;1m[1;3mSELECT u.first_name ||'' || u.last_name AS full_name, SUM(t.amount::FLOAT) AS total_amount FROM users AS u JOIN credit_card AS c ON u.user_id = c.user_id JOIN transactions AS t ON c.card_id = t.card_id GROUP BY u.first_name, u.last_name ORDER BY total_amount DESC NULLS LAST LIMIT 3;[0mLongchain2 Failed
***********************


11
Question:
What is the average number of transactions per user?
--------------------
166.66666666666666
--------------------
Gen_sql_wto_longchain:
[(Decimal('166.6666666666666667'),)]
--------------------
Longchain1


[1m> Entering new SQLDatabaseChain chain...[0m
What is the average number of transactions per user?
SQLQuery:



[32;1m[1;3mSELECT AVG(transactions_per_user) AS avg_transactions_per_user FROM (SELECT card_id, COUNT(transaction_id) AS transactions_per_user FROM transactions GROUP BY card_id) AS subquery;[0m
SQLResult: [33;1m[1;3m[(Decimal('166.6666666666666667'),)][0m
Answer:



[32;1m[1;3m166.6666666666666667[0m
[1m> Finished chain.[0m
166.6666666666666667
--------------------
Longchain3


[1m> Entering new SQLDatabaseChain chain...[0m
What is the average number of transactions per user?
SQLQuery:



[32;1m[1;3mSELECT AVG(transactions_per_user::FLOAT) AS avg_transactions_per_user FROM (SELECT card_id, COUNT(DISTINCT transaction_id) AS transactions_per_user FROM transactions GROUP BY card_id) AS subquery;[0m
SQLResult: [33;1m[1;3m[(166.66666666666666,)][0m
Answer:



[32;1m[1;3mThe average number of transactions per user across branches is 166.66666666666666.[0m
[1m> Finished chain.[0m
The average number of transactions per user across branches is 166.66666666666666.
***********************




In [None]:
import langchain
langchain.debug = False
question = "Few many transactions do we have"

# Run the chain with the question
output = db_chain.run(query=question, few_shot_examples=few_shot_examples)

# # Access the result
# result = output["SQLiteResult"]
# answer = output["Answer"]

# # Print the result and answer
# print("SQLite Result:", result)
# print("Answer:", answer)

In [None]:
# llm_chain.run("What is the total number of active credit cards?")

# Test sqlcoder

In [None]:
questions = [
    'What is the total number of active credit cards?',
    "How many transactions were made in the 'Shopping' category?",
    "Which user has the most number of credit cards?",
    "Find the total amount spent on fuel by all users.",
    "What is the oldest branch of the bank?",
    "Find the user with the oldest credit card.",
    "How many rewards were earned for transactions above $1000?",
    "Which category has the lowest average transaction amount?",
    "What is the total outstanding balance for all credit cards?",
    "How many branches have less than 20 employees?",
    "List the top 3 users by total transaction amount.",
    "What is the average number of transactions per user?"
]


answers = ["6", "93", "Michael Baldwin", "23960.57", "Branch 3", "Gabrielle Anderson", "0", "Shopping", "14090.2296532", "2", "Gabrielle Anderson, Michael Baldwin", "166.66666666666666"]

test_data = [{'qury_text':q, 'result': a} for q ,a in zip(questions, answers)]

In [None]:
prompt = "What is the total number of active credit cards?"
prompt_template = update_prompt_template(prompt)
# print(prompt_template)

In [None]:
with open('text_sql/few_shot_examples', 'rb') as f:
  few_shot_examples = pickle.load(f)

In [None]:
len(few_shots)

In [None]:
few_shots[0]

In [None]:
for i, item in enumerate(few_shots):
  print(i)
  question = item['Question']
  print('Question:')
  print(question)
  print('------------------------------------------')
  print('SQL query:')
  print(item['SQLQuery'])
  print('------------------------------------------')
  print('Expected result:')
  print(item['Answer'])
  print('------------------------------------------')
  print('sqlcoder_34b:')
  try:
    sql_query_formatted = format_to_sqlite(item['sqlcoder_34b'])
    print(sql_query_formatted)
    print('------')
    print(query_db(sql_query_formatted))
    print('------------------------------------------')
  except:
    print('Failed', i)
  print('sqlcoder2 result:')
  try:
    prompt_template = update_prompt_template(question)
    sql_query = llm_pipe(prompt_template)
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(query_db(sql_query))
  except:
    print('Failed', i)
  print("Longchain:")
  try:
    sql_query = chain.invoke({"question": question})
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(db.run(sql_query))
    print('------')
  except:
    print('Failed', i)
  print('****************************************************************************************')
  print('\n')

In [None]:
sql_query

In [None]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run( "What is the total number of  credit cards?")

In [None]:
dir(db_chain)

In [None]:
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)

In [None]:
response = chain.invoke({"question": "What is the total number of active credit cards?"})
sql_query_formatted = format_to_sqlite(query)
print(sql_query_formatted)
query_db(sql_query_formatted)

In [None]:
db.run(sql_query_formatted)

In [None]:
# from langchain.llms import GooglePalm

# api_key = 'AIzaSyCEP46MCrbkUR0AENTGOJzqYRXTm6NUd7Q'

# llm = GooglePalm(google_api_key=api_key, temperature=0.2)

In [None]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run("Find the total amount spent on fuel by all users.")

# Few shot learning

We will use few shot learning to fix issues we have seen so far

### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

to_vectorize = [" ".join(example.values()) for example in few_shots]

In [None]:
to_vectorize

In [None]:
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

In [None]:
example_selector.select_examples({"Question": "How many total rewards do we have"})

In [None]:
sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today". Ensure final query is converted into sqlite.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here"""

In [None]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX

In [None]:
print(PROMPT_SUFFIX)

In [None]:
from langchain.chains.sql_database import prompt

In [None]:
dir(prompt)

In [None]:
from langchain.chains.sql_database.prompt import SQLITE_PROMPT

In [None]:
print(SQLITE_PROMPT)

In [None]:
print('You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}')

### Setting up PromptTemplete using input variables

In [None]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

In [None]:
print(PROMPT_SUFFIX)

In [None]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=sqlite_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [None]:
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
llm('''

Convert this into SQLITE query
SELECT COUNT(DISTINCT TransactionId) AS NumberOfTransactions, SUM(Amount) AS TotalSpentFROM transactions JOIN credit_card ON transactions.cardid::integer = credit_card.cardid::integer WHERE credit_card.userid::integer = 1 AND transactions.categoryid::integer = (SELECT categoryid::integer FROM category WHERE categoryname::text ilike '%food%');
''')

In [None]:
db_chain2 = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)

In [None]:
db_chain2.run("Which user has the most number of credit cards?")

In [None]:
query_db('SELECT u."username", COUNT(cc."cardid") AS card_count FROM "user" u JOIN credit_card cc ON u."userid" = cc."userid" GROUP BY u."username" ORDER BY card_count DESC LIMIT 1')

In [None]:
resp = db_chain2.run("What is the total amount of transactions completed in the last month")

# Test

In [None]:
sql_test = ([
    {'Question': "What is the total number of active credit cards?",
     'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE ExpiryDate > CURRENT_DATE;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total number of active credit cards'},
    {'Question': "How many transactions were made in the 'Shopping' category?",
     'SQLQuery': "SELECT COUNT(*) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Shopping');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of transactions in Shopping category'},
    {'Question': "Which user has the most number of credit cards?",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId GROUP BY u.UserId ORDER BY COUNT(cc.CardId) DESC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the most number of credit cards'},
    {'Question': "Find the total amount spent on fuel by all users.",
     'SQLQuery': "SELECT SUM(Amount) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Fuel');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total amount spent on fuel'},
    {'Question': "What is the oldest branch of the bank?",
     'SQLQuery': "SELECT BranchName FROM branch WHERE EstablishedDate = (SELECT MIN(EstablishedDate) FROM branch);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Oldest branch of the bank'},
    {'Question': "Find the user with the oldest credit card.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId WHERE cc.IssueDate = (SELECT MIN(IssueDate) FROM credit_card);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the oldest credit card'},
    {'Question': "How many rewards were earned for transactions above $1000?",
     'SQLQuery': "SELECT COUNT(r.RewardId) FROM reward r JOIN transactions t ON r.TransactionId = t.TransactionId WHERE t.Amount > 1000;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of rewards for transactions above $1000'},
    # {'Question': "List all users who have made a transaction in the last month.",
    #  'SQLQuery': "SELECT DISTINCT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId WHERE t.TransactionDate >= CURRENT_DATE - INTERVAL '1 month';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'List of users with transactions in the last month'},
    {'Question': "Which category has the lowest average transaction amount?",
     'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY AVG(t.Amount) ASC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Category with the lowest average transaction amount'},
    {'Question': "What is the total outstanding balance for all credit cards?",
     'SQLQuery': "SELECT SUM(CurrentBalance) FROM credit_card;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total outstanding balance for all credit cards'},
    {'Question': "How many branches have less than 20 employees?",
     'SQLQuery': "SELECT COUNT(*) FROM branch WHERE NumberOfEmployees < 20;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of branches with less than 20 employees'},
    {'Question': "List the top 3 users by total transaction amount.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId GROUP BY u.UserId ORDER BY SUM(t.Amount) DESC LIMIT 3;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Top 3 users by total transaction amount'},
    # {'Question': "Find the most popular transaction category.",
    #  'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY COUNT(*) DESC LIMIT 1;",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Most popular transaction category'},
    # {'Question': "How many credit cards were issued in the last year?",
    #  'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE IssueDate >= CURRENT_DATE - INTERVAL '1 year';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Number of credit cards issued in the last year'},
    {'Question': "What is the average number of transactions per user?",
     'SQLQuery': "SELECT AVG(TransactionCount) FROM (SELECT COUNT(*) as TransactionCount FROM transactions GROUP BY CardId) as TransactionPerUser;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Average number of transactions per user'}
])

In [None]:
len(sql_test)