# SQLQueryChain - Mondial - GPT 4


In [None]:

from langchain.chat_models import ChatOpenAI
from langchain.chains import create_sql_query_chain
from urllib.parse import quote  
from langchain.callbacks import get_openai_callback

from dotenv import load_dotenv
import os
import sys
import json
import time
load_dotenv()

experiment_path = '..\..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")


from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils


# Schema

In [12]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'

FILE_NAME_RESULT = f"results/12_sql_queries_chatgpt4_{SCHEMA}_fk.json"

In [13]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 
        
def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries



## Conexão

In [None]:
json_file_path = f"../../datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)

db_connection

### Utilizando o SQLDatabase para pegar todas as informações do database

In [15]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclusao = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

['airport',
 'borders',
 'city',
 'citylocalname',
 'cityothername',
 'citypops',
 'continent',
 'country',
 'countrylocalname',
 'countryothername',
 'countrypops',
 'desert',
 'economy',
 'encompasses',
 'ethnicgroup',
 'geo_desert',
 'geo_estuary',
 'geo_island',
 'geo_lake',
 'geo_mountain',
 'geo_river',
 'geo_sea',
 'geo_source',
 'island',
 'islandin',
 'ismember',
 'lake',
 'lakeonisland',
 'language',
 'located',
 'locatedon',
 'mergeswith',
 'mountain',
 'mountainonisland',
 'organization',
 'politics',
 'population',
 'province',
 'provincelocalname',
 'provinceothername',
 'provpops',
 'religion',
 'river',
 'riveronisland',
 'riverthrough',
 'sea']

In [16]:
len(db.get_table_names())

46

## Criando o prompt

In [17]:
from langchain.prompts.prompt import PromptTemplate

f = open(f"prompts/prompt_template_sql_query_chain.txt", "r")
prompt_template = f.read()
f.close()


PROMPT = PromptTemplate(
    input_variables=["input", "table_info", "top_k"], template=prompt_template
)

print(PROMPT)

input_variables=['input', 'table_info', 'top_k'] output_parser=None partial_variables={} template='You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.\nUnless the user specifies in the question a specific number of examples to obtain, don\'t query for at {top_k} most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.\nNever query for all columns from a table. You must query only the columns that are needed to answer the question.\nPay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.\nPay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". \n\nSome hints:\n- Don\

## Criando o Chain para gerar SQL

In [None]:
query_chain  = create_sql_query_chain(ChatOpenAI(temperature=0, model_name='gpt-4'), db.db, prompt=PROMPT)
query_chain 



## Preparando as consultas em linguagem natural para rodar no LLM

In [19]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': '',
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': '',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': '',
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': '',
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': '',
  'type': 'complex'},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query_string': '',
  'type': 'complex'},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query_string': '',
  'type': 'medium'},
 {'id': '8',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query_string': '',
  'type': 'complex'},
 {'id': '9',
  'que

# Rodando as consultas no LLM para gerar SQL

In [11]:


# A cada X consultas, vai ser gerado um delay de 10s para evitar o bloqueio da API.

count = 0

for instance in queries:
    with get_openai_callback() as cb:
        start_time = time.time()
        sql_query = query_chain.invoke({"question":instance["question"]})
        end_time = time.time()
        instance["query_string"] = sql_query
        instance['total_tokens'] = cb.total_tokens
        instance['prompt_tokens'] = cb.prompt_tokens
        instance['completion_tokens'] = cb.completion_tokens
        instance['total_cost'] = cb.total_cost
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], sql_query, instance['time'], instance['total_cost'])
    save_queries(queries)
    time.sleep(2)
queries

  for tbl in self._metadata.sorted_tables


1 What is the area of Thailand? SELECT area FROM country WHERE name = 'Thailand' 3.32186222076416 0.17139
2 What are the provinces with an area greater than 10000? SELECT name, area FROM province WHERE area > 10000 FETCH FIRST 5 ROWS ONLY 3.8562686443328857 0.17205
3 What are the languages spoken in Poland? SELECT name FROM language WHERE country = (SELECT code FROM country WHERE name = 'Poland') 3.2713685035705566 0.1719
4 How deep is Lake Kariba? SELECT depth FROM lake WHERE name = 'Lake Kariba'; 2.834017038345337 0.17145
5 What is the total of provinces of Netherlands? SELECT COUNT(name) FROM province WHERE country = (SELECT code FROM country WHERE name = 'Netherlands') 2.600212574005127 0.17204999999999998
6 What is the percentage of religious people are hindu in thailand? SELECT percentage FROM religion WHERE country = 'THA' AND name = 'Hindu' 2.238055467605591 0.17208
7 List the number of provinces each river flows through. SELECT river, COUNT(province) AS num_provinces
FROM geo_

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT area FROM country WHERE name = 'Thailand'",
  'type': 'simple',
  'total_tokens': 5702,
  'prompt_tokens': 5691,
  'completion_tokens': 11,
  'total_cost': 0.17139,
  'time': 3.32186222076416},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': 'SELECT name, area FROM province WHERE area > 10000 FETCH FIRST 5 ROWS ONLY',
  'type': 'simple',
  'total_tokens': 5716,
  'prompt_tokens': 5697,
  'completion_tokens': 19,
  'total_cost': 0.17205,
  'time': 3.8562686443328857},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': "SELECT name FROM language WHERE country = (SELECT code FROM country WHERE name = 'Poland')",
  'type': 'medium',
  'total_tokens': 5711,
  'prompt_tokens': 5692,
  'completion_tokens': 19,
  'total_cost': 0.1719,
  'time': 3.2713685035705566},
 {'id': '4',
  'question': 'How deep is Lake Kariba?

## Prompt Gerado pelo Langchain

In [12]:
sql_query_chain_prompt = query_chain.middle[0].template.format(table_info=db.get_table_info(), top_k=0, input="{input}")
print(sql_query_chain_prompt)

You are an Oracle SQL expert. Given an input question, first create a syntactically correct Oracle SQL query to run, then look at the results of the query and return the answer to the input question.
Unless the user specifies in the question a specific number of examples to obtain, don't query for at 0 most results or any using the FETCH FIRST n ROWS ONLY clause as per Oracle SQL. You can order the results to return the most informative data in the database.
Never query for all columns from a table. You must query only the columns that are needed to answer the question.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use TRUNC(SYSDATE) function to get the current date, if the question involves "today". 

Some hints:
- Don't use dounle quotes in column name
- 
Example:
`SELECT "column_name" FROM table` should be `SELECT column_name FR

#### Fixing queries

In [20]:
to_fix = [40,59,62,72,85,99]
for pos in to_fix:
    instance = queries[pos]
    q = read_queries()
    with get_openai_callback() as cb:
            start_time = time.time()
            sql_query = query_chain.invoke({"question":instance["question"]})
            end_time = time.time()
            instance["query_string"] = sql_query
            instance['total_tokens'] = cb.total_tokens
            instance['prompt_tokens'] = cb.prompt_tokens
            instance['completion_tokens'] = cb.completion_tokens
            instance['total_cost'] = cb.total_cost
            instance['time'] = end_time - start_time
            q[pos] = instance
            print(instance['id'], instance['question'], instance["query_string"], instance['time'], instance['total_cost'])
            save_queries(q)

41 How many countries that are close to the Mediterranean Sea? SELECT COUNT(DISTINCT country) FROM geo_sea WHERE sea = 'Mediterranean Sea' 1.341437578201294 0.17204999999999998
60  List the names of capital cities which are the base for organizations in alphabetical order SELECT DISTINCT city.name 
FROM city 
JOIN organization 
ON city.name = organization.city 
ORDER BY city.name ASC 3.0097904205322266 0.17234999999999998
63 Show the inflation rate of countries that are washed by the Arabian Sea SELECT c.name, e.inflation
FROM country c
JOIN economy e ON c.code = e.country
JOIN geo_sea gs ON c.code = gs.country
WHERE gs.sea = 'Arabian Sea' 3.581909418106079 0.17357999999999998
73 What area is the largest continent? SELECT name, MAX(area) as max_area FROM continent 2.2166857719421387 0.17132999999999998
86 What are the 3 airports with the largest name? SELECT name FROM airport ORDER BY LENGTH(name) DESC FETCH FIRST 3 ROWS ONLY 2.4312219619750977 0.17180999999999996
100 What is the perce