# Experimento: DIN/C3 - Mondial - Foreign Key - GPT 3.5

In [None]:

from langchain.chat_models import ChatOpenAI
from urllib.parse import quote  

from dotenv import load_dotenv
import time
import os
import sys
import json
load_dotenv()



path = os.path.abspath('')
experiment_path = '..\..'
module_path = os.path.join(path, experiment_path)
print(module_path)
if module_path not in sys.path:
    sys.path.append(module_path+"\\functions")


from sqldatabase_langchain_utils import SQLDatabaseLangchainUtils

In [2]:
module_path = os.path.join(path, "..\\DIN")
if module_path not in sys.path:
    sys.path.append(module_path)
    
from din_classification import classification_module
from din_self_correction import self_correction_module



In [3]:
from schema_linking_din_c3 import schema_linking_din_c3
from sql_generation_with_hints_din_c3 import generating_sql_with_hints

In [4]:
SCHEMA = 'mondial_gpt'
PREFIX = 'mondial'
FILE_NAME_RESULT = f"results/5_combining_din_c3_queries_chatgpt_{SCHEMA}_fk.json"

In [5]:
def save_queries(queries):
    data = {"queries":queries}
    with open(FILE_NAME_RESULT, "w") as arquivo_json:
        json.dump(data, arquivo_json, indent=4) 
        
def read_queries():
    with open(FILE_NAME_RESULT, encoding='utf-8', errors='ignore') as json_data:
        data = json.load(json_data, strict=False)
    queries = data["queries"]
    return queries


## Conexão com o banco

In [None]:
json_file_path = f"{experiment_path}/datasets/{SCHEMA}_db_connection.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    db_connection = json.load(json_data, strict=False)
db_connection

In [7]:
db = SQLDatabaseLangchainUtils(db_connection=db_connection)

exclusao = [
    f"{SCHEMA}_tmdp",
    f"{SCHEMA}_tmdpmap",
    f"{SCHEMA}_tmds",
    f"{SCHEMA}_tmjmap",
    f"{SCHEMA}_tpv",
    f"{SCHEMA}_tmdc",
    f"{SCHEMA}_tmdcmap",
    f"{SCHEMA}_tmdej",
    f"{SCHEMA}_log_action",
    f"{SCHEMA}_log_error",
    f"{SCHEMA}_favorite_item", 
    f"{SCHEMA}_favorite_query",
    f"{SCHEMA}_favorite_tag",
    f"{SCHEMA}_favorite_tag_item",
    f"{SCHEMA}_favorite_visualization",
    f"{SCHEMA}_dashboard",
    f"{SCHEMA}_history",
    "teste_cliente",
    "teste_fornecedor",
    "teste_funcionario"
]

include_tables = [s for s in db.get_table_names() if not s.startswith(PREFIX) and s not in exclusao]
db = SQLDatabaseLangchainUtils(db_connection=db_connection, include_tables=include_tables)
db.get_table_names()

  self._metadata.reflect(
  self._metadata.reflect(
  self._metadata.reflect(


['airport',
 'borders',
 'city',
 'citylocalname',
 'cityothername',
 'citypops',
 'continent',
 'country',
 'countrylocalname',
 'countryothername',
 'countrypops',
 'desert',
 'economy',
 'encompasses',
 'ethnicgroup',
 'geo_desert',
 'geo_estuary',
 'geo_island',
 'geo_lake',
 'geo_mountain',
 'geo_river',
 'geo_sea',
 'geo_source',
 'island',
 'islandin',
 'ismember',
 'lake',
 'lakeonisland',
 'language',
 'located',
 'locatedon',
 'mergeswith',
 'mountain',
 'mountainonisland',
 'organization',
 'politics',
 'population',
 'province',
 'provincelocalname',
 'provinceothername',
 'provpops',
 'religion',
 'river',
 'riveronisland',
 'riverthrough',
 'sea']

## Preparando as consultas 

In [8]:

json_file_path = f"../../datasets/{PREFIX}/queries_{PREFIX}.json"
with open(json_file_path, encoding='utf-8', errors='ignore') as json_data:
    queries = json.load(json_data, strict=False)
queries = queries['queries']
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': '',
  'type': 'simple'},
 {'id': '2',
  'question': 'What are the provinces with an area greater than 10000?',
  'query_string': '',
  'type': 'simple'},
 {'id': '3',
  'question': 'What are the languages spoken in Poland?',
  'query_string': '',
  'type': 'medium'},
 {'id': '4',
  'question': 'How deep is Lake Kariba?',
  'query_string': '',
  'type': 'simple'},
 {'id': '5',
  'question': 'What is the total of provinces of Netherlands?',
  'query_string': '',
  'type': 'complex'},
 {'id': '6',
  'question': 'What is the percentage of religious people are hindu in thailand?',
  'query_string': '',
  'type': 'complex'},
 {'id': '7',
  'question': 'List the number of provinces each river flows through.',
  'query_string': '',
  'type': 'medium'},
 {'id': '8',
  'question': 'Find all countries that became independent between 8/1/1910 and 8/1/1950.',
  'query_string': '',
  'type': 'complex'},
 {'id': '9',
  'que

## Tracking token usage

In [9]:
track_token = [] 
def tracking_token(cb =None, reset = False):
    global track_token
    track_token.append(cb)
    if reset:
        track_token = []

In [10]:

def convert_to_dict_tracking_token():
    token_usage = {}
    for e in track_token:
        for key in e.keys():
            token_usage[key] = {}
            token_usage[key]['total_tokens'] = e[key].total_tokens
            token_usage[key]['total_cost'] = e[key].total_cost
            token_usage[key]['prompt_tokens'] = e[key].prompt_tokens
            token_usage[key]['completion_tokens'] = e[key].completion_tokens 
    return token_usage

## DIN-C3 Function

In [11]:

def run_din_c3(question, db, model_name = 'gpt-3.5-turbo-16k', add_fk = True, callback= None):

    llm_c3 = ChatOpenAI(model_name = model_name, temperature=0.7, n=10)

    model_kwargs = {'top_p':1.0, 'frequency_penalty':0.0, 'presence_penalty':0.0}
    llm_din = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=600, model_kwargs = model_kwargs)

    model_kwargs['stop'] = ['#', ';','\n\n']
    llm_din_fix = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=350, model_kwargs = model_kwargs)

    #Clear prompting
    schema_links, tables = schema_linking_din_c3(question, db, llm_c3, llm_din, add_fk = add_fk, callback= callback)
    print(schema_links)    
    #Classification
    classification = classification_module(db, llm_din, question, schema_links, tables=tables, callback= callback)
    print(classification)
    #Calibration with hints and sql generation by type
    SQL = generating_sql_with_hints(db, llm_din, question, schema_links, classification, tables=tables, callback= callback)

    #Self correction
    SQL_FINAL = self_correction_module(db, llm_din_fix, question, SQL, callback= callback)
    
    return SQL_FINAL


## Executando o método

In [None]:
tracking_token(reset=True)
errors = []
query_index = -1
queries_to_delay = 4
count = 0
for instance in queries:
    query_index +=1
    try:
        start_time = time.time()
        sql = run_din_c3(instance["question"], db, add_fk=False, callback=tracking_token)
        end_time = time.time()
        instance["query_string"] = sql
        instance["token_usage"] = convert_to_dict_tracking_token()
        instance['time'] = end_time - start_time
        save_queries(queries)
        print(instance["id"], instance["question"], sql)
    except Exception as e:
        print(str(e))
        errors.append(query_index)
        pass
    finally:
        count+=1
        if queries_to_delay==count:
            time.sleep(30)
            count = 0
        tracking_token(reset=True)
        


#### Fixing query

In [None]:

pos = 40
instance = queries[pos]
q = read_queries()
start_time = time.time()
sql = run_din_c3(instance["question"], db, callback=tracking_token)
end_time = time.time()
instance["query_string"] = sql
instance["token_usage"] = convert_to_dict_tracking_token()
instance['time'] = end_time - start_time
q[pos] = instance
save_queries(q)
print(instance['id'], instance['question'], instance["query_string"], instance['time'])



In [None]:
errors = [59,62,72,85,99]
for pos in errors:
    instance = queries[pos]
    q = read_queries()
    start_time = time.time()
    sql = run_din_c3(instance["question"], db, callback=tracking_token)
    end_time = time.time()
    instance["query_string"] = sql
    instance["token_usage"] = convert_to_dict_tracking_token()
    instance['time'] = end_time - start_time
    q[pos] = instance
    save_queries(q)
    print(instance['id'], instance['question'], instance["query_string"], instance['time'])

In [None]:
errors

[]

### Executar caso exista erros

In [None]:
if len(errors)>0:
    tracking_token(reset=True)
    queries_to_delay = 4
    count = 0
    fixed = []
    for query_index in errors:
        instance = queries[query_index]
        query_index +=1
        try:
            start_time = time.time()
            sql = run_din_c3(instance["question"], db, add_fk=False, callback=tracking_token)
            end_time = time.time()
            instance["query_string"] = sql
            instance["token_usage"] = convert_to_dict_tracking_token()
            instance['time'] = end_time - start_time
            save_queries(queries)
            print(instance["id"], instance["question"], sql)
            fixed.append(query_index)
        except Exception as e:
            print(str(e))
            pass
        finally:
            count+=1
            if queries_to_delay==count:
                time.sleep(30)
                count = 0
            tracking_token(reset=True)
            
    aux = [x for x in errors if x not in fixed]
    errors = aux

In [16]:
queries

[{'id': '1',
  'question': 'What is the area of Thailand?',
  'query_string': "SELECT country.area FROM country WHERE country.name = 'Thailand'",
  'type': 'simple',
  'token_usage': {'table_recall': {'total_tokens': 5648,
    'total_cost': 0.021871,
    'prompt_tokens': 721,
    'completion_tokens': 4927},
   'column_recall': {'total_tokens': 6724,
    'total_cost': 0.026648,
    'prompt_tokens': 248,
    'completion_tokens': 6476},
   'schema_linking': {'total_tokens': 4221,
    'total_cost': 0.012746,
    'prompt_tokens': 4138,
    'completion_tokens': 83},
   'classification': {'total_tokens': 2287,
    'total_cost': 0.006958000000000001,
    'prompt_tokens': 2190,
    'completion_tokens': 97},
   'sql_generation_din_c3': {'total_tokens': 1525,
    'total_cost': 0.004588,
    'prompt_tokens': 1512,
    'completion_tokens': 13},
   'self_correction': {'total_tokens': 2193,
    'total_cost': 0.006595,
    'prompt_tokens': 2177,
    'completion_tokens': 16}}},
 {'id': '2',
  'question

## Testes com cada módulo

In [12]:
question = 'What are the provinces with an area greater than 10000?'

model_name = 'gpt-3.5-turbo-16k'
llm_c3 = ChatOpenAI(model_name = 'gpt-3.5-turbo', temperature=0.7, n=10)
model_kwargs = {'top_p':1.0, 'frequency_penalty':0.0, 'presence_penalty':0.0}
llm_din = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=600, model_kwargs = model_kwargs)

model_kwargs['stop'] = ['#', ';','\n\n']
llm_din_fix = ChatOpenAI(model_name = model_name, temperature=0.0, n=1, max_tokens=350, model_kwargs = model_kwargs)


In [13]:

schema_links, tables = schema_linking_din_c3(question, db, llm_c3, llm_din, add_fk = True)


# mondial_city (name, country, province, population, latitude, longitude, elevation, meta_repcol)
# mondial_citypops (city, country, province, year, population, meta_repcol)
# mondial_country (name, code, capital, province, area, population, meta_repcol)
# mondial_province (name, country, population, area, capital, capprov, meta_repcol)



In [14]:
print(schema_links)
print(tables)

[mondial_province.name,mondial_province.area,10000]
['mondial_province', 'mondial_country', 'mondial_city', 'mondial_citypops']


In [15]:
classification = classification_module(db, llm_din, question, schema_links, tables=tables)
classification

{'predicted_class': '"EASY"',
 'classification_label': 'The SQL query for the question "What are the provinces with an area greater than 10000?" needs these tables = [mondial_province], so we don\'t need JOIN.\nPlus, it doesn\'t require nested queries with (INTERSECT, UNION, EXCEPT, IN, NOT IN), and we need the answer to the questions = [""].\nSo, we don\'t need JOIN and don\'t need nested queries, then the the SQL query can be classified as "EASY".\nLabel: "EASY"'}

In [16]:
SQL = generating_sql_with_hints(db, llm_din, question, schema_links, classification, tables=tables)
print(SQL)

EASY
SQL Parcial:  SELECT name FROM mondial_province WHERE area > 10000
SELECT name FROM mondial_province WHERE area > 10000


In [17]:
db.run(SQL)

'[(\'Oberösterreich\',), (\'Tirol\',), (\'Steiermark\',), (\'Niederösterreich\',), (\'Jihoceský\',), (\'Stredoceský\',), (\'Baden-Württemberg\',), (\'Bayern\',), (\'Brandenburg\',), (\'Hessen\',), (\'Mecklenburg-Vorpommern\',), (\'Niedersachsen\',), (\'Nordrhein-Westfalen\',), (\'Rheinland-Pfalz\',), (\'Sachsen\',), (\'Sachsen-Anhalt\',), (\'Schleswig-Holstein\',), (\'Thüringen\',), (\'Piemonte\',), (\'Lombardia\',), (\'Trentino-Alto Adige\',), (\'Veneto\',), (\'Emilia-Romagna\',), (\'Toscana\',), (\'Lazio\',), (\'Abruzzo\',), (\'Campania\',), (\'Puglia\',), (\'Calabria\',), (\'Sicilia\',), (\'Sardegna\',), (\'Slovenia\',), (\'Brest\',), (\'Vitebsk\',), (\'Gomel\',), (\'Grodno\',), (\'Mogilev\',), (\'Minsk\',), (\'Latvia\',), (\'Lithuania\',), (\'Dolnoslaskie\',), (\'Kujawsko-Pomorskie\',), (\'Lubelskie\',), (\'Lubuskie\',), (\'Lódzkie\',), (\'Malopolskie\',), (\'Mazowieckie\',), (\'Podkarpackie\',), (\'Podlaskie\',), (\'Pomorskie\',), (\'Slaskie\',), (\'Swietokrzyskie\',), (\'Warminsk