# Fine-Tuning CodeLlama on Gretel's Synthetic Text-to-SQL Dataset and AWS JumpStart

This notebook demonstrates how to use the SageMaker Python SDK to fine-tune the pre-trained CodeLlama-13B model on [Gretel's synthetic text-to-sql](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) dataset. 

The notebook requires a `ml.g5.24xlarge` instance for training job usage. If you encounter an error message that you've exceeded your quota, use the Service Quotas console to request an increase. For instructions on how to request a quota increase, see [Requesting a quota increase](https://docs.aws.amazon.com/servicequotas/latest/userguide/request-quota-increase.html).

## Setup

### Install Necessary Packages
Please restart the kernel after executing the cell below for the first time.

In [None]:
%pip install --upgrade --quiet datasets transformers func_timeout

### Import Libraries

In [None]:
import json
import os
import pandas as pd
from datasets import load_dataset
from IPython.display import display
from IPython.display import HTML
from sagemaker import Session
from sagemaker.jumpstart.estimator import JumpStartEstimator
from sagemaker.s3 import S3Uploader
from transformers import AutoTokenizer
from tqdm.notebook import tqdm

### Select Model
Select your desired model ID. You can search for available models in the [Built-in Algorithms with pre-trained Model Table](https://sagemaker.readthedocs.io/en/stable/doc_utils/pretrainedmodels.html).

In [None]:
model_id = "meta-textgeneration-llama-codellama-13b"

## Dataset preparation

### Load Dataset
We will use the [synthetic text-to-SQL dataset](https://huggingface.co/datasets/gretelai/synthetic_text_to_sql) provided by Gretel. The dataset is a rich dataset of high-quality synthetic Text-to-SQL samples, designed and generated using Gretel Navigator, and released under Apache 2.0. 

The dataset contains a total of 105,851 records partitioned into 100,000 train and 5,851 test records. From the ~23M tokens, there are ~12M SQL tokens. The SQL queries have coverage across 100 distinct domains/verticals, various SQL tasks, such as data definition, retrieval, manipulation, analytics & reporting, at a wide range of SQL complexity levels, including subqueries, single joins, multiple joins, aggregations, window functions, set operations. 


In [None]:
gretel_text_to_sql = load_dataset("gretelai/synthetic_text_to_sql")
gretel_text_to_sql["train"].to_json("train.jsonl")

### Prompt Template
Create a template for using the data in an instruction format for the training job. This template will also be used during model inference.

In [None]:
template = {
    "prompt": (
        "[INST] Write a SQL query that answers the following question based on the given database schema and any additional information provided. Use SQLite syntax.\n\n"
        "[SCHEMA] {sql_context}\n\n"
        "[KNOWLEDGE] This is an '{sql_task_type}' task, commonly used for {sql_task_type_description}. In the domain of {domain}, which involves {domain_description}.\n\n"
        "[QUESTION] {sql_prompt}\n\n"
        "[/INST]"
    ),
    "completion": "```{sql}```\n\n\n{sql_explanation}\n",
}

with open("template.json", "w") as f:
    json.dump(template, f)

display(template)

### Upload Training Data and Template to S3

In [None]:
session = Session()
output_bucket = session.default_bucket()
local_data_file = "train.jsonl"
train_data_location = f"s3://{output_bucket}/gretel_text_to_sql"
S3Uploader.upload(local_data_file, train_data_location)
S3Uploader.upload("template.json", train_data_location)
print(f"Training data: {train_data_location}")

## Model Training

### Set Hyperparameters
Define the hyperparameters for fine-tuning the model. By default, the models will train via domain adaptation, so you must indicate instruction tuning through the `instruction_tuned` hyperparameter.


In [None]:
hyperparameters = {
    "instruction_tuned": "True",
    "epoch": "1",
    "learning_rate": "0.0002",
    "lora_r": "8",
    "lora_alpha": "32",
    "lora_dropout": "0.05",
    "int8_quantization": "False",
    "enable_fsdp": "True",
    "per_device_train_batch_size": "4",
    "per_device_eval_batch_size": "2",
    "max_train_samples": "-1",
    "max_val_samples": "-1",
    "max_input_length": 512,
    "validation_split_ratio": "0.2",
    "train_data_split_seed": "0",
}

# Setup the estimator with the generated model name
estimator = JumpStartEstimator(
    model_id=model_id,
    environment={"accept_eula": "true"},  # Accept EULA for gated models
    disable_output_compression=True,
    hyperparameters=hyperparameters,
    sagemaker_session=session,
)

In [None]:
estimator.fit({"training": train_data_location})

Due to the potential complexity and duration of the training job, it may take several hours to complete. During this period, it is possible that the Python kernel might time out and disconnect. However, the training process will continue to run in SageMaker uninterrupted.

If you encounter a disconnection, you can still proceed with deploying your trained model by using the training job name. Follow these steps to locate and use your training job name:

1. Navigate to the AWS Management Console.
2. Select SageMaker.
3. Go to Training Jobs under the Training section.
4. Locate your specific training job and copy the training job name.

Once you have the training job name, you can use the following Python code to attach to the existing training job, monitor the logs, and deploy your model:

```python
from sagemaker.jumpstart.estimator import JumpStartEstimator

# Replace '<<training_job_name>>' with your actual training job name
training_job_name = '<<training_job_name>>'

# Attach to the existing training job
attached_estimator = JumpStartEstimator.attach(training_job_name)

# Optional: View logs
attached_estimator.logs()

# Deploy the trained model
attached_estimator.deploy()

```

## Model Deployment and Invocation

Deploy the fine-tuned model to an endpoint directly from the estimator and invoke the model.

In [None]:
predictor = estimator.deploy()

In [None]:
def prompt_and_predict(prompt_input, parameters):
    """
    Generates a SQL query based on the given prompt input using a language model.

    Args:
        prompt_input (dict): A dictionary containing the prompt input with the following keys:
            - schema (str): The database schema.
            - question (str): The question to be answered.
            - knowledge (str): Additional knowledge or context.
            - database (str): Name of the database.
        parameters (dict): Additional parameters for model prediction.

    Returns:
        str: The generated SQL query along with the database name, separated by '----- bird -----'.
    """
    # Extract inputs from the prompt_input dictionary
    sql_context = prompt_input["schema"]  # Extract schema
    sql_prompt = prompt_input["question"]  # Extract question
    knowledge = prompt_input["knowledge"]  # Extract knowledge
    database = prompt_input["database"]  # Extract database

    # Construct the prompt string with schema, knowledge, and question
    prompt = (
        f"[INST] Write a SQL query that answers the following question based on the given database schema and any additional information provided. Use SQLite syntax.\n\n"
        f"[SCHEMA] {sql_context}\n\n"
        f"[KNOWLEDGE] {knowledge}\n\n"
        f"[QUESTION] {sql_prompt}\n\n"
        f"[/INST]"
    )

    # Prepare payload for prediction with prompt and parameters
    payload = {"inputs": prompt, "parameters": parameters}

    # Get prediction from the model
    response = predictor.predict(payload)
    response = response[0] if isinstance(response, list) else response

    # Split the generated text into code blocks
    code_blocks = response["generated_text"].strip().split("```")

    # Find the SQL query code block
    sql_query = ""
    for code_block in code_blocks:
        if "SELECT" in code_block:
            sql_query = code_block

    # Clean up the SQL query
    sql_query = sql_query.replace("\n", " ").strip()

    # Construct the output string with SQL query and database
    output = f"{sql_query}\t----- bird -----\t{database}"

    return output

In [None]:
# Define parameters for model prediction
parameters = {
    "max_new_tokens": 256,  # Maximum number of tokens to generate
    "top_p": 0.9,  # Top-p sampling probability
    "temperature": 0.1,  # Sampling temperature
    "decoder_input_details": True,  # Include decoder input details
    "details": True,  # Include additional details
}

# Load prompts from a JSON file
list_of_prompts = pd.read_json(
    "https://gretel-public-website.s3.us-west-2.amazonaws.com/bird-bench/prompts-dev.json"
)

# Iterate over each prompt and predict a response
responses = []
for question in tqdm(list_of_prompts.columns, desc=f"Prompting model"):
    prompt_values = list_of_prompts[question].to_dict()
    response = prompt_and_predict(prompt_values, parameters)
    responses.append(response)

# Convert responses to DataFrame
df = pd.DataFrame(responses, columns=["response"])

In [None]:
# Define the directory path for saving predicted SQL responses
predicted_sql_dir = f"{model_id}-FT-gretel-text-to-sql"

# Create the directory if it doesn't exist
if not os.path.exists(predicted_sql_dir):
    os.makedirs(predicted_sql_dir)

# Define the file path for saving the JSON file
predicted_sql_path = f"{predicted_sql_dir}/predict_dev.json"

# Save the responses DataFrame to a JSON file
df["response"].to_json(predicted_sql_path, orient="index", indent=4)

# Print the path where the output is saved
print(f"Output saved to {predicted_sql_path}")

## Run Bird-Benchmark

To run the Bird-Benchmark tool on the generated SQL queries you can follow the steps [here](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird).

1. Clone the repo https://github.com/AlibabaResearch/DAMO-ConvAI.git
2. Download the DEV dataset from: https://bird-bench.github.io/
3. Run the eval script
   

In [None]:
import os

# Clone the repository if the directory does not exist
if not os.path.isdir("DAMO-ConvAI"):
    !git clone https://github.com/AlibabaResearch/DAMO-ConvAI.git

base_bird_dir = "DAMO-ConvAI/bird/llm"

if not os.path.exists(f"{base_bird_dir}/data/dev.zip"):
    !wget -P {base_bird_dir}/data https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip --no-check-certificate
    # Unzip the downloaded file in the same directory
    !unzip {base_bird_dir}/data/dev.zip -d {base_bird_dir}/data
    !unzip {base_bird_dir}/data/dev/dev_databases.zip -d {base_bird_dir}/data/dev
    # Rename ground truth sql
    !cp {base_bird_dir}/data/dev/dev.sql {base_bird_dir}/data/dev/dev_gold.sql

In [None]:
import shutil
import os

# Specify the source file name
source_file = "codellama-13B-FT-Gretel.json"

# Specify the target directory and the new file name
predicted_sql_dir, ext = os.path.splitext(source_file)
target_file = os.path.join(predicted_sql_dir, "predict_dev.json")

# Ensure the target directory exists, create if it doesn't
os.makedirs(predicted_sql_dir, exist_ok=True)

# Copy and rename the source file to the new location
shutil.copy(source_file, target_file)

In [None]:
import subprocess


# Define the common arguments
eval_args = [
    "--db_root_path",
    f"{base_bird_dir}/data/dev/dev_databases/",
    "--predicted_sql_path",
    predicted_sql_dir + "/",
    "--data_mode",
    "dev",
    "--ground_truth_path",
    f"{base_bird_dir}/data/dev/",
    "--num_cpus",
    "4",
    "--diff_json_path",
    f"{base_bird_dir}/data/dev/dev.json",
]

print("Execution Accuracy (EX) metric on Dev with Knowledge")
command = ["python3", "-u", f"{base_bird_dir}/src/evaluation.py"] + eval_args
output = subprocess.run(command, capture_output=True, text=True)
print(output.stdout)

print("Valid Efficiency Score (VES) metric on Dev with Knowledge")
command = [
    "python3",
    "-u",
    f"{base_bird_dir}/src/evaluation_ves.py",
] + eval_args
output = subprocess.run(command, capture_output=True, text=True)
print(output.stdout)

## Clean up
Make sure to clean up resources to avoid unnecessary charges.

In [None]:
predictor.delete_predictor()