# OracleCoder: QLoRA+ Ensembles with Schema Linking for Text-To-SQL Generation

---

<img src="https://arxiv.org/html/2402.05120v1/x1.png" width="auto" height="250px"></img>


_Created by: [Jordan Deklerk](https://github.com/jordandeklerk), [Visit My Website](https://jordandeklerk.github.io)_

<div class="alert alert-block alert-warning"> 

<b>NOTE:</b> This project is on-going. The results here are not complete. Stay tuned for future updates on the evaluation of OracleCoder.
</div>

## Contents

- [Introduction](#Introduction)
- [Data Preparation](#Data-Preparation)
- [The Model](#The-Model)
- [Prepare the Model](#Prepare-The-Model)
  - [Bits and Bytes Config](#Bits-and-Bytes-Config)
  - [LoRA+ Config](#LoRA+-Config)
  - [Applying LoRA+ in Practice](#Applying-LoRA+-in-Practice)
- [Train the Model](#Train-The-Model)
  - [Train the Model for Schema Linking](#Train-the-Model-for-Schema-Linking)
  - [Merge LoRA+ Adapters for Ensemble Schema Linking](#Merge-LoRA+-Adapters-for-Ensemble-Schema-Linking)
  - [Inference for Schema Linking](#Inference-for-Schema-Linking)
  - [Train the Model for SQL Generation](#Train-the-Model-for-SQL-Generation)
  - [Merge LoRA+ Adapters for Ensemble SQL Generation](#Merge-LoRA+-Adapters-for-Ensemble-SQL-Generation)
  - [Inference for SQL Generation](#Inference-for-SQL-Generation)
- [Inference for the Ensemble Model on the BIRD Benchmark](#Inference-for-the-Ensemble-Model-on-the-BIRD-Benchmark)
- [Conclusion](#Conclusion)

---

## Introduction

This project introduces OracleCoder, a cutting edge QLoRA+ ensemble Large Language Model (LLM) specifically engineered for high-efficiency text-to-SQL conversion. OracleCoder sets itself apart through an innovative two-phase fine-tuning approach. Initially, it employs schema linking to deepen its grasp of database architectures, thereby markedly advancing its SQL generation prowess in the latter phase through adept utilization of the schema's contextual insights. Remarkably, OracleCoder boasts the capability to be both trained and evaluated on a single NVIDIA A100 GPU, positioning it as an exceptionally efficient solution for text-to-SQL tasks. This project is based on the [DTS-SQL paper](https://arxiv.org/pdf/2402.01117.pdf) with several modifications.

Current state-of-the-art text-to-SQL models heavily rely on large proprietary language models like GPT-4, raising concerns about data privacy and cost. To address this, this project proposes decomposing the text-to-SQL task into two simpler sub-tasks: schema linking and SQL generation. By fine-tuning smaller open-source language models separately for each sub-task, they are able to achieve performance comparable to much larger proprietary models.

Our goal is to further explore and extend this decomposed approach. We will investigate techniques to enhance the performance of the individual schema linking and SQL generation stages. This could involve experimenting with different model architectures, fine-tuning strategies, and incorporating additional training data or auxiliary tasks.

Additionally, we plan to explore improved methods for transferring information between the two stages, such as developing more effective schema representations or using attention mechanisms to better align the output of the schema linker with the input of the SQL generator. Specifically, we will implement the new **LoRA+** method outlined in the [LoRA+ paper](https://arxiv.org/pdf/2402.12354.pdf). The "+" in LoRA+ indicates an improvement over the standard LoRA, specifically by setting different learning rates for the adapter matrices which increases computation time of fine-tuning.

We hope to develop a text-to-SQL system that rivals the performance of large proprietary models while using only smaller open-source components. This will help democratize access to high-quality text-to-SQL technology and mitigate concerns around privacy and cost.

## Data Preparation

In this project we will be using the [Spider dataset](https://arxiv.org/pdf/1809.08887.pdf) for training. The Spider dataset is a large-scale, complex and cross-domain semantic parsing and text-to-SQL dataset. It consists of 10,181 questions and 5,693 unique complex SQL queries on 200 databases with multiple tables, covering 138 different domains.

Spider is distinct from most previous semantic parsing datasets in several key ways:

1. The databases contain multiple tables and cover different domains, so the dataset tests a system's ability to generalize to both new SQL queries and new database schemas.

2. The queries are complex, including many SQL clauses like GROUP BY, ORDER BY, INTERSECT, nested queries, etc.

3. Different complex SQL queries and databases appear in train and test sets, rather than having the same SQL query patterns in both.

4. It is significantly larger than most previous semantic parsing datasets.

Experiments with various state-of-the-art semantic parsing models on Spider achieve only up to 12.4% exact matching accuracy, showing that the dataset presents a strong challenge for future research in this area. The dataset and task aim to improve the real-world applicability of text-to-SQL systems.

## The Model

Our fine-tuning efforts will be centered around the [m-a-p/OpenCodeInterpreter-DS-6.7B](https://huggingface.co/m-a-p/OpenCodeInterpreter-DS-6.7B) model. Check out their [website](https://opencodeinterpreter.github.io/) for a detailed exposition of the model and data. As of March 13, 2024, [m-a-p/OpenCodeInterpreter-DS-33B](https://huggingface.co/m-a-p/OpenCodeInterpreter-DS-33B) has claimed the #1 spot on [BigCode Leaderboard](https://huggingface.co/spaces/bigcode/bigcode-models-leaderboard) and is also the best performing (small) model (less 7B parameters) on the [EvalPlus coding leaderboard](https://evalplus.github.io/leaderboard.html).

EvalPlus is a code synthesis evaluation framework that rigorously assesses the functional correctness of code generated by Large Language Models (LLMs). By extending the test-cases of the HUMANEVAL benchmark by 80 times, EvalPlus reveals that a significant amount of LLM-synthesized code, previously thought to be correct, actually fails the additional tests, highlighting the importance of comprehensive testing in evaluating the performance of LLMs for code synthesis. You can find the full details in [this paper](https://openreview.net/pdf?id=1qvx610Cu7).

---

To get started, let's install all the necessary libraries. As you can see, in addition to `transformers` and `datasets`, we'll be using `peft`, `bitsandbytes`, and `flash-attn` to optimize the training. We will use `wandb` to track training and evaluation metrics such as loss and learning rate.

In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [2]:
%%capture
!pip install git+https://github.com/huggingface/transformers.git "huggingface_hub[cli]" --upgrade --quiet

In [3]:
%%capture
!pip install -q datasets peft bitsandbytes flash-attn gradio trl wandb sql_metadata

In [4]:
%%capture
!huggingface-cli login --token YOUR KEY

In [5]:
import os
import pandas as pd
import re
import random

from dataclasses import dataclass, field
from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple, Union
from sql_metadata import Parser
from datasets import load_dataset

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from peft.tuners import lora
from transformers.data.data_collator import DataCollator
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer import (EvalPrediction, PreTrainedModel,
                                  PreTrainedTokenizerBase, TrainerCallback)
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils import is_sagemaker_mp_enabled, logging

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    Trainer,
    TrainingArguments,
    logging,
    set_seed,
    BitsAndBytesConfig,
    EarlyStoppingCallback
)

import wandb
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

set_seed(315)

import warnings
warnings.filterwarnings("ignore")

In [6]:
!python /content/drive/MyDrive/Colab_Notebooks/LLM_Fine-Tuning/src/finetuning_dataset_creator.py

tokenizer_config.json: 100% 5.18k/5.18k [00:00<00:00, 24.8MB/s]
tokenizer.json: 100% 1.37M/1.37M [00:00<00:00, 4.21MB/s]
special_tokens_map.json: 100% 462/462 [00:00<00:00, 2.79MB/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
100% 8659/8659 [13:53<00:00, 10.39it/s]
Full finetuning set errors: 246
Filtered finetuning set errors: 1
100% 1034/1034 [01:39<00:00, 10.44it/s]
Filtered validation set errors: 0
Validation set formatted errors: 0
Validation set errors: 0


In our project, we will use the `Spider` dataset, which is a benchmark natural language text-to-SQL dataset for fine-tuning instruction models to generate SQL queries from English text inputs.

You can download the dataset from [here](https://drive.google.com/uc?export=download&id=1TqleXec_OykOYFREKKtschzY29dUcVAQ). The script `finetuning_dataset.py` to preprocess the `Spider` dataset for this project can be found [here](https://github.com/jordandeklerk/OracleCoder).

Here is a sample from the dataset:

```json
  {
    "question": "Return me the number of authors who have papers in the VLDB conference.",

    "context": "CREATE TABLE `author` (
                    aid INT PRIMARY KEY,
                    homepage TEXT,
                    name TEXT,
                    oid INT
                  );

                  CREATE TABLE `conference` (
                    cid INT PRIMARY KEY,
                    homepage TEXT,
                    name TEXT
                  );

                  CREATE TABLE `domain` (
                    did INT PRIMARY KEY,
                    name TEXT
                  );

                  CREATE TABLE `domain_author` (
                    aid INT PRIMARY KEY REFERENCES author(aid),
                    did INT PRIMARY KEY REFERENCES domain(did)
                  );

                  CREATE TABLE `domain_conference` (
                    cid INT PRIMARY KEY REFERENCES conference(cid),
                    did INT PRIMARY KEY REFERENCES domain(did)
                  );

                  CREATE TABLE `journal` (
                    homepage TEXT,
                    jid INT PRIMARY KEY,
                    name TEXT
                  );

                  CREATE TABLE `domain_journal` (
                    did INT PRIMARY KEY REFERENCES domain(did),
                    jid INT PRIMARY KEY REFERENCES journal(jid)
                  );

                  CREATE TABLE `keyword` (
                    keyword TEXT,
                    kid INT PRIMARY KEY
                  );

                  CREATE TABLE `domain_keyword` (
                    did INT PRIMARY KEY REFERENCES domain(did),
                    kid INT PRIMARY KEY REFERENCES keyword(kid)
                  );

                  CREATE TABLE `publication` (
                    abstract TEXT,
                    cid TEXT REFERENCES conference(cid),
                    citation_num INT,
                    jid INT REFERENCES journal(jid),
                    pid INT PRIMARY KEY,
                    reference_num INT,
                    title TEXT,
                    year INT
                  );

                  CREATE TABLE `domain_publication` (
                    did INT PRIMARY KEY REFERENCES domain(did),
                    pid INT PRIMARY KEY REFERENCES publication(pid)
                  );

                  CREATE TABLE `organization` (
                    continent TEXT,
                    homepage TEXT,
                    name TEXT,
                    oid INT PRIMARY KEY
                  );

                  CREATE TABLE `publication_keyword` (
                    pid INT PRIMARY KEY REFERENCES publication(pid),
                    kid INT PRIMARY KEY REFERENCES keyword(kid)
                  );

                  CREATE TABLE `writes` (
                    aid INT PRIMARY KEY REFERENCES author(aid),
                    pid INT PRIMARY KEY REFERENCES publication(pid)
                  );

                  CREATE TABLE `cite` (
                    cited INT REFERENCES publication(pid),
                    citing INT REFERENCES publication(pid)
                  );
                  ",

    "answer": "SELECT COUNT(DISTINCT t1.name) FROM publication as t4 JOIN conference as t2 ON t4.cid  =  t2.cid JOIN writes as t3 ON t3.pid  =  t4.pid JOIN author as t1 ON t3.aid  =  t1.aid WHERE t2.name  =  "VLDB";"
  }

```

We are going to use `trl` for fine-tuning, which supports popular instruction and conversation dataset formats. This means we only need to convert our dataset to one of the supported formats and `trl` will take care of the rest. Those formats include:

- conversational format

```json
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
{"messages": [{"role": "system", "content": "You are..."}, {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]}
```

- instruction format

```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

In our example we are going to load our open-source dataset using the 🤗 Datasets library and then convert it into the the conversational format, where we include the schema definition in the system message for our assistant.

Let's define some global variables.

In [7]:
MODEL="m-a-p/OpenCodeInterpreter-DS-6.7B"     # Model checkpoint on the Hugging Face Hub
SEQ_LENGTH=2100                               # Sequence length

# Training arguments
NUM_EPOCHS=1                                  # num_train_epochs
BATCH_SIZE=4                                  # batch_size
GR_ACC_STEPS=8                                # gradient_accumulation_steps
LR=5e-5                                       # learning_rate
LR_SCHEDULER_TYPE="cosine"                    # lr_scheduler_type
WEIGHT_DECAY=1e-1                             # weight_decay
EVAL_FREQ=50                                  # eval_freq
SAVE_FREQ=50                                  # save_freq
LOG_FREQ=50                                   # log_freq
OUTPUT_DIR="OCI-DS-6.7B-schema-linking"       # output_dir
BF16=True                                     # bf16
FP16=False                                    # no_fp16

# LORA
LORA_R=8                                      # lora_r
LORA_ALPHA=32                                 # lora_alpha
LORA_DROPOUT=0.1                              # lora_dropout
LORA_TARGET_MODULES="q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj","lm_head"    # lora_target_modules

# bitsandbytes config
USE_NESTED_QUANT=True                         # use_nested_quant
BNB_4BIT_COMPUTE_DTYPE="bfloat16"             # bnb_4bit_compute_dtype

SEED=315

Let's load the training and validation data that we formatted ealier.

In [8]:
data_files = {"train": "full_finetuning_dataset.csv", "validation": "validation_dataset_formatted.csv"}
dataset = load_dataset('csv', data_files=data_files)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Let's format our dataset into the correct style for the SFT Trainer class.

In [None]:
def create_conversation(sample):
    question = sample['question']
    query = sample['query']
    database_schema = sample['database_schema']

    tables = re.findall(r'FROM\s+(\w+)', query, re.IGNORECASE)
    columns = re.findall(r'SELECT\s+(.+?)\s+FROM', query, re.IGNORECASE)

    if columns:
        columns = ", ".join(set(columns[0].split(", ")))

    if tables:
        tables = ", ".join(set(tables))

    system_message = f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.

{database_schema}

"""
    user_message = f"# Question: {question}"

    assistant_message = f"""
```SQL
-- Columns: {columns}
-- Tables: {tables}
"""
    return {
        "messages": [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": assistant_message}
        ]
    }

In [None]:
dataset = dataset.map(create_conversation, batched=False)

Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

In [None]:
print(dataset["train"][0]["messages"])

[{'content': 'Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.\n\nCREATE TABLE `device` (\n  Device_ID INT PRIMARY KEY,\n  Device TEXT,\n  Carrier TEXT,\n  Package_Version TEXT,\n  Applications TEXT,\n  Software_Platform TEXT\n);\nSample rows from `device`:\n1, BlackBerry Storm 9530, MTS Mobility, 5.0.0.808, 5.0.0.419, Android\n2, Apple, Verizon Wireless, 5.0.0.328, 5.0.0.328, iOS\n3, Huawei, Telus Mobility, 5.0.0.419, 5.0.0.419, Android\n\nCREATE TABLE `shop` (\n  Shop_ID INT PRIMARY KEY,\n  Shop_Name TEXT,\n  Location TEXT,\n  Open_Date TEXT,\n  Open_Year INT\n);\nSample rows from `shop`:\n1, Dinas Device, Dinas, 1 January, 2014\n2, Best Buy, Cymmer, 15 July, 2006\n3, Ferndale, Blaenllechau, 8 November, 2009\n\nCREATE TABLE `stock` (\n  Shop_ID INT PRIMARY KEY REFERENCES shop(Shop_ID),\n  Device_ID INT PRIMARY KEY REFERENCES device(Device_ID),\n  Quantity INT\n);\nSample rows from `stock`:\n1, 6, 100\n2, 6, 110\n3, 6, 

## Prepare the model

### Bits and Bytes Config

Now that the data is prepared, it's time to load the model. We're going to load the quantized version of the model.

This will allow us to reduce memory usage, as quantization represents data with fewer bits. We'll use the `bitsandbytes` library to quantize the model, as it has a nice integration with `transformers`. All we need to do is define a `bitsandbytes` config, and then use it when loading the model.

There are different variants of 4bit quantization, but generally, it is recommended to use NF4 quantization for better performance (`bnb_4bit_quant_type="nf4"`).

The `bnb_4bit_use_double_quant` option adds a second quantization after the first one to save an additional 0.4 bits per parameter.

To learn more about quantization, check out the ["Making LLMs even more accessible with bitsandbytes, 4-bit quantization and QLoRA" blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).

Once defined, pass the config to the `from_pretrained` method to load the quantized version of the model.

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from peft.tuners.lora import LoraLayer

load_in_8bit = False

# 4-bit quantization
compute_dtype = getattr(torch, BNB_4BIT_COMPUTE_DTYPE)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=USE_NESTED_QUANT,
)

device_map = {"": 0}

model = AutoModelForCausalLM.from_pretrained(
        MODEL,
        load_in_8bit=load_in_8bit,
        quantization_config=bnb_config,
        device_map=device_map,
        use_cache=False,  # We will be using gradient checkpointing
        trust_remote_code=True,
        attn_implementation="flash_attention_2",
)

tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.padding_side = "right"

config.json:   0%|          | 0.00/716 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


When using a quantized model for training, you need to call the `prepare_model_for_kbit_training()` function to preprocess the quantized model.

In [None]:
model = prepare_model_for_kbit_training(model)
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaLinearScalingRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )


### LoRA+ Config

Now that the quantized model is ready, we can set up a LoRA+ configuration. LoRA+ makes fine-tuning more efficient by drastically reducing the number of trainable parameters just as the taditional LoRA approach, except we can now specify different learning rates for the adapter matrices. In the paper, the authors indicate that this improves computational fine-tuning by 2x the speed in some cases.

To train a model using LoRA+ technique, we need to wrap the base model as a `PeftModel`. This involves defining a LoRA+ configuration with `LoraConfig`, and wrapping the original model with `get_peft_model()` using the `LoraConfig`.

To learn more about the original LoRA and its parameters, refer to [PEFT documentation](https://huggingface.co/docs/peft/conceptual_guides/lora). To learn more about LoRA+, refer to the [LoRA+ paper](https://arxiv.org/pdf/2402.12354.pdf).

The author's originally created a new 🤗 `Trainer` class called the `LoraPlusTrainer`, which is a wrapper for the 🤗 `Trainer` class with a modified optimizer function to allow for different learning rates on the adapter matrices. However, we will be using the `SFTTrainer`, so we need to modify the original `LoraPlusTrainer` slightly to wrap around the `SFTTrainer` as well.

You can find the original LoRA+ code [here](https://github.com/nikhil-ghosh-berkeley/loraplus/blob/main/lora_plus.py).

#### Applying LoRA+ in Practice

LoRA+ introduces one new required hyperparameter to your optimizer (and another optional hyperparameter). Setting this hyperparameter appropriately can improve finetuning performance, especially on more challenging downstream tasks.

**LoRA+ arguments:**

`loraplus_lr_ratio`: the ratio of learning rates $\eta_A / \eta_B$  where $\eta_A$  is passed in as the optimizer learning rate (e.g., learning_rate or lr). As a rule of thumb, `loraplus_lr_ratio` should be larger when the task is more difficult and the model needs to update its features to learn well. In this case, it helps to make the learning rate slightly smaller (e.g., by a factor of 2) than typical vanilla LoRA learning rates.

`loraplus_lr_embedding`: (optional) if LoRA modules are added to embedding layers, you can specify a different learning rate for them. Default value 1e-6.
Note that `loraplus_lr_ratio` should be greater than 1, and when it is equal to 1 this is just the regular LoRA configuration. The optimal choice of `loraplus_lr_ratio` is model and task dependent and needs to be set in tandem with the optimizer learning rate.

In [None]:
# Modified from https://github.com/nikhil-ghosh-berkeley/loraplus/blob/main/lora_plus.py

from dataclasses import dataclass, field
from functools import reduce
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
from torch.utils.data import Dataset

from peft.tuners import lora
from transformers import Trainer, TrainingArguments
from transformers.data.data_collator import DataCollator
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer import EvalPrediction, PreTrainedModel, PreTrainedTokenizerBase, TrainerCallback
from transformers.trainer_pt_utils import get_parameter_names
from transformers.utils import is_sagemaker_mp_enabled, logging

from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

if is_sagemaker_mp_enabled():
    import smdistributed.modelparallel.torch as smp

logger = logging.get_logger(__name__)


@dataclass
class LoraPlusTrainingArguments(TrainingArguments):
    loraplus_lr_ratio: Optional[float] = field(
        default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
    )
    loraplus_lr_embedding: Optional[float] = field(
        default=1e-6,
        metadata={"help": "loraplus learning rate for lora embedding layers."},
    )


def get_module(name, opt_model):
    """
    Retrieve a module from a model using its parameter name.
    Args:
        name (str): Full name of the parameter, typically including module path.
        opt_model (torch.nn.Module): The model from which to retrieve the module.

    Returns:
        Module corresponding to the given name.
    """
    parent_idx = 2 if "lora" in name else 1
    module_names = name.split(sep=".")[:-parent_idx]
    module = reduce(getattr, module_names, opt_model)
    return module


def create_loraplus_optimizer(
    opt_model,
    optimizer_cls,
    optimizer_kwargs,
    loraplus_lr_ratio,
    loraplus_lr_embedding=None,
):
    """
    Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.

    Args:
        opt_model (torch.nn.Module): The model for which the optimizer is being created.
        optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
        optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
        loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
        loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.

    Returns:
        An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
    """

    assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."

    if loraplus_lr_embedding is None:
        loraplus_lr_embedding = 1e-6

    decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
    decay_parameters = [name for name in decay_parameters if "bias" not in name]
    param_groups = {
        "groupA": {},
        "groupB": {},
        "groupB_no_decay": {},
        "embedding": {},
    }

    for name, param in opt_model.named_parameters():
        if not param.requires_grad:
            continue

        module = get_module(name, opt_model)
        if isinstance(module, lora.Embedding):
            param_groups["embedding"][name] = param
        elif "lora_B" in name or param.ndim == 1:
            if name in decay_parameters:
                param_groups["groupB"][name] = param
            else:
                param_groups["groupB_no_decay"][name] = param
        else:
            param_groups["groupA"][name] = param

    assigned_param_groups = ""
    for group in param_groups:
        assigned_param_groups += f"{group}\n {list(param_groups[group].keys())}\n\n"
    logger.info(assigned_param_groups)

    lr = optimizer_kwargs["lr"]
    weight_decay = optimizer_kwargs.get("weight_decay", 0.0)

    optimizer_grouped_parameters = [
        {
            "params": list(param_groups["groupA"].values()),
            "weight_decay": weight_decay,
            "lr": lr,
        },
        {
            "params": list(param_groups["embedding"].values()),
            "weight_decay": weight_decay,
            "lr": loraplus_lr_embedding,
        },
        {
            "params": list(param_groups["groupB"].values()),
            "weight_decay": weight_decay,
            "lr": lr * loraplus_lr_ratio,
        },
        {
            "params": list(param_groups["groupB_no_decay"].values()),
            "weight_decay": 0.0,
            "lr": lr * loraplus_lr_ratio,
        },
    ]

    optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
    if optimizer_cls.__name__ == "Adam8bit":
        import bitsandbytes

        manager = bitsandbytes.optim.GlobalOptimManager.get_instance()

        skipped = 0
        for module in opt_model.modules():
            if isinstance(module, nn.Embedding):
                skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values())
                logger.info(f"skipped {module}: {skipped/2**20}M params")
                manager.register_module_override(module, "weight", {"optim_bits": 32})
                logger.debug(f"bitsandbytes: will optimize {module} in fp32")
        logger.info(f"skipped: {skipped/2**20}M params")

    return optimizer

# Wrap the SFTTrainer
class LoraPlusSFTTrainer(SFTTrainer):
    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: LoraPlusTrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
        peft_config: Optional["PeftConfig"] = None,
        dataset_text_field: Optional[str] = None,
        packing: Optional[bool] = False,
        formatting_func: Optional[Callable] = None,
        max_seq_length: Optional[int] = None,
        infinite: Optional[bool] = None,
        num_of_sequences: Optional[int] = 1024,
        chars_per_token: Optional[float] = 3.6,
        dataset_num_proc: Optional[int] = None,
        dataset_batch_size: int = 1000,
        neftune_noise_alpha: Optional[float] = None,
        model_init_kwargs: Optional[Dict] = None,
        dataset_kwargs: Optional[Dict] = None,
        eval_packing: Optional[bool] = None,
    ):
        if args.loraplus_lr_ratio is not None:
            opt_model = model.module if is_sagemaker_mp_enabled() else model
            optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args)
            loraplus_lr_ratio = getattr(args, "loraplus_lr_ratio", None)
            loraplus_lr_embedding = getattr(args, "loraplus_lr_embedding", None)
            optimizer = create_loraplus_optimizer(
                opt_model,
                optimizer_cls,
                optimizer_kwargs,
                loraplus_lr_ratio,
                loraplus_lr_embedding,
            )
            optimizers = (optimizer, None)

        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
            peft_config=peft_config,
            dataset_text_field=dataset_text_field,
            packing=packing,
            formatting_func=formatting_func,
            max_seq_length=max_seq_length,
            infinite=infinite,
            num_of_sequences=num_of_sequences,
            chars_per_token=chars_per_token,
            dataset_num_proc=dataset_num_proc,
            dataset_batch_size=dataset_batch_size,
            neftune_noise_alpha=neftune_noise_alpha,
            model_init_kwargs=model_init_kwargs,
            dataset_kwargs=dataset_kwargs,
            eval_packing=eval_packing,
        )

In [None]:
# Set up lora
peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    r=LORA_R,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
        "lm_head"
    ]
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 20,279,296 || all params: 6,760,792,064 || trainable%: 0.29995444036777286


As you can see, by applying LoRA technique, we will now need to train less than 1% of the parameters which significantly boosts training speed.

## Train the Model

**Noisy Embeddings**

We will be using noisy embeddings for fine-tuning outlined in [NEFTune: Noisy Embeddings Improve Instruction Finetuning](https://arxiv.org/pdf/2310.05914.pdf). This method adds random noise to the embedding vectors of the training data during the forward pass of fine-tuning. Specifically, each step of NEFTune begins by sampling an instruction from the dataset, and converting its tokens to embedding vectors. NEFTune then departs from standard training by adding a random noise vector to the embeddings. The noise is generated by sampling iid uniform entries, each in the range $[-1,1]$, and then scaling the entire noise vector by a factor of $\alpha / \sqrt{L d}$, where $L$ is the sequence length, $d$ is the embedding dimension, and $\alpha$ is a tunable parameter. The author's note that the scaling rule was borrowed from the adversarial ML literature, and results in a random vector with an expected Euclidean magnitude of approximately $\alpha / \sqrt{3}$.

The 🤗 `SFTTrainer` provides an easy way to incorporate NEFTTune by simply passing the `neftune_noise_alpha` argument into the trainer, which is the `LoraPlusSFTTrainer` in our case. The authors of the NEFTTune paper show that noisy embeddings can dramatically improve fine-tuning accuracy on downstream instruction tasks in many cases without sacrificing any computational over-head, or as they call it, a free lunch for LLM fine-tuning!

Now that we have prepared the data, and optimized the model, we are ready to bring everything together to start the training.

To instantiate a `LoraPlusSFTTrainer`, you need to define the training configuration like we normally do for any 🤗 trainer. The most important is the `LoraPlusTrainingArguments`, which is a class that contains all the attributes to configure the training just like the typical `TrainingArguments` class.

### Train the Model for Schema Linking

First, we will fine-tune the base model for the schema linking task. Schema linking involves identifying the pertinent columns and tables in a database in response to natural language queries. It has been demonstrated to enhance cross-domain generalizability and facilitate the creation of intricate queries, [(Lei et al., 2020)](https://zhixinma.github.io/papers/20-EMNLP.pdf).

#### Ensemble For LLMs

Large Language Models (LLMs) fine-tuned for specific tasks often struggle with accurately assessing uncertainties, leading to overconfidence, poor calibration, and unreliable predictions, especially on new or atypical data. A method commonly employed in the field of computer vision to mitigate such issues involves creating a deep ensemble, where a single model is trained multiple times with different initial starting points.

However, applying this technique to LLMs presents a significant obstacle due to their immense size; maintaining even one LLM in memory is difficult, let alone an ensemble of, say, five. Therefore, instead of fine-tuning multiple LLMs, we suggest ensembling Low-Rank Adapters (LoRA). These adapters enable the formation of large ensembles with nearly the same computational cost as the singular use of the base model. For more context, see [Wang et al., 2023](https://arxiv.org/pdf/2310.00035.pdf).

In our two-stage fine-tuning process, we will be keeping the LoRA+ ratio of learning rates, $\eta_A / \eta_B$, the same for each sub-model, and we will vary the learning rates of the LoRA adapters that get added to the **embedding layers** of the model. In future work, we will go back and create ensembles by adjusting the ratio $\eta_A / \eta_B$.

Now we can instantiate the `LoraPlusSFTTrainer` and call the `train` method to begin fine-tuning our LLM for the schema linking task.

In [None]:
lr_embedding_values = [1e-3, 5e-6, 1e3]

for i, lr_embedding in enumerate(lr_embedding_values):
    os.environ["WANDB_PROJECT"] = f'{OUTPUT_DIR}_{i}'
    os.environ["WANDB_API_KEY"] = 'YOUR KEY'

    train_dataset = dataset['train']
    train_dataset = train_dataset.shuffle()

    training_args = LoraPlusTrainingArguments(
        output_dir=f"jdeklerk10/{OUTPUT_DIR}_{i}",
        dataloader_drop_last=True,
        overwrite_output_dir=True,
        num_train_epochs=NUM_EPOCHS,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=EVAL_FREQ,
        save_steps=SAVE_FREQ,
        logging_steps=LOG_FREQ,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LR,
        lr_scheduler_type=LR_SCHEDULER_TYPE,
        warmup_ratio=0.01,
        max_grad_norm=0.3,
        group_by_length=True,
        auto_find_batch_size=False,
        gradient_accumulation_steps=GR_ACC_STEPS,
        gradient_checkpointing=True,
        save_total_limit=3,
        fp16=FP16,
        bf16=BF16,
        weight_decay=WEIGHT_DECAY,
        push_to_hub=True,
        include_tokens_per_second=True,
        loraplus_lr_ratio=1.25,  # LoRA+ learning rate for the B matrix
        loraplus_lr_embedding=lr_embedding,  # LoRA+ learning rate for the embedding matrix
        report_to='wandb',
        load_best_model_at_end=False,
    )

    trainer = LoraPlusSFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        peft_config=peft_config,
        max_seq_length=2100,
        tokenizer=tokenizer,
        neftune_noise_alpha=5,  # Author's typically use 5 for smaller models from what I can see in the paper
        packing=False
    )

    trainer.train()

    output_dir = os.path.join("./", f"final_checkpoint_{i}")
    trainer.model.save_pretrained(output_dir)
    trainer.push_to_hub()

    wandb.finish()

Step,Training Loss,Validation Loss
50,0.8424,0.666993
100,0.6032,0.642469
150,0.5435,0.641393
200,0.4812,0.629559
250,0.4668,0.63167


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▃▃▁▁
eval/runtime,▂█▇▁▅
eval/samples_per_second,▇▁▂█▄
eval/steps_per_second,█▁▁█▁
train/epoch,▁▁▃▃▄▄▆▆███
train/global_step,▁▁▃▃▄▄▆▆███
train/grad_norm,▁▃▇█▇
train/learning_rate,█▆▄▂▁
train/loss,█▄▂▁▁
train/train_tokens_per_second,▁

0,1
eval/loss,0.63167
eval/runtime,158.5962
eval/samples_per_second,6.52
eval/steps_per_second,1.633
total_flos,3.025708008894628e+17
train/epoch,1.0
train/global_step,262.0
train/grad_norm,0.71549
train/learning_rate,0.0
train/loss,0.4668


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

[34m[1mwandb[0m: Currently logged in as: [33mjdeklerk10[0m ([33mjdeklerk[0m). Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
50,0.4609,0.652599
100,0.3921,0.665111
150,0.3486,0.689598
200,0.298,0.717382
250,0.2739,0.71966


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.002 MB of 0.023 MB uploaded\r'), FloatProgress(value=0.09817983559805402, max=1.…

0,1
eval/loss,▁▂▅██
eval/runtime,█▄▁▆▆
eval/samples_per_second,▁▄█▃▃
eval/steps_per_second,▁▁█▁▁
train/epoch,▁▁▃▃▄▄▆▆███
train/global_step,▁▁▃▃▄▄▆▆███
train/grad_norm,▂▃█▂▁
train/learning_rate,█▆▄▂▁
train/loss,█▅▄▂▁
train/train_tokens_per_second,▁

0,1
eval/loss,0.71966
eval/runtime,158.7069
eval/samples_per_second,6.515
eval/steps_per_second,1.632
total_flos,3.025708008894628e+17
train/epoch,1.0
train/global_step,262.0
train/grad_norm,0.986
train/learning_rate,0.0
train/loss,0.2739


Step,Training Loss,Validation Loss
50,0.2753,0.733626
100,0.2249,0.781703
150,0.1951,0.838915
200,0.1657,0.81118
250,0.1503,0.859156


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.023 MB of 0.023 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,▁▄▇▅█
eval/runtime,▁▂▁▅█
eval/samples_per_second,█▇█▄▁
eval/steps_per_second,███▁▁
train/epoch,▁▁▃▃▄▄▆▆███
train/global_step,▁▁▃▃▄▄▆▆███
train/grad_norm,██▇▁▁
train/learning_rate,█▆▄▂▁
train/loss,█▅▄▂▁
train/train_tokens_per_second,▁

0,1
eval/loss,0.85916
eval/runtime,158.7378
eval/samples_per_second,6.514
eval/steps_per_second,1.632
total_flos,3.025708008894628e+17
train/epoch,1.0
train/global_step,262.0
train/grad_norm,0.30248
train/learning_rate,0.0
train/loss,0.1503


You can push the fine-tuned model to your Hub repository. To fine-tune 3 different versions of the base model, it took about 4 hours!

### Merge LoRA+ Adapters for Ensemble Schema Linking

Now we are ready to merge the LoRA adapters into a single ensemble model.

In [9]:
from peft import PeftModel
import torch

OUTPUT_DIR="OCI-DS-6.7B-schema-linking_0"  # output_dir
device='cuda'

peft_model_id = f"jdeklerk10/{OUTPUT_DIR}"
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    attn_implementation="flash_attention_2",
    quantization_config=None,
    device_map=None,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).cuda()

schema_model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name="schema_lr_0")

weighted_adapter_name = "schema_merged"
schema_model.load_adapter("jdeklerk10/OCI-DS-6.7B-schema-linking_1", adapter_name="schema_lr_1")
schema_model.load_adapter("jdeklerk10/OCI-DS-6.7B-schema-linking_2", adapter_name="schema_lr_2")
schema_model.add_weighted_adapter(
    adapters=["schema_lr_0", "schema_lr_1", "schema_lr_2"],
    weights=[0.6, 0.2, 0.2],
    adapter_name=weighted_adapter_name,
    combination_type="linear"
)
schema_model.set_adapter(weighted_adapter_name)

tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

schema_model.merge_and_unload()

config.json:   0%|          | 0.00/716 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/25.1k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/119 [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/329 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaLinearScalingRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
 

### Inference for Schema Linking

Let's see how our model performs on the Spider validation set for the schema linking task. It is important to understand how we will be evaluating the model.

#### Evaluation Process

The evaluation process begins by iterating over each row in the validation dataset. For each row, we extract the question, query, database schema, and database ID. We construct a user message by combining the database schema and the question, which is then tokenized and passed to the model for generating a response. The model's response is processed to extract the predicted tables.

Next, we compare the predicted tables with the reference tables obtained from the actual query. We calculate several metrics to quantify the model's performance:

1. **Accuracy**: If the predicted tables exactly match the reference tables, we count this as an accurate prediction.

2. **Filtered Accuracy**: We also calculate a filtered accuracy, which only considers a prediction as correct if all the reference tables are correctly identified, regardless of any additional predicted tables.

3. **Precision and Recall**: For each question, we calculate precision and recall. Precision measures the proportion of predicted tables that are actually relevant, while recall measures the proportion of relevant tables that are successfully predicted by the model. These metrics provide a more nuanced understanding of the model's performance.

After evaluating all the questions, we calculate the average precision, average recall, total accuracy, and filtered accuracy across the entire validation dataset. These metrics give us a comprehensive view of how well the model performs in identifying the relevant tables for a given question.

In [43]:
from transformers import StoppingCriteria

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [6203]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids

def append_string_to_file(text, file_path):
  with open(file_path, 'a') as file:
      file.write(text + '\n')

def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)

def call_oci(inputs):
  output_tokens = schema_model.generate(inputs,
                                 max_new_tokens=250,
                                 do_sample=False,
                                 pad_token_id=tokenizer.eos_token_id,
                                 eos_token_id=tokenizer.eos_token_id,
                                 stopping_criteria = [EosListStoppingCriteria()])
  return tokenizer.decode(output_tokens[0][len(inputs[0]):], skip_special_tokens=True)

In [None]:
df = pd.read_csv("validation_dataset.csv")
results = []
for index, row in tqdm(df.iterrows(), total=len(df)):
  question = row['question']
  query = row['query']
  database_schema = row['database_schema']
  db_id = row['db_id']
  user_message = f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.
{database_schema}
###
Question: {question}
"""
  messages = [
      {"role": "user", "content": user_message.strip()}
  ]
  inputs = tokenizer.apply_chat_template(messages, return_tensors="pt",add_generation_prompt=True,tokenize = True).to(device)
  response = call_oci(inputs)
  if ";" in response:
    response = response.split(";")[0]
    if "Tables:" in response:
      response = response.split("Tables:")[1]
  response = re.sub(r'\s+', ' ', response).strip()
  try:
    ref_tables = ", ".join(Parser(query).tables)
  except Exception:
    continue
  print("\n")
  print(response)
  print(ref_tables)
  print("============================")
  results.append([response, ref_tables, query,row['question'],row['db_id']])
  new_df = pd.DataFrame(results, columns = ['predicted_tables','reference_tables','query','question','db_id'])

In [4]:
import re
import pandas as pd

new_df = pd.read_csv("results/generated_schema_links.csv")

total_samples = len(new_df)
total_accuracy = 0
filtered_accuracy = 0
total_precision = 0
total_recall = 0

for index, row in new_df.iterrows():
    if not row['predicted_tables'] or pd.isna(row['predicted_tables']):
        continue
    
    # Extract table names from the predicted_tables using regular expressions
    predicted_tables = re.findall(r'(?:FROM|JOIN)\s+(\w+)', row['predicted_tables'], re.IGNORECASE)
    predicted_tables = [table.split(' ')[0] for table in predicted_tables]
    reference_tables = row['reference_tables'].split(", ")
    
    predicted_tables = [x.lower() for x in predicted_tables]
    reference_tables = [x.lower() for x in reference_tables]
    
    # Calculate accuracy
    if set(predicted_tables) == set(reference_tables):
        total_accuracy += 1
    
    # Calculate precision and recall
    true_positives = len(set(predicted_tables) & set(reference_tables))
    false_positives = len(set(predicted_tables) - set(reference_tables))
    false_negatives = len(set(reference_tables) - set(predicted_tables))
    
    if true_positives == len(reference_tables):
        filtered_accuracy += 1
    
    if len(predicted_tables) > 0:
        precision = true_positives / (true_positives + false_positives)
        recall = true_positives / (true_positives + false_negatives)
        total_precision += precision
        total_recall += recall

avg_precision = (total_precision / total_samples) * 100
avg_recall = (total_recall / total_samples) * 100

accuracy = (total_accuracy / total_samples) * 100
filtered_accuracy = (filtered_accuracy / total_samples) * 100

print("Exact Accuracy: {:.2f}%".format(accuracy))
print("Filtered Accuracy: {:.2f}%".format(filtered_accuracy))
print("Average Precision: {:.2f}%".format(avg_precision))
print("Average Recall: {:.2f}%".format(avg_recall))

Exact Accuracy: 73.31%
Filtered Accuracy: 94.68%
Average Precision: 88.21%
Average Recall: 96.79%


Let's explain these evaluation metrics in a little more detail. The code calculates four evaluation metrics to assess the performance of a table prediction model:

1. **Total Accuracy**:
   - Measures the percentage of samples where the predicted tables exactly match the reference tables.
   - Compares the set of predicted tables with the set of reference tables for each sample.
   - Calculated as: `(total_accuracy / total_samples) * 100`

2. **Filtered Accuracy**:
   - Measures the percentage of samples where the predicted tables contain all the reference tables, regardless of additional predicted tables.
   - Counts the number of true positives (table names present in both predicted and reference tables) for each sample.
   - If true positives equal the number of reference tables, increments the `filtered_accuracy` counter.
   - Calculated as: `(filtered_accuracy / total_samples) * 100`

3. **Average Precision**:
   - Measures the proportion of predicted tables that are actually relevant (present in the reference tables).
   - Calculates true positives (table names in both predicted and reference tables) and false positives (table names in predicted tables but not in reference tables) for each sample.
   - Precision for each sample: `true_positives / (true_positives + false_positives)`
   - Average precision: `(total_precision / total_samples) * 100`

4. **Average Recall**:
   - Measures the proportion of relevant tables (reference tables) that are successfully predicted.
   - Calculates true positives (table names in both predicted and reference tables) and false negatives (table names in reference tables but not in predicted tables) for each sample.
   - Recall for each sample: `true_positives / (true_positives + false_negatives)`
   - Average recall: `(total_recall / total_samples) * 100`

These metrics provide different perspectives on the model's performance:
- Total accuracy focuses on exact matches.
- Filtered accuracy considers predictions that include all reference tables.
- Average precision measures the proportion of relevant predicted tables.
- Average recall measures the proportion of relevant tables successfully predicted.

By calculating these metrics, you can assess the model's ability to predict the correct table names and evaluate its overall performance.

### Train the Model for SQL Generation

After identifying the appropriate tables for SQL generation, the next step is to utilize a model that constructs the SQL query based on the question and the schema of the correct tables. Since we have already identified the potentially correct tables using the schema-linking model, there is no need to include all tables in the input for the SQL generation model.

In [48]:
MODEL="m-a-p/OpenCodeInterpreter-DS-6.7B"     # Model checkpoint on the Hugging Face Hub
SEQ_LENGTH=2100                               # Sequence length

# Training arguments
NUM_EPOCHS=1                                  # num_train_epochs
BATCH_SIZE=4                                  # batch_size
GR_ACC_STEPS=8                                # gradient_accumulation_steps
LR=5e-5                                       # learning_rate
LR_SCHEDULER_TYPE="cosine"                    # lr_scheduler_type
WEIGHT_DECAY=1e-1                             # weight_decay
EVAL_FREQ=50                                  # eval_freq
SAVE_FREQ=50                                  # save_freq
LOG_FREQ=50                                   # log_freq
OUTPUT_DIR="OCI-DS-6.7B-SQL-Gen"              # output_dir
BF16=True                                     # bf16
FP16=False                                    # no_fp16

# LORA
LORA_R=8                                      # lora_r
LORA_ALPHA=32                                 # lora_alpha
LORA_DROPOUT=0.1                              # lora_dropout
LORA_TARGET_MODULES="q_proj","v_proj","k_proj","o_proj","gate_proj","up_proj","down_proj","lm_head"    # lora_target_modules

# bitsandbytes config
USE_NESTED_QUANT=True                         # use_nested_quant
BNB_4BIT_COMPUTE_DTYPE="bfloat16"             # bnb_4bit_compute_dtype

SEED=315

We need to format the data for the SQL generation task. The `formatting_prompts_func` function below plays a crucial role by transforming a training dataset into a format that simulates a dialogue between a user and an assistant. This dataset comprises natural language questions, corresponding SQL queries, and database schemas. The function iterates through the dataset, creating structured prompts where each prompt starts with the database schema followed by the user's question. The assistant's response contains the appropriate SQL query. These formatted prompts are designed to mimic the flow of a real conversation, with the user presenting a question based on the given schema, and the assistant responding with a SQL query.

This structured formatting is crucial for training conversational models on text-to-SQL tasks, enabling them to understand the context provided by the database schema, interpret the natural language question, and generate accurate SQL queries in response.

In [None]:
def formatting_prompts_func(training_dataset):
  output_texts = []
  for i in range(len(training_dataset['question'])):
    question = training_dataset['question'][i]
    query = training_dataset['query'][i]
    database_schema = training_dataset['database_schema'][i]
    user_message = f"""Given the following SQL tables, your job is to generate the Sqlite SQL query given the user's question.
Put your answer inside the ```sql and ``` tags.
{database_schema}
###
Question: {question}
"""
    assitant_message = f"""
```sql
{query} ;
```
<|EOT|>
"""
    messages = [
      {"role": "user", "content": user_message},
      {"role": "assistant", "content": assitant_message},
      ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    output_texts.append(text)
  return output_texts

In [None]:
response_template = "### Response:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

Now we can instantiate the `LoraPlusSFTTrainer` and call the `train` method to begin fine-tuning our LLM for the SQL generation task.

In [None]:
lr_embedding_values = [1e-3, 5e-6, 1e3]

for i, lr_embedding in enumerate(lr_embedding_values):
    os.environ["WANDB_PROJECT"] = f'{OUTPUT_DIR}_{i}'
    os.environ["WANDB_API_KEY"] = 'YOUR KEY'

    train_dataset = dataset['train']
    train_dataset = train_dataset.shuffle()

    training_args = LoraPlusTrainingArguments(
        output_dir=f"jdeklerk10/{OUTPUT_DIR}_{i}",
        dataloader_drop_last=True,
        overwrite_output_dir=True,
        num_train_epochs=NUM_EPOCHS,
        evaluation_strategy="steps",
        save_strategy="steps",
        eval_steps=EVAL_FREQ,
        save_steps=SAVE_FREQ,
        logging_steps=LOG_FREQ,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        learning_rate=LR,
        lr_scheduler_type=LR_SCHEDULER_TYPE,
        warmup_ratio=0.01,
        max_grad_norm=0.3,
        group_by_length=True,
        auto_find_batch_size=False,
        gradient_accumulation_steps=GR_ACC_STEPS,
        gradient_checkpointing=True,
        save_total_limit=3,
        fp16=FP16,
        bf16=BF16,
        weight_decay=WEIGHT_DECAY,
        push_to_hub=True,
        include_tokens_per_second=True,
        loraplus_lr_ratio=1.25,  # LoRA+ learning rate for the B matrix
        loraplus_lr_embedding=lr_embedding,  # LoRA+ learning rate for the embedding matrix
        report_to='wandb',
        load_best_model_at_end=False,
    )

    trainer = LoraPlusSFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset['train'],
        eval_dataset=dataset['validation'],
        formatting_func=formatting_prompts_func,
        data_collator=collator,
        peft_config=peft_config,
        max_seq_length=2100,
        tokenizer=tokenizer,
        neftune_noise_alpha=5,  # Author's typically use 5 for smaller models from what I can see in the paper
        packing=False
    )

    trainer.train()

    output_dir = os.path.join("./", f"final_checkpoint_{i}")
    trainer.model.save_pretrained(output_dir)
    trainer.push_to_hub()

    wandb.finish()

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
50,370.6866,0.002072
100,1.4088,0.001484
150,0.0,0.001502
200,4.9692,0.001227
250,4.6045,0.00122


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▆▂▃▁▁
eval/runtime,▁█▇▅▅▇
eval/samples_per_second,█▁▂▄▄▂
eval/steps_per_second,█▁▂▄▄▁
train/epoch,▁▂▂▃▃▅▅▆▆███
train/global_step,▁▂▂▃▃▅▅▆▆███
train/grad_norm,█▃▁▁▁
train/learning_rate,█▆▄▂▁
train/loss,█▁▁▁▁
train/train_tokens_per_second,▁

0,1
eval/loss,0.00122
eval/runtime,164.6106
eval/samples_per_second,6.281
eval/steps_per_second,1.573
total_flos,3.159817699558687e+17
train/epoch,1.0
train/global_step,262.0
train/grad_norm,0.0
train/learning_rate,0.0
train/loss,4.6045


Map:   0%|          | 0/8413 [00:00<?, ? examples/s]

Map:   0%|          | 0/1034 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
50,67.8296,0.000937
100,0.111,0.001069
150,0.0,0.001075
200,0.0759,0.000934
250,4.5094,0.000904


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.022 MB of 0.023 MB uploaded\r'), FloatProgress(value=0.9747807477655155, max=1.0…

0,1
eval/loss,▂██▂▁
eval/runtime,▁█▁▂▂
eval/samples_per_second,█▁▇▇▇
eval/steps_per_second,█▁█▇▇
train/epoch,▁▁▃▃▄▄▆▆███
train/global_step,▁▁▃▃▄▄▆▆███
train/grad_norm,▃▁▁▇█
train/learning_rate,█▆▄▂▁
train/loss,█▁▁▁▁
train/train_tokens_per_second,▁

0,1
eval/loss,0.0009
eval/runtime,164.7284
eval/samples_per_second,6.277
eval/steps_per_second,1.572
total_flos,3.159817699558687e+17
train/epoch,1.0
train/global_step,262.0
train/grad_norm,0.0
train/learning_rate,0.0
train/loss,4.5094


Step,Training Loss,Validation Loss
50,5.4929,0.000919
100,0.0007,0.000885
150,0.0,0.000879
200,0.0011,0.000783
250,2.1872,0.000817


adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

VBox(children=(Label(value='0.002 MB of 0.022 MB uploaded\r'), FloatProgress(value=0.10077919841577339, max=1.…

### Merge LoRA+ Adapters for Ensemble SQL Generation

Now we are ready to merge the LoRA adapters into a single ensemble model.

In [49]:
from peft import PeftModel
import torch

OUTPUT_DIR="OCI-DS-6.7B-SQL-Gen_0"              # output_dir
device='cuda'

peft_model_id = f"jdeklerk10/{OUTPUT_DIR}"
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    attn_implementation="flash_attention_2",
    quantization_config=None,
    device_map=None,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
).cuda()

sql_model = PeftModel.from_pretrained(base_model, peft_model_id, adapter_name="sql_lr_0")

weighted_adapter_name = "sql_merged"
sql_model.load_adapter("jdeklerk10/OCI-DS-6.7B-SQL-Gen_1", adapter_name="sql_lr_1")
sql_model.load_adapter("jdeklerk10/OCI-DS-6.7B-SQL-Gen_2", adapter_name="sql_lr_2")
sql_model.add_weighted_adapter(
    adapters=["sql_lr_0", "sql_lr_1", "sql_lr_2"],
    weights=[0.2, 0.2, 0.6],
    adapter_name=weighted_adapter_name,
    combination_type="linear"
)
sql_model.set_adapter(weighted_adapter_name)
sql_model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

adapter_config.json:   0%|          | 0.00/749 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/610M [00:00<?, ?B/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaLinearScalingRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
 

In [51]:
tokenizer = AutoTokenizer.from_pretrained(peft_model_id)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

tokenizer_config.json:   0%|          | 0.00/5.21k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.37M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/329 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Inference for SQL Generation

#### Evaluation Process

The evaluation process for the SQL generaiton task begins by iterating over each row in the validation dataset. For each row, we extract the user's question, the reference SQL query, the database schema, and the database ID. We construct a user message by combining the database schema and the question, which is then tokenized and passed to the query generation model. The model generates a response containing the predicted SQL query.

Next, we process the generated response to extract the predicted SQL query. We clean the query by removing any extra whitespace and ensuring it is in the correct format. The predicted query, along with the reference query, user's question, and database ID, are collected and stored in a DataFrame for further evaluation.

During the evaluation phase, we iterate over each row of the DataFrame and compare the predicted query with the reference query. If the predicted query is empty or invalid, we prompt the user to provide the correct SQL query manually. The predicted queries and their corresponding reference queries are stored in separate files (`Predicted.txt` and `Gold.txt`) for further analysis.

The evaluation process helps us assess the performance of our query generation model by comparing the predicted queries with the reference queries. By storing the predicted and reference queries in separate files, we can calculate various evaluation metrics such as accuracy, precision, recall, and F1 score. These metrics provide insights into how well the model is able to generate accurate SQL queries based on the given user questions and database schemas.

In [52]:
from transformers import StoppingCriteria

class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence = [6203]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:,-len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids

def append_string_to_file(text, file_path):
  with open(file_path, 'a') as file:
      file.write(text + '\n')

def remove_spaces(text):
  return re.sub(r'\s+', ' ', text)


def call_mistral(inputs):
  output_tokens = sql_model.generate(inputs,
                                 max_new_tokens=250,
                                 do_sample=False,
                                 pad_token_id=tokenizer.eos_token_id,
                                 eos_token_id=tokenizer.eos_token_id,
                                 stopping_criteria = [EosListStoppingCriteria()])
  return tokenizer.decode(output_tokens[0][len(inputs[0]):], skip_special_tokens=True)

In [None]:
df = pd.read_csv("validation_dataset.csv")
results = []

for index, row in tqdm(df.iterrows(), total=len(df)):
  question = row['question']
  query = row['query']
  database_schema = row['database_schema']
  db_id = row['db_id']
  user_message = f"""Given the following SQL tables, your job is to generate the Sqlite SQL query given the user's question.
Put your answer inside the ```sql and ``` tags.
{database_schema}
###
Question: {question}
"""
  messages = [
      {"role": "user", "content": user_message.strip()}
  ]
  inputs = tokenizer.apply_chat_template(messages, return_tensors="pt",add_generation_prompt=True,tokenize = True).to(device)
  response = call_mistral(inputs)
  if ";" in response:
    response = response.split(";")[0]
    if "```sql" in response:
      response = response.split("```sql")[1]
  response = re.sub(r'\s+', ' ', response).strip()

  print("\n")
  print(response)
  print(query)
  print("============================")

  results.append([response, query,row['question'],row['db_id']])
  new_df = pd.DataFrame(results, columns = ['generated_query','reference_query','question','db_id'])

In [55]:
for index, row in new_df.iterrows():
  print(f"Processing the {index}th rows")
  if pd.isna(row['generated_query']):
    print(row['generated_query'])
    sql_query = input("give me the correct SQL query")
    sql_query = remove_spaces(sql_query)
    append_string_to_file(sql_query, "Predicted.txt")
    append_string_to_file(row['reference_query'] + "\t" + row['db_id'], "Gold.txt")
  elif row['generated_query'][:6] == "SELECT":
    append_string_to_file(remove_spaces(row['generated_query']), "Predicted.txt")
    append_string_to_file(row['reference_query'] + "\t" + row['db_id'], "Gold.txt")
  else:
    print(row['generated_query'])
    sql_query = input("give me the correct SQL query")
    sql_query = remove_spaces(sql_query)
    append_string_to_file(sql_query, "Predicted.txt")
    append_string_to_file(row['reference_query'] + "\t" + row['db_id'], "Gold.txt")

## Inference for the Ensemble Model on the BIRD Benchmark

Once the models for each task are uploaded to Hub, we can use the LoRA+ merged ensemble model for inference.

We will be using the [Bird SQL](https://bird-bench.github.io/) benchmark for evaluation. BIRD (BIg Bench for LaRge-scale Database Grounded Text-to-SQL Evaluation) represents a pioneering, cross-domain dataset that examines the impact of extensive database contents on text-to-SQL parsing. BIRD contains over 12,751 unique question-SQL pairs, 95 big databases with a total size of 33.4 GB. It also covers more than 37 professional domains, such as blockchain, hockey, healthcare and education, etc.

In [9]:
import torch
import re
import sqlite3
import pandas as pd
from transformers import AutoTokenizer
from transformers import AutoModelForCausalLM
from torch import cuda
from transformers import StoppingCriteria
from tqdm import tqdm
from sql_metadata import Parser
import json
import numpy as np
import os
import gc
from accelerate.utils import release_memory

In [19]:
BASE_DATASET_DIR = "YOUR_PATH/dev.json"
BASE_DABATASES_DIR =  "YOUR_PATH/dev/dev_databases"
OUTPUT_DIR = "predict_dev.json"

In [20]:
class EosListStoppingCriteria(StoppingCriteria):
    def __init__(self, eos_sequence=[6204, 185, 10897]):
        self.eos_sequence = eos_sequence

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        last_ids = input_ids[:, -len(self.eos_sequence):].tolist()
        return self.eos_sequence in last_ids

In [21]:
def append_item(query, db_id, counter, output_dir):
    try:
        with open(output_dir, 'r') as json_file:
            data_dict = json.load(json_file)
    except FileNotFoundError:
        data_dict = {}
    item_value = f"{query}\t----- bird -----\t{db_id}"
    data_dict[counter] = item_value
    with open(output_dir, 'w') as json_file:
        json.dump(data_dict, json_file, indent=4)

def flush():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

def remove_spaces(text):
    return re.sub(r'\s+', ' ', text)

def get_all_table_names(db_uri: str) -> list[str]:
    with sqlite3.connect(db_uri) as conn:
        cursor = conn.cursor()
        cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
        table_names = cursor.fetchall()
    return [table_name[0] for table_name in table_names]

In [22]:
def get_table_schema_with_samples(db_uri: str, table_name: str, sample_limit: int = 0, columns_description: dict[str, str] = {}) -> str:
    with sqlite3.connect(db_uri) as conn:
        cursor = conn.cursor()

        # Fetch table schema
        cursor.execute(f"PRAGMA table_info(`{table_name}`);")
        columns = cursor.fetchall()
        cursor.execute(f"PRAGMA foreign_key_list(`{table_name}`);")
        foreign_keys = cursor.fetchall()
        cursor.execute(f"PRAGMA index_list(`{table_name}`);")
        primary_key_indices = cursor.fetchall()
        primary_key_columns = []

        for index_info in primary_key_indices:
            index_name = index_info[1]
            cursor.execute(f"PRAGMA index_info(`{index_name}`);")
            index_columns = cursor.fetchall()
            primary_key_columns.extend(column[2] for column in index_columns)

        # Construct CREATE TABLE statement
        schema_str = f"CREATE TABLE `{table_name}` (\n"
        for column in columns:
            column_name = column[1]
            data_type = column[2]
            schema_str += f"  {column_name} {data_type}"
            if column_name in primary_key_columns:
                schema_str += " PRIMARY KEY"
            for foreign_key in foreign_keys:
                if column_name == foreign_key[3]:
                    schema_str += f" REFERENCES {foreign_key[2]}({foreign_key[4]})"
            if column_name in columns_description:
                schema_str += f" -- '{columns_description[column_name]}'"
            schema_str += ",\n"
        schema_str = schema_str.rstrip(",\n")
        schema_str += "\n);\n"

        if sample_limit > 0:
            cursor.execute(f"SELECT * FROM `{table_name}` LIMIT {sample_limit};")
            sample_rows = cursor.fetchall()
            if sample_rows:
                schema_str += f"Sample rows from `{table_name}`:\n"
                for row in sample_rows:
                    formatted_row = ", ".join(str(item) for item in row)
                    schema_str += f"{formatted_row}\n"

    return schema_str

In [23]:
def load_descriptions(db_path: str, table_name: str) -> dict[str, str]:
    file_path = f"{db_path}/database_description/{table_name}.csv"
    if not os.path.exists(file_path):
        return {}

    try:
        df = pd.read_csv(file_path)
    except Exception:
        return {}

    if "column_description" not in df.columns or "value_description" not in df.columns:
        return {}

    columns_description = {}
    for _, row in df.iterrows():
        if pd.notna(row["column_description"]):
            columns_description[row["original_column_name"]] = remove_spaces(row["column_description"])
            if pd.notna(row["value_description"]):
                columns_description[row["original_column_name"]] += f" has values: ({remove_spaces(row['value_description'])})"

    return columns_description

In [24]:
def generate_sql(inputs, model):
    output_tokens = model.generate(inputs, max_new_tokens=300, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria=[EosListStoppingCriteria()])
    return tokenizer.decode(output_tokens[0][len(inputs[0]):], skip_special_tokens=True)

def generate_schema(inputs, model):
    output_tokens = model.generate(inputs, max_new_tokens=250, do_sample=False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, stopping_criteria=[EosListStoppingCriteria()])
    return tokenizer.decode(output_tokens[0][len(inputs[0]):], skip_special_tokens=True)

def print_tokens_with_ids(txt):
    tokens = tokenizer.tokenize(txt, add_special_tokens=False)
    token_ids = tokenizer.encode(txt, add_special_tokens=False)
    print(list(zip(tokens, token_ids)))

In [27]:
def load_data(file_path):
    return pd.read_json(file_path)

def get_user_message(database_schema, question):
    return f"""Given the following SQL tables, your job is to determine the columns and tables that the question is referring to.

    {database_schema}

    ####

    Question: {question}

    """

def generate_schema_linking(messages, tokenizer, schema_model):
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, tokenize=True).to(schema_model.device)
    response = generate_schema(inputs, schema_model)

    if "Tables: " in response:
        response = response.split("Tables: ")[1]

    if ";" in response:
        response = response.split(";")[0]

    schema_linking_tables = re.sub(r'\s+', ' ', response).strip()
    return schema_linking_tables

def get_database_schema(db_uri, db_path, schema_linking_tables):
    database_schema = ""

    try:
        for table in schema_linking_tables:
            table = table.replace("**", "").replace("--", "").replace("'", "").strip()
            database_schema += get_table_schema_with_samples(db_uri, table)
            database_schema += "\n"
    except Exception:
        database_schema = ""
        print(f"Table not found {schema_linking_tables}")

    if not database_schema:
        all_tables = get_all_table_names(db_uri)
        for table in all_tables:
            columns_description = load_descriptions(db_path, table)
            database_schema = get_table_schema_with_samples(db_uri, table, 0, columns_description)
            database_schema += "\n"

    return database_schema

def generate_sql(messages, tokenizer, sql_model):
    inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True, tokenize=True).to(sql_model.device)
    response = generate_sql(inputs, sql_model)

    if ";" in response:
        response = response.split(";")[0]

    if "```sql" in response:
        response = response.split("```sql")[1]

    response = re.sub(r'\s+', ' ', response).strip()
    return response

def process_row(row, tokenizer, schema_model, sql_model):
    db_id = row['db_id']
    query = row['SQL']
    question = row['question']

    if pd.notna(row['evidence']):
        question += " Hint: " + row['evidence']

    db_uri = f"{BASE_DABATASES_DIR}/{db_id}/{db_id}.sqlite"
    db_path = f"{BASE_DABATASES_DIR}/{db_id}"

    table_names = get_all_table_names(db_uri)
    database_schema = ""

    for table_name in table_names:
        columns_description = load_descriptions(db_path, table_name)
        schema = get_table_schema_with_samples(db_uri, table_name, 0, columns_description)
        database_schema += schema + "\n"

    user_message = get_user_message(database_schema, question)
    messages = [{"role": "user", "content": user_message.strip()}]

    schema_linking_tables = generate_schema_linking(messages, tokenizer, schema_model)
    print(f"Predicted schema: {schema_linking_tables}")

    try:
        print(f"Original schema: {Parser(query).tables}")
    except Exception:
        pass

    schema_linking_tables = schema_linking_tables.split(", ")
    database_schema = get_database_schema(db_uri, db_path, schema_linking_tables)

    result = {
        "question": question,
        "db_id": db_id,
        "query": query,
        "database_schema": database_schema,
    }

    return result

def main():
    results = []
    df = load_data(BASE_DATASET_DIR)

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Schema Linking"):
        result = process_row(row, tokenizer, schema_model, sql_model)
        results.append(result)

    release_memory(schema_model)
    del schema_model

    for index, row in tqdm(enumerate(results), total=len(results), desc="Generating SQL"):
        query = row['query']
        db_id = row['db_id']
        question = row['question']
        database_schema = row['database_schema']

        user_message = get_user_message(database_schema, question)
        messages = [{"role": "user", "content": user_message.strip()}]

        response = generate_sql(messages, tokenizer, sql_model)

        if "SELECT" not in response:
            schema_linking_tables = row['database_schema'].split("\n")[0].split(", ")
            response = "SELECT * FROM " + schema_linking_tables[0]

        print(f"Predicted: {response}")
        print(f"Gold: {query}")

        append_item(response, db_id, index, OUTPUT_DIR)

if __name__ == "__main__":
    main()

## Conclusion

In this project, we embarked on an ambitious journey to enhance text-to-SQL generation capabilities through the development of OracleCoder, leveraging the innovative QLoRA+ ensemble approach and schema linking. By decomposing the complex text-to-SQL task into schema linking and SQL generation sub-tasks, we were able to fine-tune smaller, open-source language models efficiently, achieving performance on par with larger proprietary models. This not only addresses concerns around data privacy and cost but also democratizes access to high-quality text-to-SQL technology.

Our experiments with the Spider dataset, a challenging benchmark for text-to-SQL tasks, underscored the potential of our approach. By employing a two-phase fine-tuning strategy, incorporating LoRA+ for efficient adaptation, and exploring the use of noisy embeddings, we aimed to push the boundaries of what's possible with smaller models.

The successful training of multiple models with varying learning rates for the LoRA adapters, followed by the creation of an ensemble model, demonstrates a novel way to enhance model performance and reliability. This ensemble approach, adapted for LLMs, showcases a path forward for achieving high accuracy and robustness in text-to-SQL generation without the computational overhead typically associated with large ensembles.

As we conclude this project, it's clear that the journey to perfect text-to-SQL generation is ongoing. However, OracleCoder represents a significant step forward, offering a scalable, efficient, and accessible solution. Future work will explore further optimizations, additional ensemble strategies, and the integration of more diverse datasets to continue improving performance and generalizability.

This project not only contributes to the field of semantic parsing and text-to-SQL generation but also exemplifies the power of community-driven, open-source innovation in advancing the state of the art in natural language processing and database interaction.