In [2]:
%%capture
!pip install -q jsonlines
!pip install -q datasets transformers torch evaluate
!pip install -q rouge_score

In [3]:
%%capture
!python -m spacy download pt_core_news_sm
!python -m spacy download en_core_web_sm

**Monta o Drive**

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


**Define as constantes globais**

In [5]:
# Nome do dataset alvo, dentre os datasets listados em SQL_DATA_INFO
DATASET_TARGET = "spider-en"

# idioma alvo do treinamento para filtro do dataset
LANGUAGE_TARGET = "EN"

# MODEL = 'google/flan-t5-large' # quebrou com outofmemory no pytorch na linha de treinamento

# TODO reduzir o batch size para testar
# MODEL = 'google/flan-t5-base' # quebrou com outofmemory no pytorch na linha de treinamento
MODEL = 'google/flan-t5-small'

# controla se deve ou não salvar as épocas durante o treinamento
SHOULD_SAVE_EPOCH = False

# numero de épocas a ser treinado
NUM_EPOCHS = 10
BATCH_SIZE = 10
USE_FP16 = False

BASE = "/content/drive/MyDrive/Mestrado/Projeto"

DATA_PATH = f"{BASE}/data"
DATA_OUTPUT_PATH = f"{DATA_PATH}/{DATASET_TARGET}-ajusted"

MODELS_PATH = f"{BASE}/models"
TRAINNING_PATH = f"{BASE}/training"

OUTPUT_MODEL = f"{MODELS_PATH}/{DATASET_TARGET}-{MODEL.replace('/', '-')}"


PREFIX_ANNOTATED = "annotated-"
PREFIX_PROCESSED = "processed-"
PREFIX_ONE_SHOT = "processed-one-shot-"


SQL_DATA_INFO = {
    "spider-en-pt" : {
        "name":"spider-en-pt",
        "languages": ['EN', 'PT'],
        "train_tables":"tables.json",
        "dev_tables":"tables.json",
        "eval_tables":"tables.json",
        "train_file": "train_spider.json",
        "evaluate_file": "train_others.json",
        "dev_file": "dev.json",
        "db_id_name": "db_id",
        "output_name": "query",
        "is_multiple_turn": False,
    },
    "spider-en" : {
        "name":"spider-en",
        "languages": ['EN'],
        "train_tables":"tables.json",
        "dev_tables":"tables.json",
        "eval_tables":"tables.json",
        "train_file": "train_spider.json",
        "evaluate_file": "train_others.json",
        "dev_file": "dev.json",
        "db_id_name": "db_id",
        "output_name": "query",
        "is_multiple_turn": False,
    },
    "spider-pt" : {
        "name":"spider-pt",
        "languages": ['PT'],
        "train_tables":"tables.json",
        "dev_tables":"tables.json",
        "eval_tables":"tables.json",
        "train_file": "train_spider.json",
        "evaluate_file": "train_others.json",
        "dev_file": "dev.json",
        "db_id_name": "db_id",
        "output_name": "query",
        "is_multiple_turn": False,
    },
    # "bird" : {
    #     "name": "bird",
    #     "train_file": "train/train.json",
    #     "evaluate_file": "",
    #     "dev_file": "dev/dev.json",
    #     "train_tables": "train/train_tables.json",
    #     "eval_tables": "train/train_tables.json",
    #     "dev_tables": "dev/dev_tables.json",
    #     "db_id_name": "db_id",
    #     "output_name": "SQL",
    #     "is_multiple_turn": False,
    # }
}

INSTRUCTION_PROMPT = """\
I want you to act as a SQL terminal in front of an example database, \
you need only to return the sql command to me.Below is an instruction that describes a task, \
Write a response that appropriately completes the request.\n"
##Instruction:\n{}\n"""

INSTRUCTION_ONE_SHOT_PROMPT = """\
I want you to act as a SQL terminal in front of an example database. \
You need only to return the sql command to me. \
First, I will show you few examples of an instruction followed by the correct SQL response. \
Then, I will give you a new instruction, and you should write the SQL response that appropriately completes the request.\
\n### Example1 Instruction:
The database contains tables such as employee, salary, and position. \
Table employee has columns such as employee_id, name, age, and position_id. employee_id is the primary key. \
Table salary has columns such as employee_id, amount, and date. employee_id is the primary key. \
Table position has columns such as position_id, title, and department. position_id is the primary key. \
The employee_id of salary is the foreign key of employee_id of employee. \
The position_id of employee is the foreign key of position_id of position.\
\n### Example1 Input:\nList the names and ages of employees in the 'Engineering' department.\n\
\n### Example1 Response:\nSELECT employee.name, employee.age FROM employee JOIN position ON employee.position_id = position.position_id WHERE position.department = 'Engineering';\
\n###New Instruction:\n{}\n"""

INPUT_PROMPT = "###Input:\n{}\n\n###Response:"


## Pipeline T5

### Processa o Dataset

In [6]:
# Adiciona a informação do idioma para cada amostra
from typing import List
import json
import re
import os

class DatasetMarker:
  def __init__(self, dataset):
     self.dataset = dataset

  def process(self):
    """ Processa o dataset """

    if self.dataset is not None:
      for file in [self.dataset['train_file'], self.dataset['evaluate_file'], self.dataset['dev_file']]:
        self.__process_file(file)

  def __complexity_discover_of_query(self, sql_query):
    """Classifica o nível de dificuldade do SQL com base nos critérios do Spider."""

    # Contar o número de colunas no SELECT
    select_match = re.search(
        r"\bSELECT\b\s+(.*?)(\bFROM\b)", sql_query, re.IGNORECASE | re.DOTALL
    )
    if select_match:
        select_columns = select_match.group(1).split(",")
        num_select = len([col.strip() for col in select_columns if col.strip()])
    else:
        num_select = 0

    # Contar o número de condições no WHERE
    where_conditions = re.findall(
        r"\bWHERE\b(.*?)(\bGROUP BY\b|\bORDER BY\b|$)",
        sql_query,
        re.IGNORECASE | re.DOTALL,
    )
    num_where = 0
    if where_conditions:
        where_clause = where_conditions[0][0]
        num_where = (
            len(re.findall(r"AND|OR", where_clause, re.IGNORECASE))
            if where_clause.strip()
            else 0
        )

    # Contar o número de colunas no GROUP BY
    group_by_match = re.search(
        r"\bGROUP BY\b\s+(.*?)(\bORDER BY\b|$)",
        sql_query,
        re.IGNORECASE | re.DOTALL,
    )
    if group_by_match:
        group_by_columns = group_by_match.group(1).split(",")
        num_group_by = len([col.strip() for col in group_by_columns if col.strip()])
    else:
        num_group_by = 0

    # Contar o número de colunas no ORDER BY
    order_by_match = re.search(
        r"\bORDER BY\b\s+(.*?)(LIMIT|$)", sql_query, re.IGNORECASE | re.DOTALL
    )
    if order_by_match:
        order_by_columns = order_by_match.group(1).split(",")
        num_order_by = len([col.strip() for col in order_by_columns if col.strip()])
    else:
        num_order_by = 0

    # Contar subconsultas com base nos parênteses
    num_nested = len(re.findall(r"\(SELECT\b", sql_query, re.IGNORECASE))

    # Contar o número de junções (JOIN)
    num_joins = len(re.findall(r"\bJOIN\b", sql_query, re.IGNORECASE))

    # Verificar a presença de EXCEPT, INTERSECT e UNION
    has_except = bool(re.search(r"\bEXCEPT\b", sql_query, re.IGNORECASE))
    has_intersect = bool(re.search(r"\bINTERSECT\b", sql_query, re.IGNORECASE))
    has_union = bool(re.search(r"\bUNION\b", sql_query, re.IGNORECASE))

    # Critério especial para subconsultas com JOIN
    has_nested_join = bool(
        re.search(r"\(SELECT\b.*?\bJOIN\b", sql_query, re.IGNORECASE | re.DOTALL)
    )

    # Classificação de dificuldade com base nos critérios do Spider
    if has_union:
        return "extra hard"  # `UNION` é sempre "extra hard"
    elif has_nested_join or num_nested > 1:
        return "extra hard"  # Subconsulta com JOIN ou múltiplas subconsultas
    elif (
        num_select <= 1
        and num_where <= 1
        and num_group_by == 0
        and num_order_by == 0
        and num_nested == 0
        and num_joins == 0
        and not (has_except or has_intersect)
    ):
        return "easy"
    elif (
        num_select <= 3
        and num_where <= 2
        and num_group_by <= 1
        and num_order_by <= 1
        and num_nested == 0
        and num_joins <= 1
        and not (has_except or has_intersect)
    ):
        return "medium"
    elif (
        num_group_by > 1
        or num_order_by > 1
        or num_nested > 0
        or num_where > 2
        or num_joins > 1
        or has_except
        or has_intersect
    ):
        return "hard"
    else:
        return "extra hard"

  def __process_file(self, file_name):
    """ """

    with open(os.path.join(DATA_PATH, self.dataset["name"], file_name), 'r') as file:
      dataset = json.load(file)

      print(f"Arquivo a ser tratado: {file_name}\n\n")
      print("Quantidade de amostras no arquivo: ", len(dataset))

      qtd_by_language = int(len(dataset) / len(self.dataset['languages']))

      print(f"Quantidade de amostras por idioma: {qtd_by_language}\n\n")

      language_indicator = 0
      data_indicator = 0

      for i, data in enumerate(dataset):
        # descobre a complexidade da amostra
        data['complexity'] = self.__complexity_discover_of_query(data['query'])

        # descobre o idioma da amostra
        if data_indicator < qtd_by_language:
          data['language'] = self.dataset['languages'][language_indicator]
        else:
          language_indicator += 1
          data_indicator = 0
          data['language'] = self.dataset['languages'][language_indicator]

        data_indicator += 1

      if len(self.dataset['languages']) > 1:
          print("Amostras da fronteira:")
          print(
              f"* {dataset[qtd_by_language - 1]['language']}: {dataset[qtd_by_language - 1]['question']}"
          )
          print(
              f"* {dataset[qtd_by_language]['language']}: {dataset[qtd_by_language]['question']}"
          )

      out_file_name = PREFIX_ANNOTATED + file_name

      self.__write_dataset_in_file(out_file_name, dataset)

      print(f"\n\nArquivo \"{out_file_name}\" tratado e salvo com sucesso!\n")
      print("=============================================================\n")

  def __write_dataset_in_file(self, file_name, data):
    """ Escreve o dataset tratado em um arquivo """

    # cria o diretorio se não exitir
    os.makedirs(f"{DATA_OUTPUT_PATH}", exist_ok=True)

    # reescreve o arquivo com as devidas alterações
    with open(f"{DATA_OUTPUT_PATH}/{file_name}", 'w') as file:
        json.dump(data, file)


In [None]:
# marker = DatasetMarker(SQL_DATA_INFO["spider-en-pt"])
# marker.process()

In [10]:
import numpy as np

class DatasetReport:
    def __init__(self, dataset, files):
        self.dataset = dataset
        self.files = files

    def __report(self, file_name):
        words_in = []
        words_out = []

        with open(file_name, 'r') as file:
            dataset = json.load(file)

            print(f"Quantidade de amostras: {len(dataset)}")

            qtd_easy = 0
            qtd_medium = 0
            qtd_hard = 0
            qtd_extra_hard = 0

            for i, data in enumerate(dataset):
              if data['difficulty'] == 'easy':
                qtd_easy += 1
              elif data['difficulty'] == 'medium':
                qtd_medium += 1
              elif data['difficulty'] == 'hard':
                qtd_hard += 1
              elif data['difficulty'] == 'extra hard':
                qtd_extra_hard += 1

              words_in.append(data['count_words_in'])
              words_out.append(data['count_words_out'])



        print(f"\nQuantidade de amostras fáceis: {qtd_easy} = {round((qtd_easy/len(dataset)) * 100, 2)}%")
        print(f"Quantidade de amostras médias: {qtd_medium} = {round((qtd_medium/len(dataset)) * 100, 2)}%")
        print(f"Quantidade de amostras difíceis: {qtd_hard} = {round((qtd_hard/len(dataset)) * 100, 2)}%")
        print(f"Quantidade de amostras extra difíceis: {qtd_extra_hard} = {round((qtd_extra_hard/len(dataset)) * 100, 2)}%")

        print(f"\nQuantidade mínima de palavras no input: {np.min(words_in)}")
        print(f"Quantidade máxima de palavras no input: {np.max(words_in)}")
        print(f"Quantidade média de palavras no input: {np.mean(words_in)}")
        print(f"Quantidade total de palavras no input: {np.sum(words_in)}")
        print(f"Desvio padrão: {np.std(words_in)}")

        print(f"\nQuantidade mínima de palavras no output: {np.min(words_out)}")
        print(f"Quantidade máxima de palavras no output: {np.max(words_out)}")
        print(f"Quantidade média de palavras no output: {np.mean(words_out)}")
        print(f"Quantidade total de palavras no output: {np.sum(words_out)}")
        print(f"Desvio padrão: {np.std(words_out)}")
        print("=======================================\n\n")

    def report(self):
        """report"""
        print("\n\n=======================================")
        print("Relatório de processamento do dataset.")
        print(f"Dataset: {dataset['name']}")

        for file in self.files:
          print(f"\nArquivo: {file}")
          self.__report(file)


In [11]:
from genericpath import exists
from tqdm import tqdm
import jsonlines
import argparse
import json
import os
import re
import spacy

class ProcessDataset:
  def __init__(self, dataset, train_file, eval_file, dev_file, num_shot=0, code_representation=False):
    self.dataset = dataset
    self.num_shot = num_shot
    self.code_representation = code_representation

    self.train_file = train_file
    self.eval_file = eval_file
    self.dev_file = dev_file

    self.nlp_en = spacy.load("en_core_web_sm")
    self.nlp_pt = spacy.load("pt_core_news_sm")

  def __def_verifica_anotacao(self):
        exists_train = os.path.isfile(os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{self.dataset['train_file']}"))
        exists_eval = os.path.isfile(os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{self.dataset['evaluate_file']}"))
        exists_dev = os.path.isfile(os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{self.dataset['dev_file']}"))

        if exists_train and exists_eval and exists_dev:
          return True

        return False

  def __count_words(self, text, language):
        return len(self.nlp_en(text) if language == "EN" else self.nlp_pt(text))

  def __decode_json_file(
        self,
        data_file_list,
        table_file,
        db_folder_path,
        db_id_name,
        output_name,
        is_multiple_turn=False,
    ):
        """
        @TODO:
            1. Colocar o prompt relacionado no arquivo de configuração
            2. Colocar as informações dos campos de diferentes fontes de dados no arquivo de configuração
        """

        if table_file.endswith(".jsonl"):
            tables = jsonlines.open(table_file)
            datas = []
            for data_file in data_file_list:
                datas.extend(jsonlines.open(data_file))

        elif table_file.endswith(".json"):
            with open(table_file) as table:
                tables = json.load(table)
                datas = []
                for data_file in data_file_list:
                    with open(data_file) as data:
                        datas.extend(json.load(data))
        else:
            print("Unsupported file types")
            raise ValueError("Unsupported file types")

        # Primeiro, processe corretamente as tabelas e colunas do db_id
        db_dict = {}
        for item in tables:
            tables = item["table_names_original"]
            coloumns = item["column_names_original"][1:]
            primary_key = item["primary_keys"]
            foreign_keys = item["foreign_keys"]

            source = (
                item["db_id"] + " contains tables such as " + ", ".join(tables) + ". "
            )

            for i, name in enumerate(tables):
                data = [coloumn[1] for coloumn in coloumns if coloumn[0] == i]
                source += (
                    "Table " + name + " has columns such as " + ", ".join(data) + ". "
                )

                # get primary key info
                for j in range(len(primary_key)):
                    if type(primary_key[j]) == int:
                        if coloumns[primary_key[j] - 1][0] == i:
                            source += (
                                coloumns[primary_key[j] - 1][1]
                                + " is the primary key."
                                + "\n"
                            )

                    # combination primary key
                    elif type(primary_key[j]) == list:
                        combine_p = "The combination of ("
                        keys = []

                        for k in range(len(primary_key[j])):
                            if coloumns[primary_key[j][k] - 1][0] == i:
                                keys.append(coloumns[primary_key[j][k] - 1][1])

                        source += (
                            combine_p
                            + ", ".join(keys)
                            + ") are the primary key."
                            + "\n"
                        )
                    else:
                        print("not support type", type(primary_key[j]))
                        continue

            # get foreign key info
            for key in foreign_keys:
                source += (
                    "The "
                    + coloumns[key[0] - 1][1]
                    + " of "
                    + tables[coloumns[key[0] - 1][0]]
                    + " is the foreign key of "
                    + coloumns[key[1] - 1][1]
                    + " of "
                    + tables[coloumns[key[1] - 1][0]]
                    + ".\n"
                )

            db_dict[item["db_id"]] = source

        res = []
        base_instruction = INSTRUCTION_PROMPT

        if self.num_shot == 1:
            base_instruction = INSTRUCTION_ONE_SHOT_PROMPT

        for data in tqdm(datas):
            if data[db_id_name] in db_dict.keys():
                if is_multiple_turn:  # Múltiplas rodadas
                    history = []

                    for interaction in data["interaction"]:
                        sql_query = interaction[output_name]

                        input = INPUT_PROMPT.format(interaction["utterance"])
                        context = db_dict[data[db_id_name]]

                        input_data = {
                            "db_id": data[db_id_name],
                            "instruction": base_instruction.format(
                                context
                            ),
                            "context": context,
                            "input": input,
                            "language": data["language"],
                            "output": sql_query,
                            "difficulty": data["complexity"],
                            "history": history,
                            "count_words_in": self.__count_words(input.replace("\n\n###Response:", "\n\n###Context:\n") + context + "\n\n###Response:", data["language"]),
                            "count_words_out": self.__count_words(sql_query, data["language"]),
                        }

                        res.append(input_data)
                        history.append(
                            (
                                INPUT_PROMPT.format(interaction["utterance"]),
                                interaction[output_name],
                            )
                        )
                else:  # Rodada única
                    sql_query = data[output_name]

                    if self.code_representation:
                        db_path = os.path.join(db_folder_path, data[db_id_name])
                        sql_file_path = next(
                            (
                                file
                                for file in os.listdir(db_path)
                                if file.endswith(".sql")
                            ),
                            None,
                        )

                        if sql_file_path is None:
                            continue  # Encerrar a iteração antecipadamente

                        schema_file_path = os.path.join(db_path, sql_file_path)

                        with open(schema_file_path, "r", encoding="utf8") as file:
                            schema_content = file.read()

                        create_statements = re.findall(
                            r"CREATE\s.*?;", schema_content, re.DOTALL | re.IGNORECASE
                        )

                        input = INPUT_PROMPT.format(data["question"])

                        input_data = {
                            "db_id": data[db_id_name],
                            "instruction": INSTRUCTION_PROMPT.format(create_statements),
                            "context": create_statements,
                            "input": input,
                            "language": data["language"],
                            "output": sql_query,
                            "difficulty": data["complexity"],
                            "history": [],
                            "count_words_in": self.__count_words(input.replace("\n\n###Response:", "\n\n###Context:\n") + create_statements + "\n\n###Response:", data["language"]),
                            "count_words_out": self.__count_words(sql_query, data["language"]),
                        }
                        res.append(input_data)
                    else:
                        input = INPUT_PROMPT.format(data["question"])
                        context = db_dict[data[db_id_name]]
                        input_data = {
                            "db_id": data[db_id_name],
                            "instruction": base_instruction.format(
                                context
                            ),
                            "context": context,
                            "input": input,
                            "language": data["language"],
                            "output": sql_query,
                            "difficulty": data["complexity"],
                            "history": [],
                            "count_words_in": self.__count_words(input.replace("\n\n###Response:", "\n\n###Context:\n") + context + "\n\n###Response:", data["language"]),
                            "count_words_out": self.__count_words(sql_query, data["language"]),
                        }
                        res.append(input_data)
        return res

  def process(self, report=True):
        """process"""
        print("Iniciando processador do dataset.")

        if not self.__def_verifica_anotacao():
            print("Dataset ainda não foi anotado. Efetuando anotação...")
            marker = DatasetMarker(self.dataset)
            marker.process()

        print("Dataset devidamente anotado.")

        print("\nProcessando o dataset...")

        train_data = []
        eval_data = []
        dev_data = []

        for data_info in SQL_DATA_INFO.values():
            if data_info["name"] != DATASET_TARGET:
                continue

            tfile = data_info["train_file"]
            efile = data_info["evaluate_file"]
            dfile = data_info["dev_file"]

            train_data_file_list = [
                os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{tfile}")
            ]

            train_data.extend(
                self.__decode_json_file(
                    data_file_list=train_data_file_list,
                    table_file=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        data_info["train_tables"],
                    ),
                    db_folder_path=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        "database",
                    ),
                    db_id_name=data_info["db_id_name"],
                    output_name=data_info["output_name"],
                    is_multiple_turn=data_info["is_multiple_turn"],
                )
            )

            eval_data_file_list = [
                os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{efile}")
            ]

            eval_data.extend(
                self.__decode_json_file(
                    data_file_list=eval_data_file_list,
                    table_file=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        data_info["eval_tables"],
                    ),
                    db_folder_path=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        "database",
                    ),
                    db_id_name=data_info["db_id_name"],
                    output_name=data_info["output_name"],
                    is_multiple_turn=data_info["is_multiple_turn"],
                )
            )

            dev_data_file_list = [
                os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ANNOTATED}{dfile}")
            ]

            dev_data.extend(
                self.__decode_json_file(
                    data_file_list=dev_data_file_list,
                    table_file=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        data_info["dev_tables"],
                    ),
                    db_folder_path=os.path.join(
                        DATA_PATH,
                        data_info["name"],
                        "database",
                    ),
                    db_id_name=data_info["db_id_name"],
                    output_name=data_info["output_name"],
                    is_multiple_turn=data_info["is_multiple_turn"],
                )
            )

        if train_data:
            with open(self.train_file, "w", encoding="utf-8") as s:
                json.dump(train_data, s, indent=4, ensure_ascii=False)

        if eval_data:
            with open(self.eval_file, "w", encoding="utf-8") as s:
                json.dump(eval_data, s, indent=4, ensure_ascii=False)

        if dev_data:
            with open(self.dev_file, "w", encoding="utf-8") as s:
                json.dump(dev_data, s, indent=4, ensure_ascii=False)

        if not train_data and not eval_data and not dev_data:
            print("Nenhum dataset foi processado.")
            return

        if report:
            print("Dataset processado com sucesso!")
            reporter = DatasetReport(self.dataset, [self.train_file, self.eval_file, self.dev_file])
            reporter.report()




In [12]:
# if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument(
#     "--code_representation", help="Enable code representation", default=False
# )
# args = parser.parse_args()

dataset = SQL_DATA_INFO[DATASET_TARGET]

print(f"Iniciando processamento do dataset {DATASET_TARGET}")


train_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}{dataset['train_file']}")
eval_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}{dataset['evaluate_file']}")
dev_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}{dataset['dev_file']}")

process = ProcessDataset(
    dataset=dataset,
    train_file=train_file,
    eval_file=eval_file,
    dev_file=dev_file,
    code_representation=False, # args.code_representation,
)
process.process()

print(f"Iniciando processamento do dataset {DATASET_TARGET} com One Shot Learning")

onse_shot_train_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ONE_SHOT}{dataset['train_file']}")
onse_shot_eval_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ONE_SHOT}{dataset['evaluate_file']}")
onse_shot_dev_file = os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_ONE_SHOT}{dataset['dev_file']}")

process = ProcessDataset(
    dataset=dataset,
    train_file=onse_shot_train_file,
    eval_file=onse_shot_eval_file,
    dev_file=onse_shot_dev_file,
    num_shot=1,
    code_representation=False, # args.code_representation,
)
process.process(report=False)

print(f"Finalizado processamento do Dataset!")

Iniciando processamento do dataset spider-en




Iniciando processador do dataset.
Dataset devidamente anotado.

Processando o dataset...


100%|██████████| 7000/7000 [06:01<00:00, 19.36it/s]
100%|██████████| 1659/1659 [01:45<00:00, 15.73it/s]
100%|██████████| 1034/1034 [00:41<00:00, 24.72it/s]


Dataset processado com sucesso!


Relatório de processamento do dataset.
Dataset: spider-en

Arquivo: /content/drive/MyDrive/Mestrado/Projeto/data/spider-en-ajusted/processed-train_spider.json
Quantidade de amostras: 7000

Quantidade de amostras fáceis: 1543 = 22.04%
Quantidade de amostras médias: 3808 = 54.4%
Quantidade de amostras difíceis: 1485 = 21.21%
Quantidade de amostras extra difíceis: 164 = 2.34%

Quantidade mínima de palavras no input: 73
Quantidade máxima de palavras no input: 1240
Quantidade média de palavras no input: 272.45285714285717
Quantidade total de palavras no input: 1907170
Desvio padrão: 199.48813586020498

Quantidade mínima de palavras no output: 4
Quantidade máxima de palavras no output: 112
Quantidade média de palavras no output: 20.78757142857143
Quantidade total de palavras no output: 145513
Desvio padrão: 12.224081144050318



Arquivo: /content/drive/MyDrive/Mestrado/Projeto/data/spider-en-ajusted/processed-train_others.json
Quantidade de amostras: 1659

Q

100%|██████████| 7000/7000 [05:54<00:00, 19.72it/s]
100%|██████████| 1659/1659 [01:42<00:00, 16.19it/s]
100%|██████████| 1034/1034 [00:42<00:00, 24.40it/s]


Finalizado processamento do Dataset!


###FINETUNNING T5-SMALL



In [None]:
import os
import torch
from datetime import datetime
from google.colab import userdata
from datasets import load_dataset, Dataset
from transformers import (
    Trainer,
    TrainingArguments,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
    EarlyStoppingCallback,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    TrainerCallback
)
import evaluate
import numpy as np
import json
import pandas as pd
from datasets import DatasetDict

import nltk
nltk.download('punkt_tab')

os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')


[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [None]:
def load_dataset(path, language=None):
    with open(path, 'r', encoding='utf-8') as arquivo:
        data = pd.read_json(arquivo)
        if language is None:
            return data
        else:
            # Filtra os dados onde 'language' é igual ao parâmetro fornecido
            filtered_data = data[data['language'] == language]
            return filtered_data


In [None]:
# transforma os dataframes em Dataset para a devida utilização no treinamento
dataset_dict = {
    'train': Dataset.from_pandas(
        load_dataset(
            os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}train_spider.json"), language=LANGUAGE_TARGET
        )
    ),
    'validation': Dataset.from_pandas(
        load_dataset(
            os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}train_others.json"), language=LANGUAGE_TARGET
        )
    ),
    'test': Dataset.from_pandas(
        load_dataset(
            os.path.join(DATA_OUTPUT_PATH, f"{PREFIX_PROCESSED}dev.json"), language=LANGUAGE_TARGET
        )
    ),
}
dataset = DatasetDict(dataset_dict)

In [None]:
dataset

In [None]:
dataset['train'][0]

In [None]:
# 2. Inicializar o modelo e tokenizer do T5-small
tokenizer = AutoTokenizer.from_pretrained(MODEL)

model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)

tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
# 3. Processar o dataset
def preprocess_data(examples):
    inputs = [
        # ajusta para o t5 invertendo o posicionamento do contexto com o input
        input.replace("\n\n###Response:", "\n\n###Context:\n") + context + "\n\n###Response:"
        for context, input in zip(examples['context'], examples['input'])
    ]
    targets = [sql for sql in examples['output']]

    # Tokenize inputs e targets
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128, truncation=True, padding="max_length").input_ids

    model_inputs['labels'] = labels

    return model_inputs

# Aplicar o preprocessamento
tokenized_datasets = dataset.map(
    preprocess_data,
    batched=True,
    remove_columns=dataset["train"].column_names
)


Map:   0%|          | 0/7000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1659 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

In [None]:
tokenized_datasets

In [None]:
# 4. Definir as métricas de avaliação
rouge = evaluate.load('rouge')

nltk.download('punkt')

def compute_metrics(eval_pred):

    predictions, labels = eval_pred

    # Ensure predictions and labels are within the allowable range
    predictions = np.clip(predictions, a_min=0, a_max=tokenizer.vocab_size - 1)
    labels = np.clip(labels, a_min=0, a_max=tokenizer.vocab_size - 1)

    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    decoded_preds = ['\n'.join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ['\n'.join(nltk.sent_tokenize(label_.strip())) for label_ in decoded_labels]

    # print('\n', decoded_preds)
    # print('\n', decoded_labels)

    result = rouge.compute(predictions=decoded_preds,
                            references=decoded_labels, use_stemmer=True)

    # print('\n', result)

    result = {key: value for key, value in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result['gen_len'] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}


Downloading builder script:   0%|          | 0.00/6.27k [00:00<?, ?B/s]

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
class SaveModelByEpochCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        if SHOULD_SAVE_EPOCH:
            epoch = int(state.epoch)

            output_dir_epoch = os.path.join(args.output_dir, "epochs", f"epoch-{epoch}")

            os.makedirs(output_dir_epoch, exist_ok=True)

            # print(kwargs)

            kwargs['model'].save_pretrained(output_dir_epoch)
            kwargs['processing_class'].save_pretrained(output_dir_epoch)

            print(f"Modelo da época {epoch} salvo em {output_dir_epoch}")

In [None]:
# 5. Definir os argumentos de treinamento
logging_eval_steps = len(tokenized_datasets['train']) // BATCH_SIZE

train_args = Seq2SeqTrainingArguments(
      output_dir=OUTPUT_MODEL,
      num_train_epochs=NUM_EPOCHS,
      learning_rate=1e-5, #5.6e-5
      per_device_train_batch_size=BATCH_SIZE,
      per_device_eval_batch_size=BATCH_SIZE,
      weight_decay=0.01,
      eval_steps=logging_eval_steps,
      logging_steps=logging_eval_steps,
      eval_strategy='epoch',
      predict_with_generate=True,
      report_to="none",
      save_total_limit=1,
      save_strategy='epoch',
      load_best_model_at_end=True,
      metric_for_best_model='rougeL',
      greater_is_better=True,
      push_to_hub=False,
      fp16=USE_FP16
  )

In [None]:
# import os
# os.environ["WANDB_DISABLED"] = "true"

torch.cuda.empty_cache() # limpa o cache do CUDA

# model_path = os.path.join(MODELS_PATH, f'{MODEL}_{DATASET_TARGET}')
# output_dir = f'{TRAINNING_PATH}/{MODEL}_{DATASET_TARGET}'

os.makedirs(OUTPUT_MODEL, exist_ok=True)
# os.makedirs(output_dir, exist_ok=True)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

train_encoded_dataset = tokenized_datasets["train"]
validation_encoded_dataset = tokenized_datasets["validation"]

# 6. Inicializar o Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    train_dataset=train_encoded_dataset,
    eval_dataset=validation_encoded_dataset,
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,

    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=5
        ),
        SaveModelByEpochCallback()
    ]
)

# 7. Iniciar o treinamento
if os.path.exists(OUTPUT_MODEL) and len(os.listdir(OUTPUT_MODEL)) > 0:
    trainer.train(resume_from_checkpoint=True)
else:
    trainer.train()

trainer.evaluate()

# 8. Salvar o modelo fine-tuned
trainer.save_model(OUTPUT_MODEL)

print('\n\n***Finetunning Complete!***')

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,7.8076,1.73831,0.111,0.0417,0.1018,0.1018,7.912
2,1.2214,0.9204,0.3448,0.1479,0.3209,0.3213,18.9204
3,0.4582,0.808187,0.341,0.1507,0.3192,0.3195,18.4473
4,0.2964,0.767738,0.3589,0.1631,0.3358,0.3357,18.3466
5,0.2566,0.761554,0.3629,0.1707,0.3404,0.3403,18.0657
6,0.2385,0.753485,0.3702,0.1787,0.3462,0.3463,18.2158
7,0.2282,0.753689,0.3808,0.183,0.3534,0.3538,18.3267
8,0.2203,0.752394,0.3863,0.1908,0.3588,0.3588,18.305
9,0.2173,0.751012,0.3865,0.1919,0.3595,0.3597,18.305
10,0.2147,0.750462,0.3887,0.1919,0.3608,0.361,18.3424


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr

Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Tr



***Finetunning Complete!***


**Testa o modelo treinado**

In [None]:
def use_ajusted_model(model_path, text):
  # Carregar o tokenizer
  tokenizer = AutoTokenizer.from_pretrained(model_path)

  # Carregar o modelo
  model = AutoModelForSeq2SeqLM.from_pretrained(model_path)

  input_ids = tokenizer(text, return_tensors="pt", max_length=512, truncation=True, padding="max_length").input_ids

  # Gerar a saída
  output_ids = model.generate(input_ids, max_new_tokens=128)

  # Decodificar a saída
  saida = tokenizer.decode(output_ids[0], skip_special_tokens=True)

  return saida

In [None]:
type(dataset["test"])

In [None]:
from typing import Literal
from pydantic import BaseModel

class Prediction(BaseModel):
    db_id: str
    difficulty: Literal["easy", "medium", "hard", "extra hard"]
    instruction: str
    nl: str
    sql_expected: str
    sql_predicted: str


In [None]:
import json
import os
from tqdm import tqdm

model_path = os.path.join(OUTPUT_MODEL, "")
predictions = []

progress_bar = tqdm(total=len(dataset["test"]), desc="Gerando Predições...", colour="red")
for data in dataset["test"]:
    # print(f"Complexidade: {data['difficulty']}")
    # print(f"Query: {data['instruction'] + data['input']}")
    # print(f"Resposta esperada: {data['output']}")
    # print(f"Resposta gerada: {use_ajusted_model(model_path, data['instruction'] + data['input'])}")
    # print("\n################################\n")

    predictions.append(
        Prediction(
            db_id=data['db_id'],
            difficulty=data['difficulty'],
            instruction=data['instruction'],
            nl=data['input'],
            sql_expected=data['output'],
            sql_predicted=use_ajusted_model(model_path, data['instruction'] + data['input'])
        ).model_dump()
    )

    progress_bar.update(1)

# Escrever o JSON em um arquivo
with open(os.path.join(OUTPUT_MODEL, 'predictions.json'), 'w') as f:
    f.write(json.dumps(predictions, indent=4))

print("Predictions saved!")


Gerando Predições...:  49%|[31m████▉     [0m| 507/1034 [18:11<18:29,  2.11s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (963 > 512). Running this sequence through the model will result in indexing errors
Gerando Predições...:  49%|[31m████▉     [0m| 508/1034 [18:16<24:14,  2.76s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (962 > 512). Running this sequence through the model will result in indexing errors
Gerando Predições...:  49%|[31m████▉     [0m| 509/1034 [18:20<27:29,  3.14s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (959 > 512). Running this sequence through the model will result in indexing errors
Gerando Predições...:  49%|[31m████▉     [0m| 510/1034 [18:22<24:19,  2.79s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (957 > 512). Running this sequen

Predictions saved!


### Exemplo uso T5

In [None]:
import torch
import numpy as np
import evaluate
import os
import nltk
import time

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, EarlyStoppingCallback
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer


nltk.download('punkt')
rouge = evaluate.load('rouge')

def preprocess_function(examples, max_input_len_, max_target_len_, tokenizer_):
    model_inputs = tokenizer_(examples['text'], max_length=max_input_len_, truncation=True)
    labels = tokenizer_(examples['summary'], max_length=max_target_len_, truncation=True)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs


def compute_eval_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    decoded_preds = ['\n'.join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ['\n'.join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    result = rouge.compute(predictions=decoded_preds,
                           references=decoded_labels,
                           use_stemmer=False)
    result = {key: value for key, value in result.items()}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result['gen_len'] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}


if __name__ == '__main__':

    is_turn_off_computer = False

    # model_name = 'ptt5_small'
    # model_name = 'ptt5_base'
    # model_name = 'ptt5_large'

    # model_name = 'flan_t5_small'
    # model_name = 'flan_t5_base'
    model_name = 'flan_t5_large'

    # model_name = 'ptt5_v2_small'
    # model_name = 'ptt5_v2_base'
    # model_name = 'ptt5_v2_large'

    dataset_name = 'recognasumm'
    # dataset_name = 'xlsum'

    use_fp16 = False

    models_dir = '../../data/models'
    training_dir = '../../data/training'

    n_examples = -1

    num_epochs = 20

    max_input_len = 512
    max_summary_len = 150

    batch_size = 32

    if model_name == 'flan_t5_base' or model_name == 'ptt5_v2_base':
        batch_size = 8
    elif model_name == 'flan_t5_large':
        batch_size = 3
    elif '_large' in model_name:
        batch_size = 4

    model_checkpoint = None

    if model_name == 'flan_t5_small':
        model_checkpoint = 'google/flan-t5-small'
    elif model_name == 'flan_t5_base':
        model_checkpoint = 'google/flan-t5-base'
    elif model_name == 'flan_t5_large':
        model_checkpoint = 'google/flan-t5-large'
    elif model_name == 'ptt5_small':
        model_checkpoint = 'unicamp-dl/ptt5-small-portuguese-vocab'
    elif model_name == 'ptt5_base':
        model_checkpoint = 'unicamp-dl/ptt5-base-portuguese-vocab'
    elif model_name == 'ptt5_large':
        model_checkpoint = 'unicamp-dl/ptt5-large-portuguese-vocab'
    elif model_name == 'ptt5_v2_small':
        model_checkpoint = 'unicamp-dl/ptt5-v2-small'
    elif model_name == 'ptt5_v2_base':
        model_checkpoint = 'unicamp-dl/ptt5-v2-base'
    elif model_name == 'ptt5_v2_large':
        model_checkpoint = 'unicamp-dl/ptt5-v2-large'
    else:
        print(f'\nError. Model Name {model_name} not found!')
        exit(-1)

    model_path = os.path.join(models_dir, f'{model_name}_{dataset_name}')

    output_dir = f'{training_dir}/{model_name}_{dataset_name}'

    os.makedirs(model_path, exist_ok=True)

    os.makedirs(output_dir, exist_ok=True)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print(f'\nDevice: {device} -- Use FP16: {use_fp16} -- Batch size: {batch_size} -- '
          f'Turn Off Computer: {is_turn_off_computer}')

    print(f'\nModel: {model_name} -- {model_checkpoint}')

    if dataset_name == 'xlsum':
        dataset = load_dataset('csebuetnlp/xlsum', 'portuguese')
    elif dataset_name == 'recognasumm':
        dataset = load_dataset("recogna-nlp/recognasumm")
        dataset = dataset.rename_column("index", "id")
        dataset = dataset.rename_column("Noticia", "text")
        dataset = dataset.rename_column("Sumario", "summary")
    else:
        print(f'\nError. DATASET Name {dataset_name} Invalid!')
        exit(-1)

    # dataset = dataset.filter(lambda example: len(example['summary'].split()) >= 25)

    if n_examples > 0:
        train_dataset = dataset['train'].select(range(n_examples))
        validation_dataset = dataset['validation'].select(range(n_examples))
    else:
        train_dataset = dataset['train']
        validation_dataset = dataset['validation']

    print(f'\nTrain: {len(train_dataset)}')
    print(f'Validation: {len(validation_dataset)}\n')

    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, legacy=False)

    train_encoded_dataset = train_dataset.map(
        preprocess_function, batched=True, fn_kwargs={
            'max_input_len_': max_input_len, 'max_target_len_': max_summary_len,
            'tokenizer_': tokenizer})

    validation_encoded_dataset = validation_dataset.map(
        preprocess_function, batched=True, fn_kwargs={
            'max_input_len_': max_input_len, 'max_target_len_': max_summary_len,
            'tokenizer_': tokenizer})

    model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

    data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

    logging_eval_steps = len(train_encoded_dataset) // batch_size

    train_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_epochs,
        learning_rate=5.6e-5,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        weight_decay=0.01,
        eval_steps=logging_eval_steps,
        logging_steps=logging_eval_steps,
        evaluation_strategy='epoch',
        predict_with_generate=True,
        save_total_limit=1,
        save_strategy='epoch',
        load_best_model_at_end=True,
        metric_for_best_model='rougeL',
        greater_is_better=True,
        push_to_hub=False,
        fp16=use_fp16
    )

    trainer = Seq2SeqTrainer(
        model=model,
        args=train_args,
        train_dataset=train_encoded_dataset,
        eval_dataset=validation_encoded_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_eval_metrics,
        callbacks=[
            EarlyStoppingCallback(
                early_stopping_patience=5
            )
        ]
    )

    if os.path.exists(output_dir) and len(os.listdir(output_dir)) > 0:
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    trainer.evaluate()

    trainer.save_model(model_path)

    print('\n\n***Finetunning Complete!***')

    if is_turn_off_computer:
        print('\nTurning off computer ...')
        time.sleep(2 * 60)
        os.system('shutdown -h now')

## categorização do dataset spider de acordo com a complexidade(dificuldade)

In [None]:
import re

def get_sql_difficulty(sql_query):
    """Classifica o nível de dificuldade do SQL com base nos critérios do Spider."""
    # Contar o número de colunas no SELECT
    select_match = re.search(r"\bSELECT\b\s+(.*?)(\bFROM\b)", sql_query, re.IGNORECASE | re.DOTALL)
    if select_match:
        select_columns = select_match.group(1).split(",")
        num_select = len([col.strip() for col in select_columns if col.strip()])
    else:
        num_select = 0

    # Contar o número de condições no WHERE
    where_conditions = re.findall(r"\bWHERE\b(.*?)(\bGROUP BY\b|\bORDER BY\b|$)", sql_query, re.IGNORECASE | re.DOTALL)
    num_where = 0
    if where_conditions:
        where_clause = where_conditions[0][0]
        num_where = len(re.findall(r"AND|OR", where_clause, re.IGNORECASE)) if where_clause.strip() else 0

    # Contar o número de colunas no GROUP BY
    group_by_match = re.search(r"\bGROUP BY\b\s+(.*?)(\bORDER BY\b|$)", sql_query, re.IGNORECASE | re.DOTALL)
    if group_by_match:
        group_by_columns = group_by_match.group(1).split(",")
        num_group_by = len([col.strip() for col in group_by_columns if col.strip()])
    else:
        num_group_by = 0

    # Contar o número de colunas no ORDER BY
    order_by_match = re.search(r"\bORDER BY\b\s+(.*?)(LIMIT|$)", sql_query, re.IGNORECASE | re.DOTALL)
    if order_by_match:
        order_by_columns = order_by_match.group(1).split(",")
        num_order_by = len([col.strip() for col in order_by_columns if col.strip()])
    else:
        num_order_by = 0

    # Contar subconsultas com base nos parênteses
    num_nested = len(re.findall(r"\(SELECT\b", sql_query, re.IGNORECASE))

    # Contar o número de junções (JOIN)
    num_joins = len(re.findall(r"\bJOIN\b", sql_query, re.IGNORECASE))

    # Verificar a presença de EXCEPT, INTERSECT e UNION
    has_except = bool(re.search(r"\bEXCEPT\b", sql_query, re.IGNORECASE))
    has_intersect = bool(re.search(r"\bINTERSECT\b", sql_query, re.IGNORECASE))
    has_union = bool(re.search(r"\bUNION\b", sql_query, re.IGNORECASE))

    # Critério especial para subconsultas com JOIN
    has_nested_join = bool(
        re.search(r"\(SELECT\b.*?\bJOIN\b", sql_query, re.IGNORECASE | re.DOTALL)
    )

    # Classificação de dificuldade com base nos critérios do Spider
    if has_union:
        return "extra hard"  # `UNION` é sempre "extra hard"
    elif has_nested_join or num_nested > 1:
        return "extra hard"  # Subconsulta com JOIN ou múltiplas subconsultas
    elif (
        num_select <= 1
        and num_where <= 1
        and num_group_by == 0
        and num_order_by == 0
        and num_nested == 0
        and num_joins == 0
        and not (has_except or has_intersect)
    ):
        print(f"Num Select: {num_select}\nNum Joins: {num_joins}\nNum Where: {num_where}\nNum Group By: {num_group_by}\nNum Order By: {num_order_by}\nNum Nested: {num_nested}")
        return "easy"
    elif (
        num_select <= 3
        and num_where <= 2
        and num_group_by <= 1
        and num_order_by <= 1
        and num_nested == 0
        and num_joins <= 1
        and not (has_except or has_intersect)
    ):
        print(f"Num Select: {num_select}\nNum Joins: {num_joins}\nNum Where: {num_where}\nNum Group By: {num_group_by}\nNum Order By: {num_order_by}\nNum Nested: {num_nested}")
        return "medium"
    elif (
        num_group_by > 1
        or num_order_by > 1
        or num_nested > 0
        or num_where > 2
        or num_joins > 1
        or has_except
        or has_intersect
    ):
        print(f"Num Select: {num_select}\nNum Joins: {num_joins}\nNum Where: {num_where}\nNum Group By: {num_group_by}\nNum Order By: {num_order_by}\nNum Nested: {num_nested}")
        return "hard"
    else:
        return "extra hard"


In [None]:
get_sql_difficulty("SELECT name, price FROM products WHERE catergory = 'Eletronics' AND voltage = 110")

In [None]:
get_sql_difficulty("SELECT AVG(price) FROM products WHERE category = 'Electronics'")

In [None]:
get_sql_difficulty("SELECT category, COUNT(*) FROM products GROUP BY category")

In [None]:
get_sql_difficulty("SELECT product_name, price FROM products")

In [None]:
# prompt: gere um exemplo de uma consulta sql que realize uma junção a esquerda colocando alias em cada tabela
get_sql_difficulty("SELECT c.customer_name, o.order_id FROM customers c LEFT JOIN orders o ON c.customer_id = o.customer_id")

In [None]:
# prompt: gere um exemplo de uma consulta sql que realize uma junção a esquerda e uma junção a direita com uma subconsulta colocando alias em cada tabela
sql1 = ("SELECT "
    + "c.customer_id, "
    + "c.customer_name, "
    + "o.order_id "
+ "FROM "
    + "customers c "
+ "LEFT JOIN "
    + "(SELECT order_id, customer_id FROM orders WHERE order_date > '2023-01-01') o "
+ "ON "
    + "c.customer_id = o.customer_id; ")



sql2 = ("SELECT"
    + "o.order_id, "
    + "o.order_date, "
    + "c.customer_name "
+ "FROM "
    + "(SELECT order_id, order_date, customer_id FROM orders WHERE order_status = 'Shipped') o "
+ "RIGHT JOIN "
    + "customers c "
+ "ON "
    + "o.customer_id = c.customer_id; ")

print(get_sql_difficulty(sql1))
print("\n")
print(get_sql_difficulty(sql2))

In [None]:
# prompt: gere uma consulta sql para retornar todos os alunos que estão matriculado em matemática
print(get_sql_difficulty("SELECT s.student_name FROM Students s JOIN Enrollments e ON s.student_id = e.student_id JOIN Courses c ON e.course_id = c.course_id WHERE c.course_name = 'Matemática'"))

In [None]:
get_sql_difficulty("SELECT name ,  born_state ,  age FROM head ORDER BY age")

In [None]:
get_sql_difficulty("SELECT T1.country_name FROM countries AS T1 JOIN continents AS T2 ON T1.continent = T2.cont_id JOIN car_makers AS T3 ON T1.country_id = T3.country WHERE T2.continent = 'Europe' GROUP BY T1.country_name HAVING COUNT(*) >= 3")

In [None]:
get_sql_difficulty("SELECT T2.name, COUNT(*) FROM concert AS T1 JOIN stadium AS T2 ON T1.stadium_id = T2.stadium_id GROUP BY T1.stadium_id")

In [None]:
get_sql_difficulty("SELECT COUNT(*) FROM cars_data WHERE cylinders > 4")

In [None]:
get_sql_difficulty("SELECT AVG(life_expectancy) FROM country WHERE name NOT IN (SELECT T1.name FROM country AS T1 JOIN country_language AS T2 ON T1.code = T2.country_code WHERE T2.language = 'English' AND T2.is_official = 'T')")