# Подготовка датасета

In [None]:
!pip install json5 gdown sentence-transformers -q

In [1]:
import json5
from tqdm import tqdm
import pandas as pd
# from unittest import TestCase, TextTestRunner, defaultTestLoader
import re, time
# from sentence_transformers import SentenceTransformer
# from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from utils.general import *
# import torch
from sqlalchemy import create_engine
from prompting import PromptBuilder
from sklearn.utils import shuffle
from sqlalchemy import text
from utils.dataset import prepare_column_names, structure_from_connection
import matplotlib.pyplot as plt

In [2]:
table = pd.read_excel('2023_04_Продажи_код_артикул.xlsx')
engine = create_engine('sqlite://', echo=False)
table.to_sql(name='sales', con=engine)
conn = engine.connect()

In [3]:
prepare_column_names(conn) # Устраняет пробелы в названии столбцов

True

In [4]:
query = r'''SELECT 
    t1.Артикул AS артикул_1, 
    t2.Артикул AS артикул_2, 
    COUNT(*) AS совместные_продажи
FROM sales t1
JOIN sales t2 ON t1.Регистратор = t2.Регистратор
WHERE t1.Артикул < t2.Артикул
  AND t1.Артикул IS NOT NULL
  AND t2.Артикул IS NOT NULL
  AND t1.Артикул NOT LIKE 'u%'
  AND t2.Артикул NOT LIKE 'u%'
  AND t1.Хозяйственнаяоперация != 'Закрытие месяца'
  AND t2.Хозяйственнаяоперация != 'Закрытие месяца'
GROUP BY t1.Артикул, t2.Артикул
ORDER BY совместные_продажи DESC
LIMIT 10;'''

In [5]:
### Пример использования schema linking
structure = structure_from_connection(conn)
linked_schema = parse_literals(query, structure)

prompt = PromptBuilder(question="Oh shit, i`m sorry... Sorry for what?")
prompt.switch_schema_linking(linked_schema).add_schema_template(conn).build_prompt(3)

'### Схема таблиц:\nsales(Артикул, Регистратор, Хозяйственнаяоперация);\n\n### Вопрос: Oh shit, i`m sorry... Sorry for what?\n### SQL:\n\n'

In [None]:
# conn, dataset = load_table(r'./pauq_databases/merged_database/geo',
#                            r'./distilled-dataset/dataset/pauq_train.json', 'geo')
# db = conn.cursor()

# Препроцессинг промпта

In [None]:
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [None]:
class HuggingFaceModelInference:
    def __init__(self, path):
        self.path = path
        self.evaluated = False
        self.is_downloaded = False

    def __load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(self.path,torch_dtype=torch.float16,trust_remote_code=True, device_map='auto')

    def evaluate(self, dataset):
        """Запуск модели на заданном датасете"""
        if not self.is_downloaded:
            self.__load_model()
            self.is_downloaded = True

        logger, exec_time = [], []
        summary = 0
        for sample in tqdm(dataset):
            question = sample['question_ru']
            truth_sql = sample['query_ru']

            builder = PromptBuilder(question)
            prompt = builder\.add_schema_template(conn)\
                             .add_few_shot(sentence_model, question, dataset)\
                             .add_cell_value_referencing(conn, count=3)\
                             .build_prompt(1)

            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            st = time.time()
            inputs = self.tokenizer(prompt, return_tensors='pt').to('cuda')
            generate_ids = self.model.generate(inputs.input_ids, max_length=2048, pad_token_id=self.tokenizer.pad_token_id)
            output = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
            exec_time.append(time.time() - st)

            pred_sql = re.search(r'SQL:(.+)', output, re.DOTALL).group(1).strip()
            logger.append([pred_sql, truth_sql])
            try:
                df_pred = pd.read_sql(pred_sql, conn)
                df_truth = pd.read_sql(truth_sql, conn)
                summary += table_similarity(df_pred, df_truth, mode='flexible')
            except:
                pass

        self.summary = summary
        self.samples_len = len(dataset)
        self.exec_time = exec_time
        self.logger = logger
        self.evaluated = True

    def accuracy(self):
        """Метрика, характеризующая корректную кодогенерацию модели"""
        if not self.evaluated:
            raise Exception('Model was not been evaluated')
        return self.summary / self.samples_len

    def sql_similarity(self):
        """Метрика, характеризующая синтаксическую схожесть сгенерированного и истинного кода"""
        if not self.evaluated:
            raise Exception('Model was not been evaluated')
        vectorized = [[sentence_model.encode(pair[0]), sentence_model.encode(pair[1])] for pair in self.logger]
        similarities = [sentence_model.similarity(pair[0], pair[1]) for pair in vectorized]
        return np.mean(similarities)

In [None]:
def dump_inference(name: str, exec_time: list, sql_sim, acc):
    dump = json5.dumps({
        'name': name,
        'exec_time': exec_time,
        'sql_similarity': str(sql_sim),
        'accuracy': str(acc)
    })
    with open(f'{name}_dump.txt', 'w') as w:
        w.write(dump)

## 1. SQLCoder 7b https://huggingface.co/defog/sqlcoder-7b-2

In [None]:
sqlcoder = HuggingFaceModelInference('defog/sqlcoder-7b-2')
sqlcoder.evaluate(shuffle(dataset))

In [None]:
sqlcoder.accuracy(), sqlcoder.sql_similarity(), np.mean(sqlcoder.exec_time)

In [None]:
sqlcoder.logger

## 2. SQLTroughAI (сайт не работает?) https://sqlthroughai.com/

## 3. Chat2DB 7b

In [None]:
# chat2db = HuggingFaceModelInference('Chat2DB/Chat2DB-SQL-7B')
# chat2db.evaluate(shuffle(dataset)[:20])

In [None]:
# chat2db.accuracy(), chat2db.sql_similarity(), np.mean(chat2db.exec_time)

In [None]:
# dump_inference('Chat2DB-SQL-7B', chat2db.exec_time, chat2db.sql_similarity(), chat2db.accuracy())

## 4. SQLova (пока пропустим)

In [None]:
# !wget https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt

In [None]:
# from transformers import AutoTokenizer, BertModel

# model = torch.load('model_bert_best.pt', map_location='cpu', weights_only=True)

## 5. DuckDB-NSQL 7b

In [None]:
# duckdb = HuggingFaceModelInference('motherduckdb/DuckDB-NSQL-7B-v0.1')

In [None]:
# duckdb.evaluate(shuffle(dataset)[:30])

In [None]:
# duckdb.accuracy(), duckdb.sql_similarity(), np.mean(duckdb.exec_time)

In [None]:
# dump_inference('DuckDB-NSQL-7B-v0.1', duckdb.exec_time, duckdb.sql_similarity(), duckdb.accuracy())

## 6. Internlm 8b

In [None]:
#!pip install einops -q

In [None]:
# internlm = HuggingFaceModelInference('internlm/internlm2_5-7b')
# internlm.evaluate(dataset[:20])

In [None]:
# dump_inference('internlm2_5-7b', internlm.exec_time, internlm.sql_similarity(), internlm.accuracy())

## Прочее

In [None]:
from numba import cuda
import gc
cuda.devices.gpus[0].reset()
cuda.devices.gpus[1].reset()
gc.collect()

In [None]:
class TestTableSimilarity(TestCase):  # Тестирование функции сравнения таблиц
    def test_one_table(self):
        df = pd.read_sql('SELECT * FROM state', conn)
        self.assertEqual(table_similarity(df, df, mode='soft'), 1)
        self.assertEqual(table_similarity(df, df, mode='strict'), 1)
        self.assertEqual(table_similarity(df, df, mode='flexible'), 1)

    def test_two_tables_with_same_rows(self):
        df1 = pd.read_sql('SELECT * FROM state', conn)
        df2 = pd.read_sql('SELECT * FROM state', conn)
        self.assertEqual(table_similarity(df1, df2, mode='soft'), 1)
        self.assertEqual(table_similarity(df1, df2, mode='strict'), 1)
        self.assertEqual(table_similarity(df1, df2, mode='flexible'), 1)

    def test_same_rows_with_different_order(self):
        df1 = pd.read_sql('SELECT * FROM state', conn)
        df2 = pd.read_sql('SELECT * FROM state', conn)[::-1]
        self.assertEqual(table_similarity(df1, df2, mode='soft'), 1)
        self.assertEqual(table_similarity(df1, df2, mode='strict'), 0)
        self.assertEqual(table_similarity(df1, df2, mode='flexible'), 1)

    def test_with_intersection(self):
        df1 = pd.read_sql(
            'SELECT * FROM state where population <= 1904000', conn)
        df2 = pd.read_sql(
            'SELECT * FROM state where population >= 1904000', conn)
        self.assertEqual(table_similarity(df1, df2, mode='soft'), 0)
        self.assertEqual(table_similarity(df1, df2, mode='strict'), 0)
        self.assertEqual(table_similarity(df1, df2, mode='flexible'), 1/81)

    def test_empty_tables(self):
        df1 = pd.read_sql('SELECT * FROM state where population < 0', conn)
        df2 = pd.read_sql('SELECT * FROM state where population < 0', conn)
        self.assertEqual(table_similarity(df1, df2, mode='soft'), 1)
        self.assertEqual(table_similarity(df1, df2, mode='strict'), 1)
        self.assertEqual(table_similarity(df1, df2, mode='flexible'), 1)

    def test_absolutely_different_tables(self):
        df1 = pd.read_sql(
            'SELECT * FROM state where population < 1904000', conn)
        df2 = pd.read_sql(
            'SELECT * FROM state where population > 1904000', conn)
        self.assertEqual(table_similarity(df1, df2, mode='soft'), 0)
        self.assertEqual(table_similarity(df1, df2, mode='strict'), 0)
        self.assertEqual(table_similarity(df1, df2, mode='flexible'), 0)


TextTestRunner().run(defaultTestLoader.loadTestsFromTestCase(TestTableSimilarity))