# 1. Setup Development Environment


In [18]:
!pip install -q -U trl
!pip install -q -U sagemaker
!pip install -q -U "datasets[s3]"
!pip install -q -U "huggingface_hub[cli]"

In [5]:
!pip show datasets
!pip show pandas

Name: datasets
Version: 2.20.0
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: thomas@huggingface.co
License: Apache 2.0
Location: /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages
Requires: aiohttp, dill, filelock, fsspec, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyarrow-hotfix, pyyaml, requests, tqdm, xxhash
Required-by: trl
Name: pandas
Version: 1.5.3
Summary: Powerful data structures for data analysis, time series, and statistics
Home-page: https://pandas.pydata.org
Author: The Pandas Development Team
Author-email: pandas-dev@python.org
License: BSD-3-Clause
Location: /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages
Requires: numpy, python-dateutil, pytz
Required-by: autovizwidget, bokeh, datasets, hdijupyterutils, nvgpu, sagemaker, seaborn, shap, smclarify, sparkmagic, statsmodels


In [6]:
# !huggingface-cli login --token YOUR_TOKEN

In [5]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [29]:
import sagemaker
import boto3

sess = sagemaker.Session()
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="AmazonSageMaker-ExecutionRole-20230112T181165")[
        "Role"
    ]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::395271362395:role/SagemakerStudioDemoSagema-SageMakerExecutionRole78-5I33I083KE6P
sagemaker bucket: sagemaker-us-east-1-395271362395
sagemaker session region: us-east-1


# 2. Create and prepare the dataset

In [30]:
from datasets import load_dataset

dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train")
dataset

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 100000
})

In [31]:
df = dataset.to_pandas()
df.head(10)

Unnamed: 0,id,domain,domain_description,sql_complexity,sql_complexity_description,sql_task_type,sql_task_type_description,sql_prompt,sql_context,sql,sql_explanation
0,5097,forestry,Comprehensive data on sustainable forest manag...,single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total volume of timber sold by eac...,"CREATE TABLE salesperson (salesperson_id INT, ...","SELECT salesperson_id, name, SUM(volume) as to...","Joins timber_sales and salesperson tables, gro..."
1,5098,defense industry,"Defense contract data, military equipment main...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List all the unique equipment types and their ...,CREATE TABLE equipment_maintenance (equipment_...,"SELECT equipment_type, SUM(maintenance_frequen...",This query groups the equipment_maintenance ta...
2,5099,marine biology,"Comprehensive data on marine species, oceanogr...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",How many marine species are found in the South...,"CREATE TABLE marine_species (name VARCHAR(50),...",SELECT COUNT(*) FROM marine_species WHERE loca...,This query counts the number of marine species...
3,5100,financial services,Detailed financial data including investment s...,aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total trade value and average pric...,"CREATE TABLE trade_history (id INT, trader_id ...","SELECT trader_id, stock, SUM(price * quantity)...",This query calculates the total trade value an...
4,5101,energy,Energy market data covering renewable energy s...,window functions,"window functions (e.g., ROW_NUMBER, LEAD, LAG,...",analytics and reporting,"generating reports, dashboards, and analytical...",Find the energy efficiency upgrades with the h...,"CREATE TABLE upgrades (id INT, cost FLOAT, typ...","SELECT type, cost FROM (SELECT type, cost, ROW...",The SQL query uses the ROW_NUMBER function to ...
5,5102,defense operations,"Defense data on military innovation, peacekeep...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",What is the total spending on humanitarian ass...,CREATE SCHEMA if not exists defense; CREATE TA...,SELECT SUM(spending) FROM defense.eu_humanitar...,This SQL query calculates the total spending o...
6,5103,aquaculture,"Aquatic farming data, fish stock management, o...",single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the average water temperature for each...,"CREATE TABLE SpeciesWaterTemp (SpeciesID int, ...","SELECT SpeciesName, AVG(WaterTemp) as AvgTemp ...",This query calculates the average water temper...
7,5104,nonprofit operations,"Donation records, program outcomes, volunteer ...",basic SQL,basic SQL with a simple select statement,data manipulation,"inserting, updating, or deleting records",Delete a program's outcome data,"CREATE TABLE Program_Outcomes (id INT, program...",DELETE FROM Program_Outcomes WHERE program_id ...,This query removes the record with a program_i...
8,5105,public transportation,"Extensive data on route planning, fare collect...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",Find the total fare collected from passengers ...,CREATE TABLE bus_routes (route_name VARCHAR(50...,SELECT SUM(fare) FROM bus_routes WHERE route_n...,This SQL query calculates the total fare colle...
9,5106,real estate,Real estate data on inclusive housing policies...,basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",What is the average property size in inclusive...,CREATE TABLE Inclusive_Housing (Property_ID IN...,SELECT AVG(Property_Size) FROM Inclusive_Housi...,The SQL query calculates the average property ...


In [32]:
dataset = dataset.shuffle(seed=1234).select(range(12500))

In [33]:
def generate_prompt(datum):
    prompt = f"""
<start_of_turn>user
You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA: {datum["sql_context"]}

{datum["sql_prompt"]}<end_of_turn>
<start_of_turn>model
{datum["sql"]}<end_of_turn>
""".strip()
    return prompt

In [34]:
columns_to_remove = list(dataset.features)
for k in ['sql', 'sql_context', 'sql_prompt']:
    columns_to_remove.remove(k)
print(f'remove columns: {columns_to_remove}')

remove columns: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_explanation']


In [35]:
prompt_column = [generate_prompt(datum) for datum in dataset]
dataset = dataset.add_column("prompt", prompt_column)
dataset = dataset.remove_columns(columns_to_remove)
dataset

Dataset({
    features: ['sql_prompt', 'sql_context', 'sql', 'prompt'],
    num_rows: 12500
})

In [36]:
dataset = dataset.train_test_split(test_size=0.2)
dataset

DatasetDict({
    train: Dataset({
        features: ['sql_prompt', 'sql_context', 'sql', 'prompt'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['sql_prompt', 'sql_context', 'sql', 'prompt'],
        num_rows: 2500
    })
})

In [37]:
print(dataset["train"][345]['prompt'])

<start_of_turn>user
You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA: CREATE TABLE volunteers (id INT, name VARCHAR(50), reg_date DATE, location VARCHAR(30)); INSERT INTO volunteers (id, name, reg_date, location) VALUES (1, 'Alex', '2023-02-01', 'urban'), (2, 'Bella', '2023-01-15', 'rural'), (3, 'Charlie', '2023-03-05', 'suburban'), (4, 'Diana', '2022-07-20', 'rural'), (5, 'Eli', '2022-09-01', 'urban');

What is the total number of volunteers from urban areas who have volunteered in the last 6 months?<end_of_turn>
<start_of_turn>model
SELECT COUNT(*) FROM volunteers WHERE location = 'urban' AND reg_date >= DATEADD(month, -6, GETDATE());<end_of_turn>


In [38]:
# save train_dataset to s3 using our SageMaker session
training_input_path = f"s3://{sess.default_bucket()}/datasets/text-to-sql"

# save datasets to s3
dataset["train"].to_json(f"{training_input_path}/train_dataset.json", orient="records")
dataset["test"].to_json(f"{training_input_path}/test_dataset.json", orient="records")

print(f"Training data uploaded to:")
print(f"{training_input_path}/train_dataset.json")
print(
    f"https://s3.console.aws.amazon.com/s3/buckets/{sess.default_bucket()}/?region={sess.boto_region_name}&prefix={training_input_path.split('/', 3)[-1]}/"
)

Creating json from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Training data uploaded to:
s3://sagemaker-us-east-1-395271362395/datasets/text-to-sql/train_dataset.json
https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-395271362395/?region=us-east-1&prefix=datasets/text-to-sql/


# 3. Fine-Tune Gemma2 with QLoRA on Amazon Sagemaker


In [51]:
from dataclasses import dataclass, field, fields, asdict
from trl import SFTConfig


@dataclass
class PartialTrainingArguments:
    num_train_epochs: int
    per_device_train_batch_size: int
    gradient_accumulation_steps: int
    gradient_checkpointing: bool
    optim: str
    logging_steps: int
    save_strategy: int
    learning_rate: float
    bf16: bool
    tf32: bool
    max_grad_norm: float
    warmup_ratio: float
    lr_scheduler_type: str
    report_to: str
    output_dir: str
    # SFTrainer Config
    dataset_text_field: str
    max_seq_length: int


@dataclass
class Hyperparameters(PartialTrainingArguments):
    # path where sagemaker will save training dataset
    dataset_path: str
    model_id: str
    use_qlora: bool
    merge_adapters: bool

    def to_dict(self):
        return asdict(self)


# Validate training arguments
training_args_fields = {field.name for field in fields(SFTConfig)}
partial_training_args_fields = {
    field.name for field in fields(PartialTrainingArguments)
}
is_subset = partial_training_args_fields.issubset(training_args_fields)
assert (
    is_subset
), "All fields in PartialTrainingArguments should be in SFTConfig"

In [52]:
hyperparameters = Hyperparameters(
    ### SagemakerArguments ###
    dataset_path="/opt/ml/input/data/training/train_dataset.json",
    model_id="google/gemma-2-9b",
    use_qlora=True,
    merge_adapters=True,
    ### TrainingArguments ###
    num_train_epochs=3,  # number of training epochs
    per_device_train_batch_size=1,  # batch size per device during training
    gradient_accumulation_steps=4,  # number of steps before performing a backward/update pass
    gradient_checkpointing=True,  # use gradient checkpointing to save memory
    optim="paged_adamw_8bit",  # use paged_adamw_8bit optimizer
    logging_steps=10,  # log every 10 steps
    save_strategy="epoch",  # save checkpoint every epoch
    learning_rate=2e-4,  # learning rate, based on QLoRA paper
    bf16=True,  # use bfloat16 precision
    tf32=True,  # use tf32 precision
    max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    report_to="tensorboard",  # report metrics to tensorboard
    output_dir="/tmp/tun",  # Temporary output directory for model checkpoints
    ### SFTrainer Config ###
    dataset_text_field="prompt",  # dataset prompt key
    max_seq_length=1024,
).to_dict()
print(hyperparameters)

{'num_train_epochs': 3, 'per_device_train_batch_size': 1, 'gradient_accumulation_steps': 4, 'gradient_checkpointing': True, 'optim': 'paged_adamw_8bit', 'logging_steps': 10, 'save_strategy': 'epoch', 'learning_rate': 0.0002, 'bf16': True, 'tf32': True, 'max_grad_norm': 0.3, 'warmup_ratio': 0.03, 'lr_scheduler_type': 'constant', 'report_to': 'tensorboard', 'output_dir': '/tmp/tun', 'dataset_text_field': 'prompt', 'max_seq_length': 1024, 'dataset_path': '/opt/ml/input/data/training/train_dataset.json', 'model_id': 'google/gemma-2-9b', 'use_qlora': True, 'merge_adapters': True}


In [59]:
from sagemaker.huggingface import HuggingFace

# define Training Job Name
job_name = f"gemma-9b-it-text-to-sql"

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point="run_sft.py",  # train script
    source_dir="./scripts",  # directory which includes all the files needed for training
    instance_type="ml.g5.8xlarge",  # instances type used for the training job
    instance_count=1,  # the number of instances used for training
    max_run=2
    * 24
    * 3600,  # maximum runtime in seconds (days * hours * minutes * seconds)
    base_job_name=job_name,  # the name of the training job
    role=role,  # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size=300,  # the size of the EBS volume in GB
    transformers_version="4.36",  # the transformers version used in the training job
    pytorch_version="2.1",  # the pytorch_version version used in the training job
    py_version="py310",  # the python version used in the training job
    hyperparameters=hyperparameters,  # the hyperparameters passed to the training job
    disable_output_compression=True,  # not compress output to save training time and cost
    environment={
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache",  # set env variable to cache models in /tmp
        "HF_TOKEN": "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",  # huggingface token to access gated models, e.g. llama 2
    },
)

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {"training": training_input_path}

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit(data, wait=True)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: gemma-9b-it-text-to-sql-2024-07-18-05-52-35-160


2024-07-18 05:52:35 Starting - Starting the training job...
2024-07-18 05:52:51 Starting - Preparing the instances for training...
2024-07-18 05:53:25 Downloading - Downloading input data...
2024-07-18 05:53:40 Downloading - Downloading the training image.....................
2024-07-18 05:57:08 Training - Training image download completed. Training in progress..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2024-07-18 05:57:33,139 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2024-07-18 05:57:33,156 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-07-18 05:57:33,167 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2024-07-18 05:57:33,169 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2024-07-18 05:57:34,61