# Experimento: C3 - Mondial - Foreign Key - GPT 4

In [None]:

from langchain.chat_models import ChatOpenAI
from urllib.parse import quote  

from dotenv import load_dotenv
import time
import os
import sys
import json
load_dotenv()

from c3_clear_prompting import generate_clear_prompting
from c3_calibration_with_hints import generate_calibration_with_hints
from c3_generate_sql import generate_sql

experiment_path = '..\..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")


from dataset_utils import DatasetEvaluator
from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils

In [12]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'
FILE_NAME_RESULT = f"results/6_c3_queries_chatgpt4_{SCHEMA}_fk.json"

In [13]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 
        
def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries


## Conexão com o banco

In [None]:
json_file_path = f"{experiment_path}/datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)
db_connection

In [15]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclusao = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

['airport',
 'borders',
 'city',
 'citylocalname',
 'cityothername',
 'citypops',
 'continent',
 'country',
 'countrylocalname',
 'countryothername',
 'countrypops',
 'desert',
 'economy',
 'encompasses',
 'ethnicgroup',
 'geo_desert',
 'geo_estuary',
 'geo_island',
 'geo_lake',
 'geo_mountain',
 'geo_river',
 'geo_sea',
 'geo_source',
 'island',
 'islandin',
 'ismember',
 'lake',
 'lakeonisland',
 'language',
 'located',
 'locatedon',
 'mergeswith',
 'mountain',
 'mountainonisland',
 'organization',
 'politics',
 'population',
 'province',
 'provincelocalname',
 'provinceothername',
 'provpops',
 'religion',
 'river',
 'riveronisland',
 'riverthrough',
 'sea']

## C3 - Function

In [16]:
def run_c3(question, db, model='gpt-4', add_fk = True, callback=None):
    llm = ChatOpenAI(model_name = model, temperature=0.7, n=10)
    clear_prompting = generate_clear_prompting(question, db, llm, add_fk=add_fk, callback=callback)
    messages = generate_calibration_with_hints(clear_prompting)
    llm = ChatOpenAI(model_name = model, n=20)
    sql = generate_sql(messages, llm, db, question, callback=callback)
    return sql

## Preparando as consultas 

In [17]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': '',
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': '',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': '',
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': '',
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': '',
  'type': 'complex'},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query_string': '',
  'type': 'complex'},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query_string': '',
  'type': 'medium'},
 {'id': '8',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query_string': '',
  'type': 'complex'},
 {'id': '9',
  'que

## Tracking token usage

In [18]:
track_token = [] 
def tracking_token(cb =None, reset = False):
    global track_token
    track_token.append(cb)
    if reset:
        track_token = []

In [19]:

def convert_to_dict_tracking_token():
    token_usage = {}
    for e in track_token:
        for key in e.keys():
            token_usage[key] = {}
            token_usage[key]['total_tokens'] = e[key].total_tokens
            token_usage[key]['total_cost'] = e[key].total_cost
            token_usage[key]['prompt_tokens'] = e[key].prompt_tokens
            token_usage[key]['completion_tokens'] = e[key].completion_tokens 
    return token_usage

## Executando o método

In [12]:
start_time = time.time()  
tracking_token(reset=True)
errors = []
query_index = -1
for instance in queries:
    query_index +=1
    try:
        start_time = time.time()
        sql = run_c3(instance["question"], db, callback=tracking_token)
        end_time = time.time()
        instance["query_string"] = sql
        instance["token_usage"] = convert_to_dict_tracking_token()
        instance['time'] = end_time - start_time
        save_queries(queries)
        print(instance['id'], instance['question'], instance["query_string"], instance['time'])
        print("-----")
    except:
        print("Error")
        errors.append(query_index)
        pass
    finally:
        tracking_token(reset=True)


end_time = time.time()



  return self.db._metadata.sorted_tables


Column recall attempt: 1
1 What is the area of Thailand? SELECT area FROM country WHERE name = 'Thailand'; 78.29437732696533
-----
Column recall attempt: 1
2 What are the provinces with an area greater than 10000? SELECT name FROM province WHERE area > 10000; 59.322205781936646
-----
Column recall attempt: 1
3 What are the languages spoken in Poland? SELECT name FROM language WHERE country = (SELECT code FROM country WHERE name = 'Poland'); 51.87317085266113
-----
Column recall attempt: 1
4 How deep is Lake Kariba? SELECT depth FROM lake WHERE name = 'Lake Kariba'; 48.967124462127686
-----
Column recall attempt: 1
5 What is the total of provinces of Netherlands? SELECT COUNT(name) FROM province WHERE country = (SELECT code FROM country WHERE name = 'Netherlands') 46.253724575042725
-----
Column recall attempt: 1
6 What is the percentage of religious people are hindu in thailand? SELECT percentage FROM religion WHERE name = 'hindu' AND country = 'thailand'; 53.08298087120056
-----
Colum

In [11]:
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT area FROM mondial_country WHERE name = 'Thailand';",
  'type': 'simple',
  'token_usage': {'table_recall': {'total_tokens': 2974,
    'total_cost': 0.15081,
    'prompt_tokens': 921,
    'completion_tokens': 2053},
   'column_recall': {'total_tokens': 3592,
    'total_cost': 0.20727,
    'prompt_tokens': 275,
    'completion_tokens': 3317},
   'sql_generation': {'total_tokens': 771,
    'total_cost': 0.03033,
    'prompt_tokens': 531,
    'completion_tokens': 240}}},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': 'SELECT name FROM mondial_province WHERE area > 10000;',
  'type': 'simple',
  'token_usage': {'table_recall': {'total_tokens': 2496,
    'total_cost': 0.12194999999999999,
    'prompt_tokens': 927,
    'completion_tokens': 1569},
   'column_recall': {'total_tokens': 2478,
    'total_cost': 0.1422,
    'prompt_tokens': 216,
    'complet

#### Fixing query

In [None]:
pos = 40
instance = queries[pos]
q = read_queries()
start_time = time.time()
sql = run_c3(instance["question"], db, callback=tracking_token)
end_time = time.time()
instance["query_string"] = sql
instance["token_usage"] = convert_to_dict_tracking_token()
instance['time'] = end_time - start_time
q[pos] = instance
save_queries(q)
print(instance['id'], instance['question'], instance["query_string"], instance['time'])



In [20]:
to_fix = [59,62,72,85,99]
for pos in to_fix:
    instance = queries[pos]
    q = read_queries()
    start_time = time.time()
    sql = run_c3(instance["question"], db, callback=tracking_token)
    end_time = time.time()
    instance["query_string"] = sql
    instance["token_usage"] = convert_to_dict_tracking_token()
    instance['time'] = end_time - start_time
    q[pos] = instance
    save_queries(q)
    print(instance['id'], instance['question'], instance["query_string"], instance['time'])

Column recall attempt: 1
60  List the names of capital cities which are the base for organizations in alphabetical order SELECT DISTINCT city.name FROM city, organization WHERE city.name = organization.city AND city.name = country.capital ORDER BY city.name; 55.844616174697876
Column recall attempt: 1
63 Show the inflation rate of countries that are washed by the Arabian Sea SELECT inflation FROM economy WHERE country IN (SELECT country FROM geo_sea WHERE sea = 'Arabian Sea') 78.05257797241211
Column recall attempt: 1
73 What area is the largest continent? SELECT MAX(area) FROM continent; 44.964357137680054
Column recall attempt: 1
86 What are the 3 airports with the largest name? SELECT name FROM airport ORDER BY LENGTH(name) DESC FETCH FIRST 3 ROWS ONLY; 29.203540802001953
Column recall attempt: 1
100 What is the percentage of industries in relation to Japan's economy? SELECT SELECT COUNT(DISTINCT industry) / (SELECT COUNT(*) FROM economy WHERE country = 'JPN') AS percentage FROM eco

## Teste de um consulta no GPT-4

In [10]:
question = 'What is the area of Thailand?'

In [11]:
sql = run_c3(question, db, callback=tracking_token)

<function tracking_token at 0x0000011FA4666EF0>


In [12]:
print(sql)
print(track_token)

SELECT area FROM mondial_country WHERE name = 'Thailand';
[{'table_recall': Tokens Used: 3099
	Prompt Tokens: 829
	Completion Tokens: 2270
Successful Requests: 1
Total Cost (USD): $0.16107}, {'column_recall': Tokens Used: 3019
	Prompt Tokens: 268
	Completion Tokens: 2751
Successful Requests: 1
Total Cost (USD): $0.17309999999999998}, {'sql_gneration': Tokens Used: 771
	Prompt Tokens: 531
	Completion Tokens: 240
Successful Requests: 1
Total Cost (USD): $0.03033}]


In [13]:
tracking_token(reset=True)

In [7]:

model = 'gpt-4'
llm = ChatOpenAI(model_name = model, temperature=0.7, n=10)
clear_prompting = generate_clear_prompting(question, db, llm, add_fk=True)
print(clear_prompting)



### Complete oracle SQL query only and with no explanation, and do not select extra columns that are not explicitly requested in the query. 
### Sqlite SQL tables, with their properties: 
#
# mondial_city (name, country, province, population)
# mondial_continent (name, area, meta_repcol)
# mondial_country (name, area, code, capital)
# mondial_province (name, country, population, area)

#
### What is the area of Thailand?
SELECT


In [9]:
messages = generate_calibration_with_hints(clear_prompting)
messages

[[SystemMessage(content="\n                You are now an excellent SQL writer, first I'll give you some tips and examples, and I need you to \n                remember the tips, and do not make same mistakes\n                ", additional_kwargs={}),
  HumanMessage(content="\n            Tips 1:\n            Question: Which A has most number of B?\n            Gold SQL: select A from B group by A order by count (*) desc fetch first 1 rows only;\n            Notice that the Gold SQL doesn't select COUNT(*) because the question only wants to know the A and\n            the number should be only used in ORDER BY clause, there are many questions asks in this way, and I\n            need you to remember this in the the following questions.        \n            ", additional_kwargs={}, example=False),
  AIMessage(content="\n                Thank you for the tip! I'll keep in mind that when the question only asks for a certain field, I should not\n                include the COUNT(*) in the 

In [10]:
llm = ChatOpenAI(model_name = model, n=20)
sql = generate_sql(messages, llm, db, question)

In [11]:
sql

"SELECT area FROM mondial_country WHERE name = 'Thailand';"