<a href="https://colab.research.google.com/github/loni9164/text_sql/blob/main/sql_starcoder_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Create PostgreSQL

In [None]:
!apt-get -qq update
!apt-get -qq -y install postgresql
!service postgresql start

In [None]:
!pip install psycopg2-binary

In [None]:
!service postgresql restart

In [None]:
!sudo -u postgres psql

In [None]:
# ALTER USER postgres PASSWORD '12345';

In [None]:
import psycopg2
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",  # Use the correct password here
    dbname="postgres"
)
conn.autocommit = True
cursor = conn.cursor()

In [None]:
# Creating a new database
cursor.execute("CREATE DATABASE credit_card_system")
cursor.close()
conn.close()

In [None]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [None]:
# SQL statements to create tables
create_table_statements = [
    """
    CREATE TABLE branch (
        "BranchId" INTEGER,
        "BranchName" TEXT,
        "BranchAddress" TEXT,
        "BranchPhone" TEXT,
        "BranchManager" TEXT,
        "BranchEmail" TEXT,
        "EstablishedDate" DATE,
        "NumberOfEmployees" INTEGER,
        PRIMARY KEY ("BranchId")
    )
    """,
    """
    CREATE TABLE category (
        "CategoryId" INTEGER,
        "CategoryName" TEXT,
        PRIMARY KEY ("CategoryId")
    )
    """
    ,

    """
    CREATE TABLE users (
    "UserId" INTEGER,
    "UserName" TEXT,
    "UserEmail" TEXT,
    "UserAddress" TEXT,
    "UserPhone" TEXT,
    "DateOfBirth" DATE,
    "RegistrationDate" DATE,
    "Status" TEXT,
    "BranchId" INTEGER,
    PRIMARY KEY ("UserId"),
    FOREIGN KEY("BranchId") REFERENCES branch ("BranchId")
    )
    """
    ,
    """
    CREATE TABLE credit_card (
    "CardId" INTEGER,
    "UserId" INTEGER,
    "CardNumber" TEXT,
    "CardType" TEXT,
    "ExpiryDate" DATE,
    "CVV" INTEGER,
    "IssueDate" DATE,
    "CreditLimit" REAL,
    "CurrentBalance" REAL,
    "StatementBalance" REAL,
    PRIMARY KEY ("CardId"),
    FOREIGN KEY("UserId") REFERENCES users ("UserId")
  )
    """,
    """
    CREATE TABLE transactions (
        "TransactionId" INTEGER,
        "CardId" INTEGER,
        "TransactionDate" DATE,
        "Amount" REAL,
        "Merchant" TEXT,
        "CategoryId" INTEGER,
        "TransactionType" TEXT,
        "TransactionStatus" TEXT,
        "Description" TEXT,
        PRIMARY KEY ("TransactionId"),
        FOREIGN KEY("CardId") REFERENCES credit_card ("CardId"),
        FOREIGN KEY("CategoryId") REFERENCES category ("CategoryId")
    )
    """,
    """
    CREATE TABLE credit_card_financial (
        "FinancialId" INTEGER,
        "CardId" INTEGER,
        "OverdueCharges" REAL,
        "LoanAmount" REAL,
        "EMIAmount" REAL,
        "EMIDueDate" DATE,
        "InterestRate" REAL,
        "PaymentDueDate" DATE,
        "MinimumPayment" REAL,
        PRIMARY KEY ("FinancialId"),
        FOREIGN KEY("CardId") REFERENCES credit_card ("CardId")
    )
    """,
    """
    CREATE TABLE reward (
        "RewardId" INTEGER,
        "TransactionId" INTEGER,
        "PointsEarned" INTEGER,
        "PointsRedeemed" INTEGER,
        "CurrentBalance" INTEGER,
        PRIMARY KEY ("RewardId"),
        FOREIGN KEY("TransactionId") REFERENCES transactions ("TransactionId")
    )
    """
]


# Execute each CREATE TABLE statement
for statement in create_table_statements:
    cursor.execute(statement)

conn.commit()

In [None]:
import pandas as pd

# Function to load data from CSV to a table
def load_csv_to_table(csv_file_path, table_name):
    data = pd.read_csv(csv_file_path)
    for i, row in data.iterrows():
        insert_query = "INSERT INTO {} VALUES %s".format(table_name)
        cursor.execute(insert_query, (tuple(row),))

In [None]:
# Load data from CSV files in the correct order
load_csv_to_table('branch.csv', 'branch')
load_csv_to_table('category.csv', 'category')
load_csv_to_table('user.csv', 'users')
load_csv_to_table('credit_card.csv', 'credit_card')
load_csv_to_table('transactions.csv', 'transactions')
load_csv_to_table('credit_card_financial.csv', 'credit_card_financial')
load_csv_to_table('reward.csv', 'reward')

conn.commit()

In [60]:
def query_db(query):
  cursor = conn.cursor()
  cursor.execute(query)
  return cursor.fetchall()
  cursor.close()

In [None]:
query_db("SELECT * from users")

# Installation

In [None]:
import sqlite3
import time
import pickle
import re

In [None]:
!git clone https://github.com/loni9164/text_sql.git

In [None]:
!pip install langchain langchain-experimental
!pip install -q  langchain
!pip install sentence-transformers
!pip install chromadb

In [None]:
# from langchain.llms import CTransformers
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain

In [None]:
db_path = 'text_sql/credit_card_system.db'

In [None]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [None]:
db_user = "postgres"
db_password = "12345"
db_host = "localhost"
db_name = "credit_card_system"

connection_string = f"postgresql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(connection_string)
table_info = db.table_info
# print(table_info)

# Load LLM model

In [None]:
!pip3 install transformers optimum
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/  # Use cu117 if on CUDA 11.7

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/sqlcoder2-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-128g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

In [None]:
from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, max_input_length=3000)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

In [None]:
# print("\n\n*** Generate:")

# input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# output = model.generate(inputs=input_ids, temperature=0, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))

# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

# print(pipe(prompt_template)[0]['generated_text'])

# Query functions and Column mapping

In [36]:
def query_db(query):
  # Connect to SQLite database
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()

  # Execute a query
  cursor.execute(query)

  # Fetch and print results
  return cursor.fetchall()

In [37]:
def get_column_names(db_path):
    columns_list = []
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get the list of all tables in the database
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    # For each table, get the column names
    for table in tables:
        # print(f"Columns for table {table[0]}:")
        cursor.execute(f"PRAGMA table_info({table[0]});")
        columns = cursor.fetchall()
        for column in columns:
          # print(column[1])
          columns_list.append(column[1])
          # print(column[1])  # column name is in the second position
    conn.close()
    return columns_list

In [38]:
db_columns = get_column_names(db_path)
db_columns = [x.lower() for x in db_columns]

NameError: ignored

In [None]:
db_columns_map = ['branch_id',
 'branch_name',
 'branch_address',
 'branch_phone',
 'branch_manager',
 'branch_email',
 'established_date',
 'number_of_employees',
 'user_id',
 'user_name',
 'user_email',
 'user_address',
 'user_phone',
 'date_of_birth',
 'registration_date',
 'status',
 'branch_id',
 'category_id',
 'category_name',
 'card_id',
 'user_id',
 'card_number',
 'card_type',
 'expiry_date',
 'cvv',
 'issue_date',
 'credit_limit',
 'current_balance',
 'statement_balance',
 'transaction_id',
 'card_id',
 'transaction_date',
 'amount',
 'merchant',
 'category_id',
 'transaction_type',
 'transaction_status',
 'description',
 'financial_id',
 'card_id',
 'overdue_charges',
 'loan_amount',
 'emi_amount',
 'emi_duedate',
 'interest_rate',
 'payment_duedate',
 'minimum_payment',
 'reward_id',
 'transaction_id',
 'points_earned',
 'points_redeemed',
 'current_balance']

In [None]:
column_map = {x:y for x, y in zip(db_columns_map, db_columns)}

In [None]:
def format_to_sqlite(query):
    # Example conversions
    query = re.sub(r'[\s\n]+', ' ', query)
    query = query.replace('ILIKE', 'LIKE')
    query = query.replace('ilike', 'LIKE')
    query = query.replace('::text', '')
    query = query.replace('::integer', '')
    query = query.replace('boolean_expression', 'case when boolean_expression then 1 else 0 end')
    query = query.replace("EXTRACT(YEAR FROM", "strftime('%Y',")
    for col_query, col_db in column_map.items():
        if col_query in query:
            query = query.replace(col_query, col_db)
    # Add more conversions as per your need
    query = query.strip()
    return query

# Test queries

In [39]:
questions = [
    'What is the total number of active credit cards?',
    "How many transactions were made in the 'Shopping' category?",
    "Which user has the most number of credit cards?",
    "Find the total amount spent on fuel by all users.",
    "What is the oldest branch of the bank?",
    "Find the user with the oldest credit card.",
    "How many rewards were earned for transactions above $1000?",
    "Which category has the lowest average transaction amount?",
    "What is the total outstanding balance for all credit cards?",
    "How many branches have less than 20 employees?",
    "List the top 3 users by total transaction amount.",
    "What is the average number of transactions per user?"
]


answers = ["6", "93", "Michael Baldwin", "23960.57", "Branch 3", "Gabrielle Anderson", "0", "Shopping", "14090.2296532", "2", "Gabrielle Anderson, Michael Baldwin", "166.66666666666666"]

test_data = [{'qury_text':q, 'result': a} for q ,a in zip(questions, answers)]

# longchain Template

In [40]:
PROMPT_SUFFIX='''## Task
  Generate a SQlite query to answer the following question:
  {input}
  please Note that columns names
  ### Database Schema
  ### Refer below database Schema for columns name
	CREATE TABLE branch (
		"BranchId" INTEGER,
		"BranchName" TEXT,
		"BranchAddress" TEXT,
		"BranchPhone" TEXT,
		"BranchManager" TEXT,
		"BranchEmail" TEXT,
		"EstablishedDate" DATE,
		"NumberOfEmployees" INTEGER,
		PRIMARY KEY ("BranchId")
	)


	CREATE TABLE category (
		"CategoryId" INTEGER,
		"CategoryName" TEXT,
		PRIMARY KEY ("CategoryId")
	)


	CREATE TABLE credit_card (
		"CardId" INTEGER,
		"UserId" INTEGER,
		"CardNumber" TEXT,
		"CardType" TEXT,
		"ExpiryDate" DATE,
		"CVV" INTEGER,
		"IssueDate" DATE,
		"CreditLimit" REAL,
		"CurrentBalance" REAL,
		"StatementBalance" REAL,
		PRIMARY KEY ("CardId"),
		FOREIGN KEY("UserId") REFERENCES user ("UserId")
	)


	CREATE TABLE credit_card_financial (
		"FinancialId" INTEGER,
		"CardId" INTEGER,
		"OverdueCharges" REAL,
		"LoanAmount" REAL,
		"EMIAmount" REAL,
		"EMIDueDate" DATE,
		"InterestRate" REAL,
		"PaymentDueDate" DATE,
		"MinimumPayment" REAL,
		PRIMARY KEY ("FinancialId"),
		FOREIGN KEY("CardId") REFERENCES credit_card ("CardId")
	)


	CREATE TABLE reward (
		"RewardId" INTEGER,
		"TransactionId" INTEGER,
		"PointsEarned" INTEGER,
		"PointsRedeemed" INTEGER,
		"CurrentBalance" INTEGER,
		PRIMARY KEY ("RewardId"),
		FOREIGN KEY("TransactionId") REFERENCES transactions ("TransactionId")
	)


	CREATE TABLE transactions (
		"TransactionId" INTEGER,
		"CardId" INTEGER,
		"TransactionDate" DATE,
		"Amount" REAL,
		"Merchant" TEXT,
		"CategoryId" INTEGER,
		"TransactionType" TEXT,
		"TransactionStatus" TEXT,
		"Description" TEXT,
		PRIMARY KEY ("TransactionId"),
		FOREIGN KEY("CardId") REFERENCES credit_card ("CardId"),
		FOREIGN KEY("CategoryId") REFERENCES category ("CategoryId")
	)


	CREATE TABLE user (
		"UserId" INTEGER, -- Unique ID for each user
		"UserName" TEXT, -- Name of the user
		"UserEmail" TEXT, -- Email ID of the user
		"UserAddress" TEXT, -- User address
		"UserPhone" TEXT, -- User phone number
		"DateOfBirth" DATE,
		"RegistrationDate" DATE,
		"Status" TEXT,
		"BranchId" INTEGER,
		PRIMARY KEY ("UserId"),
		FOREIGN KEY("BranchId") REFERENCES branch ("BranchId")
	)


  ### SQL
	Ensure the output is traslated to the SQLite query.
	Correct the column name in the final query
  Given the database schema, here is the SQL query that answers `{input}`:
  ```sql
  '''

# Longchain integration

In [41]:
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain

In [42]:
llm = HuggingFacePipeline(pipeline=pipe)
# from langchain.chains import create_sql_query_chain
# chain = create_sql_query_chain(llm, db)

In [44]:
with open('text_sql/few_shot_examples', 'rb') as f:
  few_shot_examples = pickle.load(f)

In [45]:
few_shot_examples[11]

{'Question': 'Which category has the highest number of transactions?',
 'SQLQuery': 'SELECT CategoryId, COUNT(*) as NumberOfTransactions FROM transactions GROUP BY CategoryId ORDER BY NumberOfTransactions DESC LIMIT 1;',
 'SQLResult': [(5, 117)],
 'Answer': '5'}

In [46]:
from langchain.prompts import PromptTemplate

template = """Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Use the following format:

Question: "Question here"
SQLiteQuery: "SQLite Query to run"
SQLiteResult: "Result of the SQLite Query"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQLite queries that correspond to questions are:

{few_shot_examples}

Question: {query}"""

prompt_template = PromptTemplate(
    input_variables=["query", "few_shot_examples", "table_info"],
    template=template,
)

In [47]:
from langchain.prompts import PromptTemplate

template = """Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
Use the following format:

The output of LLM is PostgreSQL, please convert this into the SQLite query as our database SQLite

Question: "Question here"
SQLiteQuery: "SQLite Query to run"
SQLiteResult: "Result of the SQLite Query"
Answer: "Final answer here"

Only use the following tables:

{table_info}.

Some examples of SQLite queries that correspond to questions are:

{few_shot_examples}

Question: {input}"""

prompt_template = PromptTemplate(
    input_variables=["input", "few_shot_examples", "table_info"],
    template=template,
)

In [77]:
PROMPT_SUFFIX = """Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer.
Use the following format:

Only use the following tables:

	CREATE TABLE branch (
        "BranchId" INTEGER,
        "BranchName" TEXT,
        "BranchAddress" TEXT,
        "BranchPhone" TEXT,
        "BranchManager" TEXT,
        "BranchEmail" TEXT,
        "EstablishedDate" DATE,
        "NumberOfEmployees" INTEGER,
        PRIMARY KEY ("BranchId")
    ),
    CREATE TABLE category (
        "CategoryId" INTEGER,
        "CategoryName" TEXT,
        PRIMARY KEY ("CategoryId")
    ),

    CREATE TABLE users (
    "UserId" INTEGER,
    "UserName" TEXT,
    "UserEmail" TEXT,
    "UserAddress" TEXT,
    "UserPhone" TEXT,
    "DateOfBirth" DATE,
    "RegistrationDate" DATE,
    "Status" TEXT,
    "BranchId" INTEGER,
    PRIMARY KEY ("UserId"),
    FOREIGN KEY("BranchId") REFERENCES branch ("BranchId")
    ),

    CREATE TABLE credit_card (
    "CardId" INTEGER,
    "UserId" INTEGER,
    "CardNumber" TEXT,
    "CardType" TEXT,
    "ExpiryDate" DATE,
    "CVV" INTEGER,
    "IssueDate" DATE,
    "CreditLimit" REAL,
    "CurrentBalance" REAL,
    "StatementBalance" REAL,
    PRIMARY KEY ("CardId"),
    FOREIGN KEY("UserId") REFERENCES users ("UserId")
  ),
    CREATE TABLE transactions (
        "TransactionId" INTEGER,
        "CardId" INTEGER,
        "TransactionDate" DATE,
        "Amount" REAL,
        "Merchant" TEXT,
        "CategoryId" INTEGER,
        "TransactionType" TEXT,
        "TransactionStatus" TEXT,
        "Description" TEXT,
        PRIMARY KEY ("TransactionId"),
        FOREIGN KEY("CardId") REFERENCES credit_card ("CardId"),
        FOREIGN KEY("CategoryId") REFERENCES category ("CategoryId")
    ),
    CREATE TABLE credit_card_financial (
        "FinancialId" INTEGER,
        "CardId" INTEGER,
        "OverdueCharges" REAL,
        "LoanAmount" REAL,
        "EMIAmount" REAL,
        "EMIDueDate" DATE,
        "InterestRate" REAL,
        "PaymentDueDate" DATE,
        "MinimumPayment" REAL,
        PRIMARY KEY ("FinancialId"),
        FOREIGN KEY("CardId") REFERENCES credit_card ("CardId")
    ),
    CREATE TABLE reward (
        "RewardId" INTEGER,
        "TransactionId" INTEGER,
        "PointsEarned" INTEGER,
        "PointsRedeemed" INTEGER,
        "CurrentBalance" INTEGER,
        PRIMARY KEY ("RewardId"),
        FOREIGN KEY("TransactionId") REFERENCES transactions ("TransactionId")
    )

Question: {input}"""

In [86]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=True,
                                     prompt=PromptTemplate(input_variables=["input", "table_info"],
                                     template=PROMPT_SUFFIX))

In [55]:
test_data[3]['qury_text']

'Find the total amount spent on fuel by all users.'

In [87]:
db_chain('how many cards do we have')



[1m> Entering new SQLDatabaseChain chain...[0m
how many cards do we have
SQLQuery:




[1m> Finished chain.[0m


{'query': 'how many cards do we have',
 'result': 'SELECT COUNT(DISTINCT cardid) AS number_of_cards FROM credit_card;'}

In [97]:
res_lst

['SELECT COUNT(DISTINCT cardid) AS active_cards FROM credit_card WHERE expirydate > CURRENT_DATE AND currentbalance > 0;',
 "SELECT COUNT(transactionid) AS number_of_transactions FROM transactions JOIN category ON transactions.categoryid = category.categoryid WHERE category.categoryname ilike '%Shopping%'",
 'SELECT users."userid", COUNT(credit_card."cardid") AS card_count FROM users JOIN credit_card ON users."userid" = credit_card."userid" GROUP BY users."userid" ORDER BY card_count DESC LIMIT 1;',
 "SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.categoryid = category.categoryid WHERE category.categoryname ilike '%fuel%';",
 'SELECT MIN(branch."branchid") AS id, branch."branchname", branch."branchaddress", branch."branchphone", branch."manager", branch."email", branch."establisheddate", branch."numberofemployees" FROM branch;',
 'SELECT * FROM users WHERE userid IN (SELECT MIN(userid) AS min_user_id FROM credit_card);',
 'SELECT SUM(rew

In [124]:
cursor.execute("ROLLBACK")
query_db('''
SELECT COUNT(*)
FROM credit_card c
WHERE c.expirydate > CURRENT_DATE;
''')

UndefinedColumn: ignored

In [88]:
db_chain(test_data[3]['qury_text'])



[1m> Entering new SQLDatabaseChain chain...[0m
Find the total amount spent on fuel by all users.
SQLQuery:




[1m> Finished chain.[0m


{'query': 'Find the total amount spent on fuel by all users.',
 'result': "SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.categoryid = category.categoryid WHERE category.categoryname ilike '%fuel%';"}

In [None]:
db.run('SELECT SUM(transactions."amount") AS total_amount_spent ON transactions WHERE transactions."transactiontype" LIKE "%fuel%"')

In [None]:
db_chain.run("How many employees are there?")

In [96]:
res_lst = []

for item in test_data:
  print("Question:")
  question = item['qury_text']
  print(question)
  print('--------------------')
  print(item['result'])
  print('--------------------')
  res = db_chain.run(question)
  res_lst.append(res)
  print(res)
  print('***********************')
  print('\n')

Question:
What is the total number of active credit cards?
--------------------
6
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total number of active credit cards?
SQLQuery:




[1m> Finished chain.[0m
SELECT COUNT(DISTINCT cardid) AS active_cards FROM credit_card WHERE expirydate > CURRENT_DATE AND currentbalance > 0;
***********************


Question:
How many transactions were made in the 'Shopping' category?
--------------------
93
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
How many transactions were made in the 'Shopping' category?
SQLQuery:




[1m> Finished chain.[0m
SELECT COUNT(transactionid) AS number_of_transactions FROM transactions JOIN category ON transactions.categoryid = category.categoryid WHERE category.categoryname ilike '%Shopping%'
***********************


Question:
Which user has the most number of credit cards?
--------------------
Michael Baldwin
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
Which user has the most number of credit cards?
SQLQuery:




[1m> Finished chain.[0m
SELECT users."userid", COUNT(credit_card."cardid") AS card_count FROM users JOIN credit_card ON users."userid" = credit_card."userid" GROUP BY users."userid" ORDER BY card_count DESC LIMIT 1;
***********************


Question:
Find the total amount spent on fuel by all users.
--------------------
23960.57
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
Find the total amount spent on fuel by all users.
SQLQuery:




[1m> Finished chain.[0m
SELECT SUM(transactions.amount) AS total_amount FROM transactions JOIN category ON transactions.categoryid = category.categoryid WHERE category.categoryname ilike '%fuel%';
***********************


Question:
What is the oldest branch of the bank?
--------------------
Branch 3
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
What is the oldest branch of the bank?
SQLQuery:




[1m> Finished chain.[0m
SELECT MIN(branch."branchid") AS id, branch."branchname", branch."branchaddress", branch."branchphone", branch."manager", branch."email", branch."establisheddate", branch."numberofemployees" FROM branch;
***********************


Question:
Find the user with the oldest credit card.
--------------------
Gabrielle Anderson
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
Find the user with the oldest credit card.
SQLQuery:




[1m> Finished chain.[0m
SELECT * FROM users WHERE userid IN (SELECT MIN(userid) AS min_user_id FROM credit_card);
***********************


Question:
How many rewards were earned for transactions above $1000?
--------------------
0
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
How many rewards were earned for transactions above $1000?
SQLQuery:




[1m> Finished chain.[0m
SELECT SUM(reward.points_earned) AS total_rewards FROM reward JOIN transactions ON reward.transactionid = transactions.transactionid WHERE transactions.amount > 1000;
***********************


Question:
Which category has the lowest average transaction amount?
--------------------
Shopping
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
Which category has the lowest average transaction amount?
SQLQuery:




[1m> Finished chain.[0m
SELECT category.categoryname, AVG(transactions.amount::FLOAT) AS avg_amount FROM transactions JOIN category ON transactions.categoryid = category.categoryid GROUP BY category.categoryname ORDER BY avg_amount ASC LIMIT 1;
***********************


Question:
What is the total outstanding balance for all credit cards?
--------------------
14090.2296532
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
What is the total outstanding balance for all credit cards?
SQLQuery:




[1m> Finished chain.[0m
SELECT SUM(cc.currentbalance::FLOAT) AS total_outstanding_balance FROM credit_card cc;
***********************


Question:
How many branches have less than 20 employees?
--------------------
2
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
How many branches have less than 20 employees?
SQLQuery:




[1m> Finished chain.[0m
SELECT COUNT(DISTINCT b."branchid") AS number_of_branches FROM branch b JOIN users u ON CAST(b."branchmanager" AS integer) = u."userid" WHERE CAST(u."status" AS integer) < 2 AND CAST(b."numberofemployees" AS integer) < 20;
***********************


Question:
List the top 3 users by total transaction amount.
--------------------
Gabrielle Anderson, Michael Baldwin
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
List the top 3 users by total transaction amount.
SQLQuery:




[1m> Finished chain.[0m
SELECT u."user_name", SUM(r.points_earned::integer) AS points_earned FROM users u JOIN reward r ON CAST(u.user_id::text AS integer) = r.transactionid GROUP BY u."user_name" ORDER BY points_earned DESC NULLS LAST LIMIT 3;
***********************


Question:
What is the average number of transactions per user?
--------------------
166.66666666666666
--------------------


[1m> Entering new SQLDatabaseChain chain...[0m
What is the average number of transactions per user?
SQLQuery:




[1m> Finished chain.[0m
SELECT AVG(transactions) AS avg_transactions FROM (SELECT cardid::integer, COUNT(transactionid::integer) AS transactions FROM transactions GROUP BY cardid::integer) AS sub;
***********************




In [None]:
import langchain
langchain.debug = False
question = "Few many transactions do we have"

# Run the chain with the question
output = db_chain.run(query=question, few_shot_examples=few_shot_examples)

# # Access the result
# result = output["SQLiteResult"]
# answer = output["Answer"]

# # Print the result and answer
# print("SQLite Result:", result)
# print("Answer:", answer)

In [None]:
# llm_chain.run("What is the total number of active credit cards?")

# Columns_map

In [None]:
def get_column_names(db_path):
    columns_list = []
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    # Get the list of all tables in the database
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    # For each table, get the column names
    for table in tables:
        # print(f"Columns for table {table[0]}:")
        cursor.execute(f"PRAGMA table_info({table[0]});")
        columns = cursor.fetchall()
        for column in columns:
          # print(column[1])
          columns_list.append(column[1])
          # print(column[1])  # column name is in the second position
    conn.close()
    return columns_list

In [None]:
db_columns = get_column_names(db_path)
db_columns = [x.lower() for x in db_columns]

In [None]:
db_columns_map = ['branch_id',
 'branch_name',
 'branch_address',
 'branch_phone',
 'branch_manager',
 'branch_email',
 'established_date',
 'number_of_employees',
 'user_id',
 'user_name',
 'user_email',
 'user_address',
 'user_phone',
 'date_of_birth',
 'registration_date',
 'status',
 'branch_id',
 'category_id',
 'category_name',
 'card_id',
 'user_id',
 'card_number',
 'card_type',
 'expiry_date',
 'cvv',
 'issue_date',
 'credit_limit',
 'current_balance',
 'statement_balance',
 'transaction_id',
 'card_id',
 'transaction_date',
 'amount',
 'merchant',
 'category_id',
 'transaction_type',
 'transaction_status',
 'description',
 'financial_id',
 'card_id',
 'overdue_charges',
 'loan_amount',
 'emi_amount',
 'emi_duedate',
 'interest_rate',
 'payment_duedate',
 'minimum_payment',
 'reward_id',
 'transaction_id',
 'points_earned',
 'points_redeemed',
 'current_balance']

In [None]:
column_map = {x:y for x, y in zip(db_columns_map, db_columns)}

# Test sqlcoder

In [None]:
def query_db(query):
  # Connect to SQLite database
  conn = sqlite3.connect(db_path)
  cursor = conn.cursor()

  # Execute a query
  cursor.execute(query)

  # Fetch and print results
  return cursor.fetchall()

In [None]:
questions = [
    'What is the total number of active credit cards?',
    "How many transactions were made in the 'Shopping' category?",
    "Which user has the most number of credit cards?",
    "Find the total amount spent on fuel by all users.",
    "What is the oldest branch of the bank?",
    "Find the user with the oldest credit card.",
    "How many rewards were earned for transactions above $1000?",
    "Which category has the lowest average transaction amount?",
    "What is the total outstanding balance for all credit cards?",
    "How many branches have less than 20 employees?",
    "List the top 3 users by total transaction amount.",
    "What is the average number of transactions per user?"
]


answers = ["6", "93", "Michael Baldwin", "23960.57", "Branch 3", "Gabrielle Anderson", "0", "Shopping", "14090.2296532", "2", "Gabrielle Anderson, Michael Baldwin", "166.66666666666666"]

test_data = [{'qury_text':q, 'result': a} for q ,a in zip(questions, answers)]

In [None]:
def format_to_sqlite(query):
    # Example conversions
    query = re.sub(r'[\s\n]+', ' ', query)
    query = query.replace('ILIKE', 'LIKE')
    query = query.replace('ilike', 'LIKE')
    query = query.replace('::text', '')
    query = query.replace('::integer', '')
    query = query.replace('boolean_expression', 'case when boolean_expression then 1 else 0 end')
    query = query.replace("EXTRACT(YEAR FROM", "strftime('%Y',")
    for col_query, col_db in column_map.items():
        if col_query in query:
            query = query.replace(col_query, col_db)
    # Add more conversions as per your need
    query = query.strip()
    return query

In [None]:
sql_query ='''
SELECT u.user_name,
       COUNT(c.card_id) AS card_count
FROM user u
JOIN credit_card c ON u.user_id = c.user_id
GROUP BY u.user_name
ORDER BY card_count DESC
LIMIT 1;
'''
print(sql_query)

In [None]:
sql_query_formatted = format_to_sqlite(sql_query)
print(sql_query_formatted)

In [None]:
print(query_db(sql_query_formatted))

In [None]:
prompt = "What is the total number of active credit cards?"
prompt_template = update_prompt_template(prompt)
# print(prompt_template)

In [None]:
with open('/content/few_shots.pkl', 'rb') as f:
  few_shots = pickle.load(f)

In [None]:
len(few_shots)

In [None]:
few_shots[0]

In [None]:
for i, item in enumerate(few_shots):
  print(i)
  question = item['Question']
  print('Question:')
  print(question)
  print('------------------------------------------')
  print('SQL query:')
  print(item['SQLQuery'])
  print('------------------------------------------')
  print('Expected result:')
  print(item['Answer'])
  print('------------------------------------------')
  print('sqlcoder_34b:')
  try:
    sql_query_formatted = format_to_sqlite(item['sqlcoder_34b'])
    print(sql_query_formatted)
    print('------')
    print(query_db(sql_query_formatted))
    print('------------------------------------------')
  except:
    print('Failed', i)
  print('sqlcoder2 result:')
  try:
    prompt_template = update_prompt_template(question)
    sql_query = llm_pipe(prompt_template)
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(query_db(sql_query))
  except:
    print('Failed', i)
  print("Longchain:")
  try:
    sql_query = chain.invoke({"question": question})
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(db.run(sql_query))
    print('------')
  except:
    print('Failed', i)
  print('****************************************************************************************')
  print('\n')

In [None]:
sql_query

In [None]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run( "What is the total number of  credit cards?")

In [None]:
dir(db_chain)

In [None]:
from langchain.chains import create_sql_query_chain
chain = create_sql_query_chain(llm, db)

In [None]:
response = chain.invoke({"question": "What is the total number of active credit cards?"})
sql_query_formatted = format_to_sqlite(query)
print(sql_query_formatted)
query_db(sql_query_formatted)

In [None]:
db.run(sql_query_formatted)

In [None]:
db_chain.generate_query("What is the total number of active credit cards?")

In [None]:
resp = llm_chain.run({"prompt": sql_query})
resp

In [None]:
# from langchain.llms import GooglePalm

# api_key = 'AIzaSyCEP46MCrbkUR0AENTGOJzqYRXTm6NUd7Q'

# llm = GooglePalm(google_api_key=api_key, temperature=0.2)

In [None]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run("Find the total amount spent on fuel by all users.")

# Few shot learning

We will use few shot learning to fix issues we have seen so far

### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

to_vectorize = [" ".join(example.values()) for example in few_shots]

In [None]:
to_vectorize

In [None]:
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

In [None]:
example_selector.select_examples({"Question": "How many total rewards do we have"})

In [None]:
sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today". Ensure final query is converted into sqlite.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here"""

In [None]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX

In [None]:
print(PROMPT_SUFFIX)

In [None]:
from langchain.chains.sql_database import prompt

In [None]:
dir(prompt)

In [None]:
from langchain.chains.sql_database.prompt import SQLITE_PROMPT

In [None]:
print(SQLITE_PROMPT)

In [None]:
print('You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}')

### Setting up PromptTemplete using input variables

In [None]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

In [None]:
print(PROMPT_SUFFIX)

In [None]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=sqlite_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [None]:
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
llm('''

Convert this into SQLITE query
SELECT COUNT(DISTINCT TransactionId) AS NumberOfTransactions, SUM(Amount) AS TotalSpentFROM transactions JOIN credit_card ON transactions.cardid::integer = credit_card.cardid::integer WHERE credit_card.userid::integer = 1 AND transactions.categoryid::integer = (SELECT categoryid::integer FROM category WHERE categoryname::text ilike '%food%');
''')

In [None]:
db_chain2 = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)

In [None]:
db_chain2.run("Which user has the most number of credit cards?")

In [None]:
query_db('SELECT u."username", COUNT(cc."cardid") AS card_count FROM "user" u JOIN credit_card cc ON u."userid" = cc."userid" GROUP BY u."username" ORDER BY card_count DESC LIMIT 1')

In [None]:
resp = db_chain2.run("What is the total amount of transactions completed in the last month")

# Test

In [None]:
sql_test = ([
    {'Question': "What is the total number of active credit cards?",
     'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE ExpiryDate > CURRENT_DATE;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total number of active credit cards'},
    {'Question': "How many transactions were made in the 'Shopping' category?",
     'SQLQuery': "SELECT COUNT(*) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Shopping');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of transactions in Shopping category'},
    {'Question': "Which user has the most number of credit cards?",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId GROUP BY u.UserId ORDER BY COUNT(cc.CardId) DESC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the most number of credit cards'},
    {'Question': "Find the total amount spent on fuel by all users.",
     'SQLQuery': "SELECT SUM(Amount) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Fuel');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total amount spent on fuel'},
    {'Question': "What is the oldest branch of the bank?",
     'SQLQuery': "SELECT BranchName FROM branch WHERE EstablishedDate = (SELECT MIN(EstablishedDate) FROM branch);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Oldest branch of the bank'},
    {'Question': "Find the user with the oldest credit card.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId WHERE cc.IssueDate = (SELECT MIN(IssueDate) FROM credit_card);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the oldest credit card'},
    {'Question': "How many rewards were earned for transactions above $1000?",
     'SQLQuery': "SELECT COUNT(r.RewardId) FROM reward r JOIN transactions t ON r.TransactionId = t.TransactionId WHERE t.Amount > 1000;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of rewards for transactions above $1000'},
    # {'Question': "List all users who have made a transaction in the last month.",
    #  'SQLQuery': "SELECT DISTINCT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId WHERE t.TransactionDate >= CURRENT_DATE - INTERVAL '1 month';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'List of users with transactions in the last month'},
    {'Question': "Which category has the lowest average transaction amount?",
     'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY AVG(t.Amount) ASC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Category with the lowest average transaction amount'},
    {'Question': "What is the total outstanding balance for all credit cards?",
     'SQLQuery': "SELECT SUM(CurrentBalance) FROM credit_card;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total outstanding balance for all credit cards'},
    {'Question': "How many branches have less than 20 employees?",
     'SQLQuery': "SELECT COUNT(*) FROM branch WHERE NumberOfEmployees < 20;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of branches with less than 20 employees'},
    {'Question': "List the top 3 users by total transaction amount.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId GROUP BY u.UserId ORDER BY SUM(t.Amount) DESC LIMIT 3;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Top 3 users by total transaction amount'},
    # {'Question': "Find the most popular transaction category.",
    #  'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY COUNT(*) DESC LIMIT 1;",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Most popular transaction category'},
    # {'Question': "How many credit cards were issued in the last year?",
    #  'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE IssueDate >= CURRENT_DATE - INTERVAL '1 year';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Number of credit cards issued in the last year'},
    {'Question': "What is the average number of transactions per user?",
     'SQLQuery': "SELECT AVG(TransactionCount) FROM (SELECT COUNT(*) as TransactionCount FROM transactions GROUP BY CardId) as TransactionPerUser;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Average number of transactions per user'}
])

In [None]:
len(sql_test)

In [None]:
for item in sql_test:
  qns = item['Question']
  print(qns)
  sql = item['SQLQuery']
  print(sql)
  print(db_chain2.run(sql))
  try:
    print(db_chain.run(qns))
  except:
    print('failed db_chain')
  try:
    print(db_chain2.run(qns))
  except:
    print('failed db_chain2')
    # item['Answer'] = ans
    # few_shots.append(item)
  print('-----------------------------')