# Подготовка датасета

In [None]:
!pip install json5 gdown sentence-transformers zss -q

In [None]:
import json5
from tqdm import tqdm
import pandas as pd
# from unittest import TestCase, TextTestRunner, defaultTestLoader
import re, time
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import torch
from sqlalchemy import create_engine
from sklearn.utils import shuffle
from sqlalchemy import text, Connection
import matplotlib.pyplot as plt

In [None]:
import numpy as np
from sentence_transformers import util
import pandas as pd
import zipfile
import sqlparse



def find_similar_sentences(sentence_model, target_sentence : str, sentences : list[str], count : int = 3):
    emb_target = sentence_model.encode(target_sentence)

    sims = []
    for i, sentence in enumerate(sentences):
        emb_sentence = sentence_model.encode(sentence)
        sim = util.pytorch_cos_sim(emb_sentence, emb_target)
        sims.append([i, np.float16(sim.squeeze())])

    nearest = sorted(sims, key=lambda pair : pair[1], reverse=True)
    similar_questions = [sentences[pair[0]] for pair in nearest if pair[1] != 1.0][:count]
    return similar_questions


def table_similarity(dataframe1 : pd.DataFrame, dataframe2 : pd.DataFrame, mode : str) -> int:
    # if dataframe1.columns.shape != dataframe2.columns.shape:
    #     return False
    # if not (dataframe1.columns == dataframe2.columns).all():
    #     return False
    
    match mode:
        case 'soft':
            return int(dataframe1.sort_index().equals(dataframe2.sort_index()))
        case 'strict':
            return int(dataframe1.equals(dataframe2))
        case 'flexible':
            hash_1 = set(pd.util.hash_pandas_object(dataframe1, index=False))
            hash_2 = set(pd.util.hash_pandas_object(dataframe2, index=False))
            intersection = hash_1 & hash_2
            union = hash_1 | hash_2

            return len(intersection) / len(union) if len(union) != 0 else 1
        case _:
            raise Exception('Incorrect mode value')
     

def unzip_file(path, path_to):
    with zipfile.ZipFile(path, 'r') as zip_ref:
        zip_ref.extractall(path_to)


def parse_literals(sql : str, table_structure : list[dict]):
    root = sqlparse.parse(sql)[0]
    names = []

    def __get_all_names_helper(node : sqlparse.sql.Token):
        if issubclass(type(node), sqlparse.sql.TokenList):
            for token in node.tokens:
                __get_all_names_helper(token)
        elif node.ttype != sqlparse.sql.T.Punctuation and node.ttype != sqlparse.sql.T.Whitespace:
            names.append(node.value)

    __get_all_names_helper(root)
    
    tables = set([table['table_name'] for table in table_structure])
    visited_tables = set([])
    buckets = []
    for name in names:
        if name in tables and name not in visited_tables:
            buckets.append({
                'table_name' : name,
                'columns' : []
            })
            visited_tables.add(name)
        elif name not in tables:
            for table in table_structure:
                if name in table['columns']:
                    if table['table_name'] not in visited_tables:
                        buckets.append({
                            'table_name' : table['table_name'],
                            'columns' : [name]
                        })
                        visited_tables.add(table['table_name'])
                    else:
                        instance = [bucket for bucket in buckets if bucket['table_name'] == table['table_name']][0]
                        if name not in instance['columns']:
                            instance['columns'].append(name)
    
    return buckets

In [None]:
import string
import pandas as pd
from sqlalchemy import text, Connection


class IterableDataFrame():
    def __init__(self, df : pd.DataFrame):
        self.df = df
        self.__series = {}
        for idx in self.df.index:
            sample = {
                column : self.df[self.df.index == idx][column][idx] for column in self.df.keys()
            }
            self.__series[idx] = sample

    def __len__(self):
        return self.df.shape[0]

    def as_list(self):
        return list(self.__series.values())
    
    def __iter__(self):
        return iter(self.as_list())

    def __getitem__(self, index):
        return self.__series[index]


def tables_from_connection(conn : Connection):
    master = pd.DataFrame(conn.execute(text('SELECT * FROM sqlite_master')).fetchall())
    tables = list(master[master['type'] == 'table']['name'])
    return tables


def structure_from_connection(conn : Connection):
    tables = tables_from_connection(conn)
    structure = []
    for table in tables:
        columns = list(pd.DataFrame(conn.execute(text(f'SELECT * FROM "{table}"')).fetchall()).columns)[1:]
        structure.append(
            {
                'table_name' : table,
                'columns' : columns
            })
        
    return structure


def prepare_column_names(conn : Connection):
    structure = structure_from_connection(conn)
    for table in structure:
        for column in table['columns']:
            if len((set(string.punctuation) | set(string.whitespace)) & set(column)) != 0:
                new_name = ''.join([char for char in column if str.isalnum(char)])
                conn.execute(text(
                    f'''ALTER TABLE "{table['table_name']}" RENAME COLUMN "{column}" TO "{new_name}"'''
                ))

        if len((set(string.punctuation) | set(string.whitespace)) & set(table['table_name'])) != 0:
            new_table_name = ''.join([char for char in table['table_name'] if str.isalnum(char)]);
            conn.execute(text(f'''ALTER TABLE "{table['table_name']}" RENAME TO "{new_table_name}"'''))

    return True

In [None]:
import pandas as pd
import numpy as np


class PromptBuilder:
    def __init__(self, question):
        self.__prompt = ''
        self.schema_linking = False
        self.__question = question
        self.__few_shot = None
        self.__schema_template = None
        self.__cell_value_referencing = None


    def switch_schema_linking(self, table_structure=None):
        self.table_structure = table_structure
        self.schema_linking = not self.schema_linking
        return self


    def add_few_shot(self, sentence_model, target_question, queries):
        questions = [sample['question'] for sample in queries]

        input_examples = []
        similar = find_similar_sentences(sentence_model, target_question, questions, count=3)
        for sample in queries:
            curr_qs = sample['question']
            if curr_qs in similar:
                input_examples.append([curr_qs, sample['query']])

        few_shot_template = ''
        for ex in input_examples:
            few_shot_template += f'Q: {ex[0]}\n'
            few_shot_template += f'A: {ex[1]}\n'

        self.__few_shot = few_shot_template
        return self
    

    def add_schema_template(self, db_conn):
        if self.schema_linking:
            structure = self.table_structure
        else:
            structure = structure_from_connection(db_conn)

        schema_template = ''
        for table in structure:
            schema_template += f"{table['table_name']}({', '.join(table['columns'])});\n"
        self.__schema_template = schema_template
        return self


    def add_cell_value_referencing(self, db_conn, count=1):
        if self.schema_linking:
            tables = [table['table_name'] for table in self.table_structure]
        else:
            tables = tables_from_connection(db_conn)

        data_information = []
        for table in tables:
            if self.schema_linking:
                instance = [bucket for bucket in self.table_structure if bucket['table_name'] == table][0]
                pd_table = pd.read_sql(f'SELECT * FROM {table}', db_conn)[instance['columns']]
            else:
                pd_table = pd.read_sql(f'SELECT * FROM {table}', db_conn)
            
            indexes = np.random.randint(0, pd_table.shape[0], size=count)
            series = [pd_table[pd_table.index == idx].to_numpy() for idx in indexes]

            data_information.append({
                'table_name' : table,
                'examples' : [f"[{', '.join(map(str,list(ser.reshape(ser.shape[1]))))}]" for ser in series]
            })

        value_template = ''
        for data in data_information:
            value_template += f"{data['table_name']}({', '.join(data['examples'])});\n"

        self.__cell_value_referencing = value_template
        return self


    def include_target(self, number: int):
        Variations = {
            1: 'Ответь на вопрос SQLite sql-запросом и без объяснений.\n',
            2: ''
        }
        return Variations[number]


    def include_few_shot(self, number: int):
        if self.__few_shot is None:
            raise RuntimeError('Не добавлен few_shot')

        Variations = {
            1: f'### Примеры похожих запросов и ответы на них:\n{self.__few_shot}\n',
            2: ''
        }
        return Variations[number]


    def include_schema_template(self, number: int):
        if self.__schema_template is None:
            raise RuntimeError('Не добавлен schema_template')

        Variations = {
            1: f'### Схема таблиц:\n{self.__schema_template}\n',
            2: ''
        }
        return Variations[number]


    def include_cell_value_referencing(self, number: int):
        if self.__cell_value_referencing is None:
            raise RuntimeError('Не добавлен cell_value_referencing')

        Variations = {
            1: f'### Примеры данных в таблице:\n{self.__cell_value_referencing}\n',
            2: ''
        }
        return Variations[number]


    def include_question(self, number: int):
        Variations = {
            1: f'### Вопрос: {self.__question}\n### SQL:\n\n',
            2: ''
        }
        return Variations[number]


    def build_prompt(self, number: int):
        Variations = {
            1: {
                self.include_target : 1,
                self.include_few_shot : 1,
                self.include_schema_template : 1,
                self.include_cell_value_referencing : 1,
                self.include_question : 1
            },
            2: {
                self.include_question : 1
            },
            3:
            {
                self.include_cell_value_referencing : 1,
                self.include_question : 1
            },
            4:
            {
                self.include_question : 1,
                self.include_schema_template : 1
            }
        }

        for func, value in Variations[number].items():
                self.__prompt += func(value)  

        return self.__prompt


In [None]:
import zss
import sqlparse

# def pretty_print(node, shift):
#     print(shift + str(node))
#     shift += '    '
#     for token in node.children:
#         pretty_print(token, shift + '    ')


class SqlNode:
    def __init__(self, node):
        self.children = []
        self.raw_node = node
        if type(node) == sqlparse.sql.Token or type(node) == sqlparse.sql.Identifier:
            self.label = str(node.value)
            return
        
        self.label = type(node).__name__
        for token in node.tokens:
            if token.is_whitespace:
                continue

            self.children.append(SqlNode(token))

    def __repr__(self):
        return str(type(self.raw_node)) + ' ' + self.label
    
    @staticmethod
    def get_children(self):
        return self.children
    
    @staticmethod
    def get_label(self):
        return self.label


def dist_comp(node1, node2):
    return int(node1 != node2)


def ratio(tree1 : SqlNode, tree2 : SqlNode):
    edit_distance = zss.simple_distance(tree1, tree2, SqlNode.get_children, SqlNode.get_label, dist_comp)

    def __tree_nodes_count(root):
        cnt = 0
        for child in root.children:
            cnt += __tree_nodes_count(child)

        cnt += 1
        return cnt
    
    max_nodes = max(__tree_nodes_count(tree1), __tree_nodes_count(tree2))
    return max(1 - edit_distance/max_nodes, 0)

In [None]:
#!gdown 1Xjbp207zfCaBxhPgt-STB_RxwNo2TIW2
#unzip_file('merged_database_2022-06-10.zip', 'pauq_databases')

In [None]:
engine = create_engine('sqlite:////kaggle/input/main-package/main_database.sqlite', echo=False)
conn = engine.connect()

In [None]:
prepare_column_names(conn) # Устраняет пробелы в названии столбцов
structure_from_connection(conn)

In [None]:
### Пример использования schema linking
#structure = structure_from_connection(conn)
#linked_schema = parse_literals(query, structure)

#prompt = PromptBuilder(question="Oh shit, i`m sorry... Sorry for what?")
#prompt.switch_schema_linking(linked_schema).add_schema_template(conn).build_prompt(3)

In [None]:
queries = IterableDataFrame(pd.read_excel('/kaggle/input/main-package/NLSQL.xlsx'))

# Препроцессинг промпта

In [None]:
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

In [None]:
class HuggingFaceModelInference:
    def __init__(self, path):
        self.path = path
        self.evaluated = False
        self.is_downloaded = False


    def __load_model(self):
        self.tokenizer = AutoTokenizer.from_pretrained(self.path, trust_remote_code=True)
        self.model = AutoModelForCausalLM.from_pretrained(
                    self.path,
                    torch_dtype=torch.float16,
                    device_map="auto",
                    max_memory={0: "10GiB", 1: "10GiB"},  
                    offload_folder="./offload", 
                    trust_remote_code=True
                    )

    def evaluate(self, queries : IterableDataFrame, connection : Connection):
        if not self.is_downloaded:
            self.__load_model()
            self.is_downloaded = True

        self.model.eval()
        logger = []
        summary = 0
        for query in tqdm(queries):
            question = query['question']
            gold_sql = query['query']

            builder = PromptBuilder(question)
            prompt = builder.add_schema_template(connection)\
                             .build_prompt(4)

            text = f'''You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.
                1. Return ONLY valid SQL query without any explanations
                3. Never repeat the answer
                4. Format: [SQL]<query>[/SQL]
                
                ### Instruction:
                {prompt}\n\n
                ### Response:'''
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            with torch.inference_mode():  
                inputs = self.tokenizer(text,return_tensors="pt").to(self.model.device) 

                generate_ids = self.model.generate(
                                **inputs,
                                max_length=2048,
                                num_return_sequences=1,
                                temperature=0.2, 
                                top_p=0.95,
                                do_sample=True,
                                use_cache=True 
                                )
        
                output = self.tokenizer.decode(
                        generate_ids[0, inputs.input_ids.shape[1]:],
                        skip_special_tokens=True
                        )

            #pred_sql = re.search(r'Response:(.+)', output, re.DOTALL).group(1).strip()
            pred_sql =  re.search(r'\[SQL\](.*?)\[\/SQL\]', output, re.DOTALL)
            pred_sql = pred_sql.group(1).strip() if pred_sql else "error"
            logger.append({'question' : question, 'pred' : pred_sql, 'gold' : gold_sql})
            try:
                df_pred = pd.read_sql(pred_sql, connection)
                df_gold = pd.read_sql(gold_sql, connection)
                summary += table_similarity(df_pred, df_gold, mode='flexible')
            except:
                pass

        self.summary = summary
        self.queries_count = len(queries)
        self.logger = logger
        self.evaluated = True

    def accuracy(self):
        """Метрика, характеризующая корректную кодогенерацию модели"""
        if not self.evaluated:
            raise Exception('Model was not been evaluated')
        return self.summary / self.queries_count

    def sql_similarity(self):
        """Метрика, характеризующая синтаксическую схожесть сгенерированного и истинного кода"""
        if not self.evaluated:
            raise Exception('Model was not been evaluated')
        vectorized = [[sentence_model.encode(pair[0]), sentence_model.encode(pair[1])] for pair in self.logger]
        similarities = [sentence_model.similarity(pair[0], pair[1]) for pair in vectorized]
        return np.mean(similarities)

In [None]:
def dump_inference(name: str, exec_time: list, sql_sim, acc):
    dump = json5.dumps({
        'name': name,
        'exec_time': exec_time,
        'sql_similarity': str(sql_sim),
        'accuracy': str(acc)
    })
    with open(f'{name}_dump.txt', 'w') as w:
        w.write(dump)

## 1. SQLCoder 7b https://huggingface.co/defog/sqlcoder-7b-2

In [None]:
#sqlcoder = HuggingFaceModelInference('defog/sqlcoder-7b-2')
#sqlcoder.evaluate(shuffle(geo_iterable.as_list())[:3], conn)

In [None]:
#sqlcoder.accuracy()

In [None]:
#sqlcoder.logger

## DeepSeek 6.7b

In [None]:
deepseek = HuggingFaceModelInference('deepseek-ai/deepseek-coder-6.7b-instruct')
deepseek.evaluate(shuffle(queries.as_list())[:10], conn) 

In [None]:
torch.cuda.empty_cache()

## 2. SQLTroughAI (сайт не работает?) https://sqlthroughai.com/

## 3. Chat2DB 7b

In [None]:
# chat2db = HuggingFaceModelInference('Chat2DB/Chat2DB-SQL-7B')
# chat2db.evaluate(shuffle(dataset)[:20])

In [None]:
# chat2db.accuracy(), chat2db.sql_similarity(), np.mean(chat2db.exec_time)

In [None]:
# dump_inference('Chat2DB-SQL-7B', chat2db.exec_time, chat2db.sql_similarity(), chat2db.accuracy())

## 4. SQLova (пока пропустим)

In [None]:
# !wget https://github.com/naver/sqlova/releases/download/SQLova-parameters/model_bert_best.pt

In [None]:
# from transformers import AutoTokenizer, BertModel

# model = torch.load('model_bert_best.pt', map_location='cpu', weights_only=True)

## 5. DuckDB-NSQL 7b

In [None]:
# duckdb = HuggingFaceModelInference('motherduckdb/DuckDB-NSQL-7B-v0.1')

In [None]:
# duckdb.evaluate(shuffle(dataset)[:30])

In [None]:
# duckdb.accuracy(), duckdb.sql_similarity(), np.mean(duckdb.exec_time)

In [None]:
# dump_inference('DuckDB-NSQL-7B-v0.1', duckdb.exec_time, duckdb.sql_similarity(), duckdb.accuracy())

## 6. Internlm 8b

In [None]:
#!pip install einops -q

In [None]:
# internlm = HuggingFaceModelInference('internlm/internlm2_5-7b')
# internlm.evaluate(dataset[:20])

In [None]:
# dump_inference('internlm2_5-7b', internlm.exec_time, internlm.sql_similarity(), internlm.accuracy())

## Прочее

In [None]:
from numba import cuda
import gc
#cuda.devices.gpus[0].reset()
#cuda.devices.gpus[1].reset()
#gc.collect()