# Пример готового pipeline


In [1]:
!pip install -q -U bitsandbytes
!pip install -q -U trl
!pip install -q -U accelerate
!pip install transformers
!pip install adapter-transformers

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.1/76.1 MB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m126.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m100.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m42.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, pipeline

In [12]:
import re
from typing import Optional

In [3]:
!unzip "/content/fine_tuned_models.zip"

Archive:  /content/fine_tuned_models.zip
   creating: my_sql_data/
   creating: pg_data/
   creating: sql_lite_data/
  inflating: my_sql_data/tokenizer.model  
  inflating: my_sql_data/tokenizer_config.json  
  inflating: my_sql_data/training_args.bin  
  inflating: my_sql_data/README.md   
  inflating: my_sql_data/adapter_config.json  
  inflating: my_sql_data/adapter_model.safetensors  
  inflating: my_sql_data/special_tokens_map.json  
  inflating: my_sql_data/tokenizer.json  
  inflating: sql_lite_data/tokenizer.model  
  inflating: sql_lite_data/tokenizer_config.json  
  inflating: sql_lite_data/training_args.bin  
  inflating: sql_lite_data/README.md  
  inflating: sql_lite_data/adapter_config.json  
  inflating: sql_lite_data/adapter_model.safetensors  
  inflating: sql_lite_data/special_tokens_map.json  
  inflating: sql_lite_data/tokenizer.json  
  inflating: pg_data/tokenizer.model  
  inflating: pg_data/tokenizer_config.json  
  inflating: pg_data/training_args.bin  
  infla

In [21]:
PROMPT_TEMPLATE = """<start_of_turn>user
You are an intelligent AI specialized in generating SQL queries.
Your task is to translate {db_name} into Clickhouse.
Please provide the SQL query corresponding to the given prompt and context:
Prompt:
translate {db_name} into Clickhouse
Context:
Natural language: {natural_language}
{ch_schema_part}
{query_part}
<end_of_turn>
<start_of_turn>model
"""

In [17]:
# Исходный данные
PG_SOURCE_QUERY = """
select deparеament, COUNT(*)
from employees
WHERE status = 'active'
-- example of comment section
GROUP BY department
order by COUNT(*) DESC
LIMIT 5;
"""
PG_DB_SCHEMA = """
CREATE TABLE Employees (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
age INTEGER CHECK (age > 0),
status TEXT CHECK (status IN ('active', 'inactive', 'terminated')),
department TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
"""
CH_DB_SCHEMA = """
CREATE TABLE Employees
(
    id UInt32,
    name String,
    age UInt8,
    status Enum8('active' = 1, 'inactive' = 2, 'terminated' = 3),
    department String,
    created_at DateTime
)
ENGINE = MergeTree
ORDER BY (status, age);
"""

## Первы этап: предобработка входных данных, формирование промпта

In [9]:
# Данную секцию можно подгружать через доступное api, здесь представлены ключевые слова PG и CH

SQL_KEYWORDS = {
    # Общие SQL
    "select", "from", "where", "join", "left", "right", "inner", "outer", "on", "group", "by",
    "order", "having", "limit", "offset", "insert", "into", "update", "delete", "values",
    "create", "table", "drop", "alter", "union", "all", "distinct", "and", "or", "not", "in", "as",
    "is", "null", "like", "ilike", "case", "when", "then", "else", "end", "exists", "between",
    "desc", "asc", "true", "false", "default", "primary", "key", "foreign", "references",
    "check", "constraint", "unique", "index", "view", "materialized", "column", "database",
    "cast", "coalesce", "intersect", "except", "using", "with", "recursive", "if"

    # PostgreSQL-специфические
    "serial", "bigserial", "text", "varchar", "json", "jsonb", "boolean", "bytea", "timestamp",
    "timestamptz", "date", "time", "interval", "inet", "uuid", "now", "current_timestamp",
    "returning", "ilike", "similar", "array", "unnest", "generate_series", "over", "partition",
    "range", "rows", "preceding", "following", "row_number", "rank", "dense_rank",
    "window", "lag", "lead", "nth_value", "first_value", "last_value", "filter", "do", "begin",
    "language", "plpgsql", "loop", "raise", "notice", "perform", "execute", "declare"

    # ClickHouse-специфические
    "engine", "merge", "tree", "mergetree", "replacing", "aggregating", "summing",
    "versionedcollapsing", "collapsing", "distributed", "nullable", "tuple", "map", "arrayjoin",
    "sample", "prewhere", "settings", "codec", "lowcardinality", "int8", "int16", "int32",
    "int64", "uint8", "uint16", "uint32", "uint64", "float32", "float64", "decimal", "datetime",
    "datetime64", "date", "string", "enum", "enum8", "enum16", "fixedstring", "aggregate",
    "function", "if", "case", "visitparamhas", "has", "position", "splitbychar", "multiif",
    "toint32", "todatetime", "tostring", "materialize", "row_policy", "cluster", "shard",
    "replica", "zookeeper", "dictionary", "join_get", "global", "final", "any", "any_last",
    "topk", "histogram", "quantiles", "median", "modulo", "array", "group_array", "group_uniq_array",
    "uniq", "uniqexact", "uniqcombined", "running_difference", "windowfunnel", "lambda"
}

In [13]:
def preprocess_sql_query(query: str, uppercase_keywords=True, lowercase_identifiers=False) -> str:
    """
    Полноценная предобработка SQL-запроса:
    - Удаление комментариев
    - Очистка пробелов
    - Нормализация регистра
    - Удаление лишних символов
    """

    query = re.sub(r'--.*?(\n|$)', ' ', query)
    query = re.sub(r'#.*?(\n|$)', ' ', query)

    query = re.sub(r'/\*.*?\*/', ' ', query, flags=re.DOTALL)

    query = re.sub(r'\s+', ' ', query.strip())

    if uppercase_keywords:
        def replace_keyword(match):
            word = match.group(0)
            if word.lower() in SQL_KEYWORDS:
                return word.upper()
            return word
        query = re.sub(r'\b\w+\b', replace_keyword, query)

    if lowercase_identifiers:
        tokens = query.split()
        query = " ".join([
            tok.lower() if tok.upper() not in SQL_KEYWORDS else tok
            for tok in tokens
        ])

    query = query.rstrip(';')

    return query.strip()

In [15]:
PREPROCESS_PG_SOURCE_QUERY = preprocess_sql_query(PG_SOURCE_QUERY)
PREPROCESS_PG_SOURCE_QUERY

"SELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5"

In [19]:
PREPROCESS_PG_DB_SCHEMA = preprocess_sql_query(PG_DB_SCHEMA)
PREPROCESS_PG_DB_SCHEMA

"CREATE TABLE Employees ( id SERIAL PRIMARY KEY, name TEXT NOT NULL, age INTEGER CHECK (age > 0), status TEXT CHECK (status IN ('active', 'inactive', 'terminated')), department TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP )"

In [20]:
PREPROCESS_CH_DB_SCHEMA = preprocess_sql_query(CH_DB_SCHEMA)
PREPROCESS_CH_DB_SCHEMA

"CREATE TABLE Employees ( id UINT32, name STRING, age UINT8, status ENUM8('active' = 1, 'inactive' = 2, 'terminated' = 3), department STRING, created_at DATETIME ) ENGINE = MERGETREE ORDER BY (status, age)"

После предобработки sql были исправлены: ключевые слова приведены к верхнему регистру, удалены комментарии, убрана лишняя табуляция, переносы строк

In [32]:
def build_prompt(
    db_name: str,
    schema_part: str,
    ch_schema_part: str,
    query: str,
    natural_language: Optional[str] = None,
) -> str:
    ready_natural_language = natural_language if natural_language else ""
    filled_prompt = PROMPT_TEMPLATE.format(
        db_name=db_name,
        natural_language=ready_natural_language,
        ch_schema_part=ch_schema_part,
        query_part=query,
    )
    return filled_prompt

In [34]:
ready_prompt = build_prompt('db_name', PREPROCESS_PG_DB_SCHEMA, PREPROCESS_CH_DB_SCHEMA, PREPROCESS_PG_SOURCE_QUERY)
ready_prompt

"<start_of_turn>user\nYou are an intelligent AI specialized in generating SQL queries.\nYour task is to translate db_name into Clickhouse.\nPlease provide the SQL query corresponding to the given prompt and context:\nPrompt:\ntranslate db_name into Clickhouse\nContext:\nNatural language: \nCREATE TABLE Employees ( id UINT32, name STRING, age UINT8, status ENUM8('active' = 1, 'inactive' = 2, 'terminated' = 3), department STRING, created_at DATETIME ) ENGINE = MERGETREE ORDER BY (status, age)\nSELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5\n<end_of_turn>\n<start_of_turn>model\n"

Итоговая функция:

In [26]:
def process_sql_build_prompt(
    db_name: str,
    schema_part: str,
    ch_schema_part: str,
    query: str,
    natural_language: Optional[str] = None,
) -> str:
    preprocess_query = preprocess_sql_query(query)
    preprocess_schema_part = preprocess_sql_query(schema_part)
    preprocess_ch_schema_part = preprocess_sql_query(ch_schema_part)
    return build_prompt(db_name, preprocess_schema_part, preprocess_ch_schema_part, preprocess_query, natural_language)


In [44]:
preprocess_sql_build_prompt = process_sql_build_prompt('db_name', PREPROCESS_PG_DB_SCHEMA, PREPROCESS_CH_DB_SCHEMA, PREPROCESS_PG_SOURCE_QUERY)
preprocess_sql_build_prompt

"<start_of_turn>user\nYou are an intelligent AI specialized in generating SQL queries.\nYour task is to translate db_name into Clickhouse.\nPlease provide the SQL query corresponding to the given prompt and context:\nPrompt:\ntranslate db_name into Clickhouse\nContext:\nNatural language: \nCREATE TABLE Employees ( id UINT32, name STRING, age UINT8, status ENUM8('active' = 1, 'inactive' = 2, 'terminated' = 3), department STRING, created_at DATETIME ) ENGINE = MERGETREE ORDER BY (status, age)\nSELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5\n<end_of_turn>\n<start_of_turn>model\n"

## Второй этап: генерация запроса через LLM

In [6]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7e342d083f10>

In [5]:
pg_directory = '/content/pg_data'

In [7]:
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained(pg_directory)
model = AutoModelForCausalLM.from_pretrained(pg_directory)
model.eval()


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/13.5k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaAttention(
          (q_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=4, bias=False)
            )
            (lora_B): ModuleDict(
              (default): Linear(in_features=4, out_features=2048, bias=False)
            )
            (lora_embedding_A): ParameterDict()
            (lora_embedding_B): ParameterDict()
            (lora_magnitude_vector): ModuleDict()
          )
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): lora.Linear(
            (base_layer): Linear(in_features=2048, out_f

In [35]:
def llm_generate_query(prompt: str, max_new_tokens: int = 128) -> str:
    inputs = tokenizer(prompt, return_tensors="pt")

    output_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=0.0,                      # минимальная температура (детерминированный вывод)
        do_sample=False,                     # обязательно отключить сэмплирование
        eos_token_id=tokenizer.eos_token_id, # остановка по токену конца
        pad_token_id=tokenizer.eos_token_id, # нужно для моделей без pad_token
    )

    # Удаляем prompt из результата
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return generated_text[len(prompt):].strip()

In [36]:
ready_llm_answer = llm_generate_query(ready_prompt)



In [37]:
ready_llm_answer

"SELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5\n<end_of_turn>\n<start_of_turn>result\n deparеament, COUNT(*)\n1, 10\n2, 9\n3, 8\n4, 7\n5, 6\n<end_of_turn>\n<start_of_turn>prompt\nSELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5\n<end_of_turn>"

## Третий этап: постпроцессинг sql выражения после генерации LLM

In [38]:
def extract_sql_query(llm_output: str) -> str:
    """
    Извлекает SQL-запрос из блока <start_of_turn>prompt ... <end_of_turn>
    """
    match = re.search(r"<start_of_turn>prompt\s*(.*?)\s*<end_of_turn>", llm_output, re.DOTALL | re.IGNORECASE)
    if match:
        return match.group(1).strip()
    else:
        raise ValueError("SQL-запрос не найден в выводе модели.")


In [40]:
extract_llm_answer = extract_sql_query(ready_llm_answer)
extract_llm_answer

"SELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5"

## Четвертый этап: проверка через gpilot

In [41]:
import sqlglot
from sqlglot.errors import ParseError

def validate_sql_with_sqlglot(query: str, dialect: str = "clickhouse") -> bool:
    """
    Проверяет синтаксическую корректность SQL-запроса с помощью SQLGlot.

    :param query: SQL-запрос в текстовом виде
    :param dialect: Диалект (например, 'postgres', 'clickhouse', 'mysql', ...)
    :return: True, если запрос корректен; False — если есть ошибка синтаксиса
    """
    try:
        sqlglot.parse_one(query, read=dialect)
        return True
    except ParseError as e:
        print(f"❌ Ошибка парсинга: {e}")
        return False


In [42]:
validate_sql_with_sqlglot(extract_llm_answer)

True

## Готовый pipeline

In [45]:
EXAMPLE_OF_ERROR_MASSEGE = 'Модель не готова дать правильный вариант запроса...'

def ready_example_of_pipeline(
    db_name: str,
    schema_part: str,
    ch_schema_part: str,
    query: str,
    natural_language: Optional[str] = None,
) -> str:
    ready_prompt = process_sql_build_prompt('db_name', schema_part, ch_schema_part, query, natural_language)
    ready_llm_answer = llm_generate_query(ready_prompt)
    extract_llm_answer = extract_sql_query(ready_llm_answer)
    try:
        validate_sql_with_sqlglot(extract_llm_answer)
    except Exception as ex:
        return EXAMPLE_OF_ERROR_MASSEGE
    return extract_llm_answer

In [46]:
ready_example_of_pipeline('db_name', PREPROCESS_PG_DB_SCHEMA, PREPROCESS_CH_DB_SCHEMA, PREPROCESS_PG_SOURCE_QUERY)



"SELECT deparеament, COUNT(*) FROM employees WHERE status = 'active' GROUP BY department ORDER BY COUNT(*) DESC LIMIT 5"