# Experimento: DIN - Mondial - Foreign Key - GPT 4

In [None]:

from langchain.chat_models import ChatOpenAI
from urllib.parse import quote  

from dotenv import load_dotenv
import time
import os
import sys
import json
load_dotenv()


from din_schema_linking import schema_linking_module
from din_classification import classification_module
from din_generating_sql_by_type import generating_sql_by_type_module
from din_self_correction import self_correction_module

experiment_path = '..\..'
path = os.path.abspath('')
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")

from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils

In [2]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'
FILE_NAME_RESULT = f"results/6_din_queries_chatgpt4_{SCHEMA}_fk.json"

In [3]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 
        
def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries


## Conexão com o banco

In [None]:
json_file_path = f"{experiment_path}/datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)
db_connection

In [5]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclusao = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

  self._metadata.reflect(
  self._metadata.reflect(
  self._metadata.reflect(


['airport',
 'borders',
 'city',
 'citylocalname',
 'cityothername',
 'citypops',
 'continent',
 'country',
 'countrylocalname',
 'countryothername',
 'countrypops',
 'desert',
 'economy',
 'encompasses',
 'ethnicgroup',
 'geo_desert',
 'geo_estuary',
 'geo_island',
 'geo_lake',
 'geo_mountain',
 'geo_river',
 'geo_sea',
 'geo_source',
 'island',
 'islandin',
 'ismember',
 'lake',
 'lakeonisland',
 'language',
 'located',
 'locatedon',
 'mergeswith',
 'mountain',
 'mountainonisland',
 'organization',
 'politics',
 'population',
 'province',
 'provincelocalname',
 'provinceothername',
 'provpops',
 'religion',
 'river',
 'riveronisland',
 'riverthrough',
 'sea']

## DIN - Function

In [6]:
def run_din(question, db, model_name = 'gpt-4', debbug = False, callback=None):
    
    model_kwargs = {'top_p':1.0, 'frequency_penalty':0.0, 'presence_penalty':0.0, 'stop':['Q:']}
    llm = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=600, model_kwargs = model_kwargs)

    model_kwargs['stop'] = ['#', ';','\n\n']
    llm_fix = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=350, model_kwargs = model_kwargs)
    
    schema_links = schema_linking_module(db, llm, question, callback=callback)
   
    classification = classification_module(db, llm, question, schema_links, callback=callback)
    
    SQL = generating_sql_by_type_module(db, llm, question, schema_links, classification, callback=callback)
   
    SQL_FINAL = self_correction_module(db, llm_fix, question, SQL, callback=callback)
    
    if debbug == True:
        print("Schema Links =", schema_links)
        print('Class: ', classification['predicted_class'])
        print('SQL >', SQL)
        print('SQL Final >', SQL_FINAL)
        print()
        print('-------')
    return SQL_FINAL

## Preparando as consultas 

In [7]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': '',
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': '',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': '',
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': '',
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': '',
  'type': 'complex'},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query_string': '',
  'type': 'complex'},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query_string': '',
  'type': 'medium'},
 {'id': '8',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query_string': '',
  'type': 'complex'},
 {'id': '9',
  'que

## Tracking token usage

In [8]:
track_token = [] 
def tracking_token(cb =None, reset = False):
    global track_token
    track_token.append(cb)
    if reset:
        track_token = []

In [9]:

def convert_to_dict_tracking_token():
    token_usage = {}
    for e in track_token:
        for key in e.keys():
            token_usage[key] = {}
            token_usage[key]['total_tokens'] = e[key].total_tokens
            token_usage[key]['total_cost'] = e[key].total_cost
            token_usage[key]['prompt_tokens'] = e[key].prompt_tokens
            token_usage[key]['completion_tokens'] = e[key].completion_tokens 
    return token_usage

## Executando o método

In [10]:
tracking_token(reset=True)
errors = []
query_index = -1
for instance in queries:
    query_index +=1
    try:
        start_time = time.time()
        sql = run_din(instance["question"], db, debbug=False, callback=tracking_token)
        end_time = time.time()
        instance["query_string"] = sql
        instance["token_usage"] = convert_to_dict_tracking_token()
        instance['time'] = end_time - start_time
        print(instance['id'], instance['question'], instance["query_string"], instance['time'])
        save_queries(queries)
    except:
        print("Error")
        errors.append(query_index)
        pass
    finally:
        tracking_token(reset=True)
        time.sleep(2)

       

In [None]:
# tracking_token(reset=True)
# errors = []
# query_index = 0
# count = 0
# queries = read_queries()
# start = 84
# end = len(queries)
# for i in range(start, end):
#     instance = queries[i]
#     try:
#         start_time = time.time()
#         sql = run_din(instance["question"], db, debbug=False, callback=tracking_token)
#         end_time = time.time()
#         instance["query_string"] = sql
#         instance["token_usage"] = convert_to_dict_tracking_token()
#         instance['time'] = end_time - start_time
#         print(instance['id'], instance['question'], instance["query_string"], instance['time'])
#         save_queries(queries)
#     except:
#         print("Error")
#         errors.append(query_index)
#         pass
#     finally:
#         tracking_token(reset=True)
#         time.sleep(2)
    
    
    

#### Fixing query

In [None]:

pos = 40
instance = queries[pos]
q = read_queries()
start_time = time.time()
sql = run_din(instance["question"], db, callback=tracking_token)
end_time = time.time()
instance["query_string"] = sql
instance["token_usage"] = convert_to_dict_tracking_token()
instance['time'] = end_time - start_time
q[pos] = instance
save_queries(q)
print(instance['id'], instance['question'], instance["query_string"], instance['time'])



In [12]:
to_fix = [59,62,72,85,99]
for pos in to_fix:
    instance = queries[pos]
    q = read_queries()
    start_time = time.time()
    sql = run_din(instance["question"], db, callback=tracking_token)
    end_time = time.time()
    instance["query_string"] = sql
    instance["token_usage"] = convert_to_dict_tracking_token()
    instance['time'] = end_time - start_time
    q[pos] = instance
    save_queries(q)
    print(instance['id'], instance['question'], instance["query_string"], instance['time'])

-- Classification -- 
generations=[[ChatGeneration(text='The SQL query for the question "List the names of capital cities which are the base for organizations in alphabetical order" needs these tables = [city, organization], so we need JOIN.\nPlus, it doesn\'t require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""].\nSo, we need JOIN and don\'t need nested queries, then the the SQL query can be classified as "NON-NESTED".\nLabel: "NON-NESTED"', generation_info={'finish_reason': 'stop'}, message=AIMessage(content='The SQL query for the question "List the names of capital cities which are the base for organizations in alphabetical order" needs these tables = [city, organization], so we need JOIN.\nPlus, it doesn\'t require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""].\nSo, we need JOIN and don\'t need nested queries, then the the SQL query can be classified as "NON-NES

In [11]:
errors


[]

In [None]:
# queries = read_queries()
# errors = []
# query_index = 22
# for i in range(query_index, len(queries)):
#     instance = queries[i]
#     query_index +=1
#     try:
#         sql = run_din(instance["question"], db, debbug=True)
#         instance["query_string"] = sql
#         save_queries(queries)
#         print(instance["id"], instance["question"], sql)
#         time.sleep(5)
#     except:
#         print("Error")
#         errors.append(query_index)
#         pass

In [8]:
queries


[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT area  FROM  mondial_country  WHERE  name = 'Thailand'",
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area more than 10000?',
  'query_string': 'SELECT name  FROM  mondial_province  WHERE  area > 10000',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': "SELECT name  FROM  mondial_language  WHERE  country = 'Poland'",
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': "SELECT depth  FROM  mondial_lake  WHERE  name = 'Lake Kariba'",
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': "SELECT COUNT(DISTINCT province)  FROM mondial_province  WHERE country = 'Netherlands'",
  'type': 'medium'},
 {'id': '6',
  'question': 'What percentage of religious people are hindu in thailand?',
  'query_string': "SELECT 

## Testando o DIN para uma consulta sem FK no Mondial

### LLM

In [6]:
model_name = 'gpt-4'

question = ""

model_kwargs = {'top_p':1.0, 'frequency_penalty':0.0, 'presence_penalty':0.0, 'stop':['Q:']}
llm = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=600, model_kwargs = model_kwargs)

model_kwargs['stop'] = ['#', ';','\n\n']
llm_fix = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=350, model_kwargs = model_kwargs)



### Schema Linking

In [9]:
schema_links = schema_linking_module(db, llm, question)
print("Schema Links =", schema_links)

Schema Links = [mondial_country.area]


### Classification

In [10]:
classification = classification_module(db, llm, question, schema_links)
print(classification)

{'predicted_class': '"EASY"', 'classification_label': 'The SQL query for the question "What is the area of Thailand?" needs the table [mondial_country], so we don\'t need JOIN.\nPlus, it doesn\'t require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""].\nSo, we don\'t need JOIN and don\'t need nested queries, then the SQL query can be classified as "EASY".\nLabel: "EASY"'}


### Generate SQL

In [11]:
SQL = generating_sql_by_type_module(db, llm, question, schema_links, classification)
print(SQL)


SELECT area FROM mondial_country WHERE name = 'Thailand'
SELECT area FROM mondial_country WHERE name = 'Thailand'


### Self-Correction

In [12]:
SQL_FINAL = self_correction_module(db, llm, question, SQL)
print(SQL_FINAL)

SELECT area FROM mondial_country WHERE name = 'Thailand'
SELECT area FROM mondial_country WHERE name = 'Thailand'
