<a href="https://colab.research.google.com/github/loni9164/text_sql/blob/main/sql_starcoder_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Create PostgreSQL

In [None]:
!apt-get -qq update
!apt-get -qq -y install postgresql
!service postgresql start

In [None]:
!pip install psycopg2-binary

In [3]:
!service postgresql restart

 * Restarting PostgreSQL 14 database server
   ...done.


In [10]:
!sudo -u postgres psql -U postgres -c "ALTER USER postgres PASSWORD '12345';"

ALTER ROLE


In [11]:
import psycopg2
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",  # Use the correct password here
    dbname="postgres"
)
conn.autocommit = True
cursor = conn.cursor()

In [12]:
# Creating a new database
cursor.execute("CREATE DATABASE credit_card_system")
cursor.close()
conn.close()

In [13]:
# Connecting to the new database
conn = psycopg2.connect(
    host="localhost",
    user="postgres",
    password="12345",
    dbname="credit_card_system"
)
cursor = conn.cursor()

In [14]:
# SQL statements to create tables
create_table_statements = [
    """
    CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )
    """,
    """
    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )
    """,

    """
    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )
    """,

    """
    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )
    """,

    """
    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )
    """,

    """
    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )
    """,

    """
    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )
    """
]


# Execute each CREATE TABLE statement
for statement in create_table_statements:
    cursor.execute(statement)

conn.commit()

In [15]:
import pandas as pd

# Function to load data from CSV to a table
def load_csv_to_table(csv_file_path, table_name):
    data = pd.read_csv(csv_file_path)
    for i, row in data.iterrows():
        insert_query = "INSERT INTO {} VALUES %s".format(table_name)
        cursor.execute(insert_query, (tuple(row),))

In [None]:
!git clone https://github.com/loni9164/text_sql.git

In [17]:
# Load data from CSV files in the correct order
load_csv_to_table('text_sql/csv_files/branch.csv', 'branch')
load_csv_to_table('text_sql/csv_files/category.csv', 'category')
load_csv_to_table('text_sql/csv_files/users.csv', 'users')
load_csv_to_table('text_sql/csv_files/credit_card.csv', 'credit_card')
load_csv_to_table('text_sql/csv_files/transactions.csv', 'transactions')
load_csv_to_table('text_sql/csv_files/credit_card_financial.csv', 'credit_card_financial')
load_csv_to_table('text_sql/csv_files/reward.csv', 'reward')

conn.commit()

In [18]:
def query_db(query):
  cursor = conn.cursor()
  cursor.execute("ROLLBACK")
  cursor.execute(query)
  return cursor.fetchall()

In [None]:
# query_db('SELECT user_id FROM users;')

# Installation

In [19]:
import sqlite3
import time
import pickle
import re

In [None]:
!pip install langchain langchain-experimental
!pip install -q  langchain
!pip install sentence-transformers
!pip install chromadb

In [21]:
# from langchain.llms import CTransformers
from langchain import PromptTemplate, LLMChain
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.utilities import SQLDatabase
from langchain.prompts import PromptTemplate
from langchain_experimental.sql import SQLDatabaseChain

# DB connection

In [22]:
db_user = "postgres"
db_password = "12345"
db_host = "localhost"
db_name = "credit_card_system"

connection_string = f"postgresql://{db_user}:{db_password}@{db_host}/{db_name}"
db = SQLDatabase.from_uri(connection_string)
table_info = db.table_info
# print(table_info)

# Load LLM model

In [None]:
!pip3 install transformers optimum
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu121/  # Use cu117 if on CUDA 11.7

In [24]:
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_name_or_path = "TheBloke/sqlcoder2-GPTQ"
# To use a different branch, change revision
# For example: revision="gptq-4bit-128g-actorder_True"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=False,
                                             revision="main")

config.json:   0%|          | 0.00/1.44k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/9.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [25]:
from auto_gptq import exllama_set_max_input_length
model = exllama_set_max_input_length(model, max_input_length=3000)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

tokenizer_config.json:   0%|          | 0.00/4.04k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/777k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/442k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.06M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/532 [00:00<?, ?B/s]

In [26]:
# print("\n\n*** Generate:")

# input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids.cuda()
# output = model.generate(inputs=input_ids, temperature=0, do_sample=True, top_p=0.95, top_k=40, max_new_tokens=512)
# print(tokenizer.decode(output[0]))

# Inference can also be done using transformers' pipeline

# print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    do_sample=True,
    temperature=0.1,
    top_p=0.95,
    top_k=40,
    repetition_penalty=1.1
)

# print(pipe(prompt_template)[0]['generated_text'])

# Quey generation without longchain

In [27]:
def update_prompt_template(prompt):
  template = f'''
  ## Task
  Generate a SQL query to answer the following question:
  `{prompt}`

  ### Database Schema
  This query will run on a database whose schema is represented in this string:
  CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )


    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )

    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )


    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )


    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )


    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )

    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )

  ### SQL
  Given the database schema, here is the SQL query that answers `{prompt}`:
  ```sql
  '''
  return template


In [28]:
def gen_sql_wto_longchain(prompt):
  updated_prompt = update_prompt_template(prompt)
  sql_query = pipe(updated_prompt)[0]['generated_text'].split("```sql")[-1].strip()
  print(sql_query)
  return query_db(sql_query)

In [29]:
gen_sql_wto_longchain('how many users do we have')



SELECT COUNT(DISTINCT users.user_id) AS total_users FROM users;


[(1000,)]

# Test queries

In [30]:
questions = [
    'What is the total number of active credit cards?',
    "How many transactions were made in the 'Shopping' category?",
    "Which user has the most number of credit cards?",
    "Find the total amount spent on fuel by all users.",
    "What is the oldest branch of the bank?",
    "Find the user with the oldest credit card.",
    "How many rewards were earned for transactions above $1000?",
    "Which category has the lowest average transaction amount?",
    "What is the total outstanding balance for all credit cards?",
    "How many branches have less than 20 employees?",
    "List the top 3 users by total transaction amount.",
    "What is the average number of transactions per user?"
]


answers = ["6", "93", "Michael Baldwin", "23960.57", "Branch 3", "Gabrielle Anderson", "0", "Shopping", "14090.2296532", "2", "Gabrielle Anderson, Michael Baldwin", "166.66666666666666"]

test_data = [{'qury_text':q, 'result': a} for q ,a in zip(questions, answers)]

# Longchain

## longchain Template

In [31]:
template = f'''
  ## Task
  Generate a SQL query to answer the following question:
  `{input}`

  ### Database Schema
  This query will run on a database whose schema is represented in this string:
  CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )


    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )

    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )


    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )


    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )


    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )

    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )

  ### SQL
  Given the database schema, here is the SQL query that answers `{input}`:
  ```sql
  '''

# Integration

In [32]:
from langchain.llms import HuggingFacePipeline
from langchain import PromptTemplate, LLMChain

In [33]:
llm = HuggingFacePipeline(pipeline=pipe)

## Dummpy template

In [34]:
PROMPT_SUFFIX = """Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer.
Use the following format:

Only use the following tables:

	CREATE TABLE branch (
        branch_id INTEGER,
        branch_name TEXT,
        branch_address TEXT,
        branch_phone TEXT,
        branch_manager TEXT,
        branch_email TEXT,
        established_date DATE,
        number_of_employees INTEGER,
        PRIMARY KEY (branch_id)
    )


    CREATE TABLE category (
        category_id INTEGER,
        category_name TEXT,
        PRIMARY KEY (category_id)
    )

    CREATE TABLE users (
        user_id INTEGER,
        user_name TEXT,
        user_email TEXT,
        user_address TEXT,
        user_phone TEXT,
        date_of_birth DATE,
        registration_date DATE,
        status TEXT,
        branch_id INTEGER,
        PRIMARY KEY (user_id),
        FOREIGN KEY(branch_id) REFERENCES branch (branch_id)
    )


    CREATE TABLE credit_card (
        card_id INTEGER,
        user_id INTEGER,
        card_number TEXT,
        card_type TEXT,
        expiry_date DATE,
        cvv INTEGER,
        issue_date DATE,
        credit_limit REAL,
        current_balance REAL,
        statement_balance REAL,
        PRIMARY KEY (card_id),
        FOREIGN KEY(user_id) REFERENCES users (user_id)
    )


    CREATE TABLE transactions (
        transaction_id INTEGER,
        card_id INTEGER,
        transaction_date DATE,
        amount REAL,
        merchant TEXT,
        category_id INTEGER,
        transaction_type TEXT,
        transaction_status TEXT,
        description TEXT,
        PRIMARY KEY (transaction_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id),
        FOREIGN KEY(category_id) REFERENCES category (category_id)
    )


    CREATE TABLE credit_card_financial (
        financial_id INTEGER,
        card_id INTEGER,
        overdue_charges REAL,
        loan_amount REAL,
        emi_amount REAL,
        emi_due_date DATE,
        interest_rate REAL,
        payment_due_date DATE,
        minimum_payment REAL,
        PRIMARY KEY (financial_id),
        FOREIGN KEY(card_id) REFERENCES credit_card (card_id)
    )


    CREATE TABLE reward (
        reward_id INTEGER,
        transaction_id INTEGER,
        points_earned INTEGER,
        points_redeemed INTEGER,
        current_balance INTEGER,
        PRIMARY KEY (reward_id),
        FOREIGN KEY(transaction_id) REFERENCES transactions (transaction_id)
    )

Question: {input}"""

## Chain

In [49]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False,
                                     prompt=PromptTemplate(input_variables=["input", "table_info"],
                                     template=PROMPT_SUFFIX))

In [None]:
db_chain.run('what is the name of user for userid 2')

In [None]:
db_chain.run(test_data[6]['qury_text'])

In [None]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False,
                                     prompt=PromptTemplate(input_variables=["input", "table_info"],
                                     template=PROMPT_SUFFIX))

db_chain.run('how many cards do we have')

In [None]:
db_chain3 = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False)


db_chain3.run('what is the name of user for userid 2')

In [40]:
gen_sql_wto_longchain('how many cards do we have')

SELECT COUNT(DISTINCT card_id) AS number_of_cards FROM credit_card;


[(6,)]

In [None]:
db_chain.run("How many employees are there?")

# # Test sqlcoder without few shots learning

In [50]:
db_chain.run('how many cards do we have')



[1m> Entering new SQLDatabaseChain chain...[0m
how many cards do we have
SQLQuery:



[32;1m[1;3mSELECT COUNT(DISTINCT card_id) AS total_cards FROM credit_card;[0m
SQLResult: [33;1m[1;3m[(6,)][0m
Answer:



[32;1m[1;3m6[0m
[1m> Finished chain.[0m


'6'

# Comaparing base sqlcoder model, with prompt and without prompt

In [None]:
res_lst = []
failed_cases = []

for i, item in enumerate(test_data):
  print(i)
  print("Question:")
  question = item['qury_text']
  print(question)
  print('--------------------')
  print(item['result'])
  print('--------------------')
  print("Gen_sql_wto_longchain:")
  print(gen_sql_wto_longchain(question))
  print('--------------------')
  print("Longchain1")
  try:
    res = db_chain.run(question)
    print(res)
  except:
    print("Longchain1 Failed")
  print('--------------------')
  print("Longchain3")
  try:
    res = db_chain3.run(question)
    print(res)
  except:
    print("Longchain2 Failed")

  print('***********************')
  print('\n')

In [116]:
with open('text_sql/few_shot_examples', 'rb') as f:
  few_shot_examples = pickle.load(f)

In [118]:
len(few_shot_examples)

18

In [119]:
few_shot_examples[0]

{'Question': 'List all the cards going to expire by sep 2026',
 'SQLQuery': "SELECT CardNumber FROM credit_card WHERE ExpiryDate <= '2026-09-30';",
 'SQLResult': [('5040000000001',),
  ('35300000000000000',),
  ('52300000000000000',)],
 'Answer': '5040000000001, 35300000000000000, 52300000000000000'}

In [None]:
for i, item in enumerate(few_shots):
  print(i)
  question = item['Question']
  print('Question:')
  print(question)
  print('------------------------------------------')
  print('SQL query:')
  print(item['SQLQuery'])
  print('------------------------------------------')
  print('Expected result:')
  print(item['Answer'])
  print('------------------------------------------')
  print('sqlcoder_34b:')
  try:
    sql_query_formatted = format_to_sqlite(item['sqlcoder_34b'])
    print(sql_query_formatted)
    print('------')
    print(query_db(sql_query_formatted))
    print('------------------------------------------')
  except:
    print('Failed', i)
  print('sqlcoder2 result:')
  try:
    prompt_template = update_prompt_template(question)
    sql_query = llm_pipe(prompt_template)
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(query_db(sql_query))
  except:
    print('Failed', i)
  print("Longchain:")
  try:
    sql_query = chain.invoke({"question": question})
    sql_query = format_to_sqlite(sql_query)
    print(sql_query)
    print('------')
    print(db.run(sql_query))
    print('------')
  except:
    print('Failed', i)
  print('****************************************************************************************')
  print('\n')

In [None]:
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run( "What is the total number of  credit cards?")

In [None]:
dir(db_chain)

In [None]:
response = chain.invoke({"question": "What is the total number of active credit cards?"})
sql_query_formatted = format_to_sqlite(query)
print(sql_query_formatted)
query_db(sql_query_formatted)

In [None]:
db.run(sql_query_formatted)

In [None]:
# from langchain.llms import GooglePalm

# api_key = 'AIzaSyCEP46MCrbkUR0AENTGOJzqYRXTm6NUd7Q'

# llm = GooglePalm(google_api_key=api_key, temperature=0.2)

In [None]:
from langchain.utilities import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)

In [None]:
db_chain.run("Find the total amount spent on fuel by all users.")

# Few shot learning

We will use few shot learning to fix issues we have seen so far

### Creating Semantic Similarity Based example selector

- create embedding on the few_shots
- Store the embeddings in Chroma DB
- Retrieve the the top most Semantically close example from the vector store

In [None]:
from langchain.embeddings import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')

to_vectorize = [" ".join(example.values()) for example in few_shots]

In [None]:
to_vectorize

In [None]:
from langchain.vectorstores import Chroma
vectorstore = Chroma.from_texts(to_vectorize, embeddings, metadatas=few_shots)

In [None]:
from langchain.prompts import SemanticSimilarityExampleSelector

example_selector = SemanticSimilarityExampleSelector(
    vectorstore=vectorstore,
    k=2,
)

In [None]:
example_selector.select_examples({"Question": "How many total rewards do we have"})

In [None]:
sqlite_prompt = """You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use date('now') function to get the current date, if the question involves "today". Ensure final query is converted into sqlite.

Use the following format:

Question: Question here
SQLQuery: SQL Query to run
SQLResult: Result of the SQLQuery
Answer: Final answer here"""

In [None]:
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX

In [None]:
print(PROMPT_SUFFIX)

In [None]:
from langchain.chains.sql_database import prompt

In [None]:
dir(prompt)

In [None]:
from langchain.chains.sql_database.prompt import SQLITE_PROMPT

In [None]:
print(SQLITE_PROMPT)

In [None]:
print('You are a SQLite expert. Given an input question, first create a syntactically correct SQLite query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per SQLite. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use date(\'now\') function to get the current date, if the question involves "today".\n\nUse the following format:\n\nQuestion: Question here\nSQLQuery: SQL Query to run\nSQLResult: Result of the SQLQuery\nAnswer: Final answer here\n\nOnly use the following tables:\n{table_info}\n\nQuestion: {input}')

### Setting up PromptTemplete using input variables

In [None]:
from langchain.prompts.prompt import PromptTemplate

example_prompt = PromptTemplate(
    input_variables=["Question", "SQLQuery", "SQLResult","Answer",],
    template="\nQuestion: {Question}\nSQLQuery: {SQLQuery}\nSQLResult: {SQLResult}\nAnswer: {Answer}",
)

In [None]:
print(PROMPT_SUFFIX)

In [None]:
few_shot_prompt = FewShotPromptTemplate(
    example_selector=example_selector,
    example_prompt=example_prompt,
    prefix=sqlite_prompt,
    suffix=PROMPT_SUFFIX,
    input_variables=["input", "table_info", "top_k"], #These variables are used in the prefix and suffix
)

In [None]:
llm = HuggingFacePipeline(pipeline=pipe)

In [None]:
llm('''

Convert this into SQLITE query
SELECT COUNT(DISTINCT TransactionId) AS NumberOfTransactions, SUM(Amount) AS TotalSpentFROM transactions JOIN credit_card ON transactions.cardid::integer = credit_card.cardid::integer WHERE credit_card.userid::integer = 1 AND transactions.categoryid::integer = (SELECT categoryid::integer FROM category WHERE categoryname::text ilike '%food%');
''')

In [None]:
db_chain2 = SQLDatabaseChain.from_llm(llm, db, verbose=True, prompt=few_shot_prompt)

In [None]:
db_chain2.run("Which user has the most number of credit cards?")

In [None]:
query_db('SELECT u."username", COUNT(cc."cardid") AS card_count FROM "user" u JOIN credit_card cc ON u."userid" = cc."userid" GROUP BY u."username" ORDER BY card_count DESC LIMIT 1')

In [None]:
resp = db_chain2.run("What is the total amount of transactions completed in the last month")

# Test

In [None]:
sql_test = ([
    {'Question': "What is the total number of active credit cards?",
     'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE ExpiryDate > CURRENT_DATE;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total number of active credit cards'},
    {'Question': "How many transactions were made in the 'Shopping' category?",
     'SQLQuery': "SELECT COUNT(*) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Shopping');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of transactions in Shopping category'},
    {'Question': "Which user has the most number of credit cards?",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId GROUP BY u.UserId ORDER BY COUNT(cc.CardId) DESC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the most number of credit cards'},
    {'Question': "Find the total amount spent on fuel by all users.",
     'SQLQuery': "SELECT SUM(Amount) FROM transactions WHERE CategoryId = (SELECT CategoryId FROM category WHERE CategoryName = 'Fuel');",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total amount spent on fuel'},
    {'Question': "What is the oldest branch of the bank?",
     'SQLQuery': "SELECT BranchName FROM branch WHERE EstablishedDate = (SELECT MIN(EstablishedDate) FROM branch);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Oldest branch of the bank'},
    {'Question': "Find the user with the oldest credit card.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId WHERE cc.IssueDate = (SELECT MIN(IssueDate) FROM credit_card);",
     'SQLResult': "Result of the SQL query",
     'Answer': 'User with the oldest credit card'},
    {'Question': "How many rewards were earned for transactions above $1000?",
     'SQLQuery': "SELECT COUNT(r.RewardId) FROM reward r JOIN transactions t ON r.TransactionId = t.TransactionId WHERE t.Amount > 1000;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of rewards for transactions above $1000'},
    # {'Question': "List all users who have made a transaction in the last month.",
    #  'SQLQuery': "SELECT DISTINCT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId WHERE t.TransactionDate >= CURRENT_DATE - INTERVAL '1 month';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'List of users with transactions in the last month'},
    {'Question': "Which category has the lowest average transaction amount?",
     'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY AVG(t.Amount) ASC LIMIT 1;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Category with the lowest average transaction amount'},
    {'Question': "What is the total outstanding balance for all credit cards?",
     'SQLQuery': "SELECT SUM(CurrentBalance) FROM credit_card;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Total outstanding balance for all credit cards'},
    {'Question': "How many branches have less than 20 employees?",
     'SQLQuery': "SELECT COUNT(*) FROM branch WHERE NumberOfEmployees < 20;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Number of branches with less than 20 employees'},
    {'Question': "List the top 3 users by total transaction amount.",
     'SQLQuery': "SELECT u.UserName FROM user u JOIN credit_card cc ON u.UserId = cc.UserId JOIN transactions t ON cc.CardId = t.CardId GROUP BY u.UserId ORDER BY SUM(t.Amount) DESC LIMIT 3;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Top 3 users by total transaction amount'},
    # {'Question': "Find the most popular transaction category.",
    #  'SQLQuery': "SELECT c.CategoryName FROM category c JOIN transactions t ON c.CategoryId = t.CategoryId GROUP BY c.CategoryId ORDER BY COUNT(*) DESC LIMIT 1;",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Most popular transaction category'},
    # {'Question': "How many credit cards were issued in the last year?",
    #  'SQLQuery': "SELECT COUNT(*) FROM credit_card WHERE IssueDate >= CURRENT_DATE - INTERVAL '1 year';",
    #  'SQLResult': "Result of the SQL query",
    #  'Answer': 'Number of credit cards issued in the last year'},
    {'Question': "What is the average number of transactions per user?",
     'SQLQuery': "SELECT AVG(TransactionCount) FROM (SELECT COUNT(*) as TransactionCount FROM transactions GROUP BY CardId) as TransactionPerUser;",
     'SQLResult': "Result of the SQL query",
     'Answer': 'Average number of transactions per user'}
])

In [None]:
len(sql_test)