In [None]:
!pip install --quiet accelerate peft bitsandbytes transformers datasets

In [2]:
import torch
from peft import LoraConfig, AutoPeftModelForCausalLM, PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset, Dataset
import sqlite3
import pandas as pd

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
dataset_id = "NESPED-GEN/cnpj"
dataset_split = "test"
db_id = "cnpj"
dataset =  load_dataset(dataset_id,split=dataset_split)
df = dataset.to_pandas()

In [8]:
def get_model_and_tokenizer(model_id):

    tokenizer = AutoTokenizer.from_pretrained(model_id)
    tokenizer.pad_token = tokenizer.eos_token

    compute_dtype = getattr(torch, "float16")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_use_double_quant=False,
    )

    # bnb_config = BitsAndBytesConfig(
    #     load_in_8bit=True,
    # )

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        device_map="auto",
        #quantization_config=bnb_config,
        use_cache=True,
    )
    return model, tokenizer

In [9]:
from huggingface_hub import login

token = "WRITE_TOKEN"
login(token=token)

In [None]:
model_name=f"" 

model, tokenizer = get_model_and_tokenizer(model_name)

### Prompt e Função para Obter Resposta Gerada pelo Modelo **[ajustar de acordo com o modelo]**

In [11]:
from transformers import pipeline

params = {
    "task":"text-generation",
    "eos_token_id":tokenizer.eos_token_id,
    "pad_token_id":tokenizer.eos_token_id,
    "max_new_tokens":250,
    "do_sample":False,
    "temperature": 0.0,
    "return_full_text":False,
    "stop_sequence": "<|im_end|>"
}

pipe = pipeline(model=model, tokenizer=tokenizer, **params)

Device set to use cuda:0


#### Get Schema

In [None]:
def SQLDatabase(db_path, num_examples = 0):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    schema_str = ""

    cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
    tables = cursor.fetchall()

    for table in tables:

        table_name=table[0]
        cursor.execute(f"PRAGMA table_info('{table_name}')")
        included_columns = cursor.fetchall()

        schema_str += f'CREATE TABLE {table_name.lower()} (\n'

        primary_keys = []
        for column in included_columns:
            column_name = column[1].replace('"','')
            column_type = column[2]
            schema_str += f'        {column_name.lower()} {column_type.upper()},\n'

            if column[5] == 1:
                primary_keys.append(column[1].replace('"',''))

        schema_str = schema_str.rstrip(",\n") 

        # Adicionar chaves primárias ao esquema
        if primary_keys:
            primary_keys_str = [pk.replace('"','').lower() for pk in primary_keys]
            primary_keys_str = ", ".join(primary_keys_str)
            schema_str += f',\n        PRIMARY KEY ({primary_keys_str})'


        cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
        foreign_keys_info = cursor.fetchall()
        for fk in foreign_keys_info:
          try:
              fk_col = fk[3].replace('"','')          
              ref_table = fk[2].replace('"','')       
              ref_col = fk[4].replace('"','')         
              schema_str += f',\n        FOREIGN KEY ({fk_col.lower()}) REFERENCES {ref_table.lower()}({ref_col.lower()})'
          except:
            print(fk)

        schema_str += "\n);\n\n"

        if num_examples > 0:
          cursor.execute(f"SELECT {', '.join([col[1] for col in included_columns])} FROM {table_name} LIMIT {num_examples};")
          rows = cursor.fetchall()
          schema_str += f"/*\n{len(rows)} rows from {table_name} table:\n"
          schema_str += "\t".join([col[1].lower().replace('"','') for col in included_columns]) + "\n"
          for row in rows:
              schema_str += "\t".join(map(str, row)) + "\n"
          schema_str += "*/\n\n"

    schema_str = schema_str.rstrip('\n\n')

    conn.close()
    return schema_str

In [12]:
def schema_reduzido(db_path, tables_and_columns, num_examples = 0):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    tables_and_columns_dict = eval(tables_and_columns)
    tables_and_columns_dict_min = eval(tables_and_columns.lower())

    schema_str = ""

    for table_name, columns in tables_and_columns_dict.items():

        table_name = table_name.replace('"','')
        columns_to_include = [col.strip().lower().replace('"','') for col in columns]

        # obter informações das colunas da tabela
        cursor.execute(f"PRAGMA table_info('{table_name}')")
        columns_info = cursor.fetchall()

        included_columns = []
        primary_keys = []
        for column in columns_info:
            column_name = column[1].lower().replace('"','') #usar lower pra não ter interferencia de letra maiscula
            if column_name in columns_to_include:
                included_columns.append(column)

            if column[5] == 1:  # se a coluna for uma chave primária, sempre incluir
                primary_keys.append(column[1].replace('"',''))
                if column_name not in columns_to_include:
                    included_columns.append(column)

        schema_str += f'CREATE TABLE {table_name} (\n'
        for column in included_columns:
            column_name = column[1].replace('"','')
            column_type = column[2]
            schema_str += f'        {column_name} {column_type.upper()},\n'

        schema_str = schema_str.rstrip(",\n") # remover a última vírgula e nova linha extra

        # Adicionar chave primária ao esquema
        if primary_keys:
            primary_keys_str = [pk.replace('"','') for pk in primary_keys]
            primary_keys_str = ", ".join(primary_keys_str)
            schema_str += f',\n        PRIMARY KEY ({primary_keys_str})'

        # Adicionar definições de chave estrangeira ao esquema
        cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
        foreign_keys_info = cursor.fetchall()

        for fk in foreign_keys_info:
            foreign_key = fk[3].replace('"','')  # Nome da coluna da chave estrangeira
            if foreign_key.lower() in columns_to_include: #confere se a coluna vai ser adicionada

                fk_col = fk[3].replace('"','')          # Coluna com chave estrangeira
                ref_table = fk[2].replace('"','')       # Tabela referenciada
                ref_col = fk[4].replace('"','')         # Coluna referenciada
                #if ref_table.lower() in tables_and_columns_dict_min.keys():
                #  if ref_col.lower() in tables_and_columns_dict_min[ref_table.lower()]:
                schema_str += f',\n        FOREIGN KEY ({fk_col}) REFERENCES {ref_table}({ref_col})'

        schema_str += "\n);\n\n"

        if num_examples > 0:
          # Adicionar exemplos de dados
          cursor.execute(f"SELECT {', '.join([col[1] for col in included_columns])} FROM {table_name} LIMIT {num_examples};")
          rows = cursor.fetchall()
          # Adicionar dados de exemplo
          schema_str += f"/*\n{len(rows)} rows from {table_name} table:\n"
          schema_str += "\t".join([col[1].lower().replace('"','') for col in included_columns]) + "\n"
          for row in rows:
              schema_str += "\t".join(map(str, row)) + "\n"
          schema_str += "*/\n\n"

    schema_str = schema_str.rstrip('\n\n')

    # Fechar a conexão
    conn.close()
    return schema_str

In [13]:
def schema_reduzido_tabelas(db_path, table_names, num_examples = 0):
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()

    tables_name_min = [table.lower() for table in table_names]

    schema_str = ""

    for table_name in table_names:

        cursor.execute(f"PRAGMA table_info('{table_name}')")
        columns_info = cursor.fetchall()

        included_columns = []
        primary_keys = []
        for column in columns_info:
            included_columns.append(column)
            if column[5] == 1:  # se a coluna for uma chave primária, sempre incluir
                primary_keys.append(column)


        schema_str += f'CREATE TABLE {table_name} (\n'
        for column in included_columns:
            column_name = column[1].replace('"','')
            column_type = column[2].upper()
            schema_str += f'        {column_name} {column_type.upper()},\n'

        schema_str = schema_str.rstrip(",\n") # remover a última vírgula e nova linha extra

        # Adicionar chave primária ao esquema
        if primary_keys:
            primary_keys_str = [pk[1].replace('"','') for pk in primary_keys]
            primary_keys_str = ", ".join(primary_keys_str)
            schema_str += f',\n        PRIMARY KEY ({primary_keys_str})'

        # Adicionar definições de chave estrangeira ao esquema
        cursor.execute(f"PRAGMA foreign_key_list('{table_name}')")
        foreign_keys_info = cursor.fetchall()
        for fk in foreign_keys_info:
            fk_col = fk[3].replace('"','')          # Coluna com chave estrangeira
            ref_table = fk[2].replace('"','')       # Tabela referenciada
            ref_col = fk[4].replace('"','')         # Coluna referenciada


            #if ref_table.lower() in tables_name_min:
            schema_str += f',\n        FOREIGN KEY ({fk_col}) REFERENCES {ref_table}({ref_col})'

        schema_str += "\n);\n\n"

        if num_examples > 0:
          # Adicionar exemplos de dados
          cursor.execute(f"SELECT * FROM {table_name} LIMIT {num_examples};")
          rows = cursor.fetchall()
          # Adicionar dados de exemplo
          schema_str += f"/*\n{len(rows)} rows from {table_name} table:\n"
          schema_str += "\t".join([col[1].lower().replace('"','') for col in included_columns]) + "\n"
          for row in rows:
              schema_str += "\t".join(map(str, row)) + "\n"
          schema_str += "*/\n\n"

    schema_str = schema_str.rstrip('\n\n')

    # Fechar a conexão
    conn.close()
    return schema_str

In [15]:
schema_linking = True
only_tables = True
schema_linking_result_path = f'.txt'
schema_linking_result_path

'/content/drive/Shareddrives/LLMs/ResultadoTestes/TestesPOCsL/CNPJ/StableCode-schemaLinking-min-v2.txt'

In [16]:
test_result = []

if schema_linking:
    with open(schema_linking_result_path, 'r') as arquivo:

        for linha in arquivo:

            generated = linha.lower()
            test_result.append(generated)

In [17]:
def get_schema(index, db_id, schema_linking, only_tables):
  
  database_folder_path = ""
  db_path = f'{database_folder_path}{db_id}/{db_id}.sqlite'
  
  if schema_linking:
    generated = eval(test_result[index])

    if only_tables==True:
      tables = [table for table in  eval(str(generated)).keys()]
      return schema_reduzido_tabelas(db_path,tables,0)

    else:
      return schema_reduzido(db_path, str(generated),0)
  
  else:
    return SQLDatabase(db_path)

In [23]:
def generate_response(question, schema):
    system = "Given a user question and the schema of a database, your task is to generate an SQL query that accurately answers the question based on the provided schema."

    messages = [
            {'role': 'system', 'content': system},
            {'role': 'user', 'content': f"# Schema:\n```sql\n{schema}\n```\n\n# Question: {question}"}
    ]

    prompt = pipe.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

    resp = pipe(prompt)

    generated_text = resp[0]["generated_text"]

    tokens = pipe.tokenizer.encode(generated_text)


    return generated_text, len(tokens)

In [None]:
out_path = f"...txt"
out_path

In [None]:
import time
from tqdm import tqdm
import statistics

max_new_tokens = 250

result = []
response_time = []
count_tokens = []
index = 0

with open(out_path, 'a+') as file:
  for example in tqdm(dataset, desc="Test ..."):

    question = example['question'] #example['question_en']
    schema = get_schema(index, example['db_id'], schema_linking, only_tables)

    #generated
    start_time = time.time()  
    output, tokens_count = generate_response(question, schema)
    end_time = time.time() 

    response_time.append(end_time - start_time)

    count_tokens.append(tokens_count)

    output = output.replace('\n', ' ')

    try:
      output = output.split('```sql')[1].split('```')[0]
    except IndexError:
      output = output.split('```')[0]

    if tokens_count == max_new_tokens:
      print(f'{index} -> {output}')

    output = output.replace('<|im_end|>', '')

    result.append(output)

    file.write(f"{output}\n")
    file.flush()

    index+=1

print(f'\naverage time = {statistics.mean(response_time)}')