# 1. Setup Development Environment


!pip install -q -U trl
!pip install -q -U sagemaker
!pip install -q -U "datasets[s3]"
!pip install -q -U "huggingface_hub[cli]"


In [5]:
!pip show datasets
!pip show pandas

Name: datasets
Version: 2.20.0
Summary: HuggingFace community-driven open-source library of datasets
Home-page: https://github.com/huggingface/datasets
Author: HuggingFace Inc.
Author-email: thomas@huggingface.co
License: Apache 2.0
Location: /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages
Requires: aiohttp, dill, filelock, fsspec, huggingface-hub, multiprocess, numpy, packaging, pandas, pyarrow, pyarrow-hotfix, pyyaml, requests, tqdm, xxhash
Required-by: trl
Name: pandas
Version: 1.5.3
Summary: Powerful data structures for data analysis, time series, and statistics
Home-page: https://pandas.pydata.org
Author: The Pandas Development Team
Author-email: pandas-dev@python.org
License: BSD-3-Clause
Location: /home/ec2-user/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages
Requires: numpy, python-dateutil, pytz
Required-by: autovizwidget, bokeh, datasets, hdijupyterutils, nvgpu, sagemaker, seaborn, shap, smclarify, sparkmagic, statsmodels


In [1]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [11]:
import sagemaker
import boto3

sess = sagemaker.Session()
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client("iam")
    role = iam.get_role(RoleName="AmazonSageMaker-ExecutionRole-20230112T181165")[
        "Role"
    ]["Arn"]

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

sagemaker role arn: arn:aws:iam::395271362395:role/SagemakerStudioDemoSagema-SageMakerExecutionRole78-5I33I083KE6P
sagemaker bucket: sagemaker-us-east-1-395271362395
sagemaker session region: us-east-1


# 2. Create and prepare the dataset


In [15]:
from datasets import load_dataset

dataset = load_dataset("gretelai/synthetic_text_to_sql", split="train")
dataset = dataset.shuffle(seed=1234).select(range(12500))
dataset

Dataset({
    features: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation'],
    num_rows: 12500
})

In [16]:
df = dataset.to_pandas()
df.head(10)

Unnamed: 0,id,domain,domain_description,sql_complexity,sql_complexity_description,sql_task_type,sql_task_type_description,sql_prompt,sql_context,sql,sql_explanation
0,65582,disability services,Comprehensive data on disability accommodation...,basic SQL,basic SQL with a simple select statement,data manipulation,"inserting, updating, or deleting records",Update the budget for the 'ASL Interpreter' se...,"CREATE TABLE Regions (RegionID INT, RegionName...",UPDATE SupportServices SET Budget = 16000 WHER...,This query updates the budget for the 'ASL Int...
1,83180,climate change,"Climate change data on climate mitigation, cli...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",Find the intersection of mitigation and adapta...,"CREATE TABLE mitigation (id INT PRIMARY KEY, c...","SELECT m.action FROM mitigation m, adaptation ...",This SQL query identifies the intersection of ...
2,90518,marine biology,"Comprehensive data on marine species, oceanogr...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",List all marine species with a conservation st...,"CREATE TABLE species (id INT, name VARCHAR(255...",SELECT name FROM species WHERE conservation_st...,The SQL query filters the species table for ro...
3,42346,rural development,"Agricultural innovation metrics, rural infrast...",basic SQL,basic SQL with a simple select statement,analytics and reporting,"generating reports, dashboards, and analytical...",Find the minimum budget for agricultural innov...,"CREATE TABLE agricultural_innovation (id INT, ...",SELECT MIN(budget) FROM agricultural_innovation;,The SQL query calculates the minimum budget fo...
4,86672,retail,"Retail data on circular supply chains, ethical...",single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",What is the maximum price of a product sold by...,"CREATE TABLE vendors(vendor_id INT, vendor_nam...",SELECT MAX(transactions.price) FROM transactio...,The SQL query calculates the maximum price of ...
5,65425,arctic research,"In-depth data on climate change, biodiversity,...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the average temperature recorded in ea...,CREATE TABLE WeatherData (Station VARCHAR(255)...,"SELECT Station, AVG(Temperature) FROM WeatherD...",This SQL query calculates the average temperat...
6,62717,arts and culture,"Audience demographics, event attendance, progr...",subqueries,"subqueries, including correlated and nested su...",data manipulation,"inserting, updating, or deleting records",Delete records of artists who have not partici...,"CREATE TABLE Artists (artist_id INT, artist_na...",DELETE FROM Artists WHERE artist_id NOT IN (SE...,This query deletes records of artists from the...
7,10921,retail,"Retail data on circular supply chains, ethical...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",List all suppliers that have provided both rec...,"CREATE TABLE suppliers (supplier_id INT, suppl...",SELECT supplier_name FROM suppliers WHERE mate...,This query identifies all suppliers that have ...
8,12727,manufacturing,"Detailed records on ethical manufacturing, cir...",aggregation,"aggregation functions (COUNT, SUM, AVG, MIN, M...",analytics and reporting,"generating reports, dashboards, and analytical...",What is the total number of employees working ...,CREATE TABLE ethical_manufacturing (country VA...,"SELECT country, SUM(employees) as total_employ...",This query calculates the total number of empl...
9,73714,sports entertainment,"Sports team performance data, fan demographics...",single join,"only one join (specify inner, outer, cross)",analytics and reporting,"generating reports, dashboards, and analytical...",List the total number of tickets sold for home...,"CREATE TABLE teams (team_id INT, team_name VAR...","SELECT t.team_name, SUM(g.price * g.attendance...","Join teams and games tables, filter on games f..."


In [17]:
def generate_prompt(datum):
    prompt = f"""
<start_of_turn>user
You are an text to SQL query translator. Users will ask you questions in English and you will generate a SQL query based on the provided SCHEMA.
SCHEMA: {datum["sql_context"]}
{datum["sql_prompt"]}<end_of_turn>
<start_of_turn>model
{datum["sql"]}<end_of_turn>
""".strip()
    return prompt

In [18]:
from transformers import AutoTokenizer

model_id = "google/gemma-2-9b"
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    add_eos_token=True,
)

In [19]:
prompt_column = [generate_prompt(datum) for datum in dataset]
dataset = dataset.add_column("prompt", prompt_column)

In [20]:
columns_to_remove = list(dataset.features)
for k in ["prompt"]:
    columns_to_remove.remove(k)
print(f"remove columns: {columns_to_remove}")

remove columns: ['id', 'domain', 'domain_description', 'sql_complexity', 'sql_complexity_description', 'sql_task_type', 'sql_task_type_description', 'sql_prompt', 'sql_context', 'sql', 'sql_explanation']


In [21]:
dataset = dataset.map(
    lambda samples: tokenizer(samples["prompt"]),
    batched=True,
    remove_columns=columns_to_remove,
)
dataset = dataset.train_test_split(test_size=0.2)
dataset

DatasetDict({
    train: Dataset({
        features: ['prompt', 'input_ids', 'attention_mask'],
        num_rows: 10000
    })
    test: Dataset({
        features: ['prompt', 'input_ids', 'attention_mask'],
        num_rows: 2500
    })
})

In [22]:
print(dataset["train"][345]["input_ids"][-1] == tokenizer.eos_token_id)

True


In [23]:
# save train_dataset to s3 using our SageMaker session
training_input_path = f"s3://{sess.default_bucket()}/datasets/text-to-sql"

# save datasets to s3
dataset["train"].to_json(f"{training_input_path}/train_dataset.json", orient="records")
dataset["test"].to_json(f"{training_input_path}/test_dataset.json", orient="records")

print(f"Training data uploaded to:")
print(f"{training_input_path}/train_dataset.json")
print(
    f"https://s3.console.aws.amazon.com/s3/buckets/{sess.default_bucket()}/?region={sess.boto_region_name}&prefix={training_input_path.split('/', 3)[-1]}/"
)

severe performance issues, see also https://github.com/dask/dask/issues/10276

To fix, you should specify a lower version bound on s3fs, or
update the current installation.



Creating json from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Creating json from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Training data uploaded to:
s3://sagemaker-us-east-1-395271362395/datasets/text-to-sql/train_dataset.json
https://s3.console.aws.amazon.com/s3/buckets/sagemaker-us-east-1-395271362395/?region=us-east-1&prefix=datasets/text-to-sql/


# 3. Fine-Tune Gemma2 with QLoRA on Amazon Sagemaker


In [24]:
from typing import Optional, Union
from dataclasses import dataclass, field, fields, asdict
from trl import SFTConfig


@dataclass
class PartialTrainingArguments:
    num_train_epochs: int
    per_device_train_batch_size: int
    gradient_accumulation_steps: int
    gradient_checkpointing: bool
    optim: str
    logging_steps: int
    save_strategy: int
    learning_rate: float
    bf16: bool
    tf32: bool
    max_grad_norm: float
    warmup_ratio: float
    lr_scheduler_type: str
    report_to: str
    output_dir: str
    fsdp: str
    fsdp_config: Optional[Union[dict, str]]
    # SFTrainer Config
    dataset_text_field: str
    packing: bool
    max_seq_length: int


@dataclass
class Hyperparameters(PartialTrainingArguments):
    # path where sagemaker will save training dataset
    train_dataset_path: str
    test_dataset_path: str
    model_id: str

    def to_dict(self):
        return asdict(self)


# Validate training arguments
training_args_fields = {field.name for field in fields(SFTConfig)}
partial_training_args_fields = {
    field.name for field in fields(PartialTrainingArguments)
}
is_subset = partial_training_args_fields.issubset(training_args_fields)
assert is_subset, "All fields in PartialTrainingArguments should be in SFTConfig"

In [25]:
from transformers.trainer_utils import FSDPOption

hyperparameters = Hyperparameters(
    ### SagemakerArguments ###
    train_dataset_path="/opt/ml/input/data/training/train_dataset.json",
    test_dataset_path="/opt/ml/input/data/training/test_dataset.json",
    model_id="google/gemma-2-9b",
    ### TrainingArguments ###
    num_train_epochs=3,  # number of training epochs
    per_device_train_batch_size=4,  # batch size per device during training
    gradient_accumulation_steps=4,  # number of steps before performing a backward/update pass
    gradient_checkpointing=True,  # use gradient checkpointing to save memory
    optim="adamw_8bit",  # use adamw_8bit optimizer
    logging_steps=10,  # log every 10 steps
    save_strategy="epoch",  # save checkpoint every epoch
    learning_rate=2e-4,  # learning rate, based on QLoRA paper
    bf16=True,  # use bfloat16 precision
    tf32=True,  # use tf32 precision
    max_grad_norm=0.3,  # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,  # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",  # use constant learning rate scheduler
    report_to="tensorboard",  # report metrics to tensorboard
    output_dir="/tmp/tun",  # Temporary output directory for model checkpoints
    fsdp=f"{FSDPOption.SHARD_GRAD_OP} {FSDPOption.AUTO_WRAP} {FSDPOption.OFFLOAD}",
    fsdp_config={
        "transformer_layer_cls_to_wrap": "GemmaDecoderLayer",
        "backward_prefetch": "backward_pre",
        "forward_prefetch": False,
        "use_orig_params": False,
        "cpu_ram_efficient_loading": True,
    },
    ### SFTrainer Config ###
    dataset_text_field="prompt",
    packing=True,
    max_seq_length=512,
).to_dict()
print(hyperparameters)

{'num_train_epochs': 1, 'per_device_train_batch_size': 12, 'gradient_accumulation_steps': 4, 'gradient_checkpointing': True, 'optim': 'paged_adamw_8bit', 'logging_steps': 10, 'save_strategy': 'epoch', 'learning_rate': 0.0002, 'bf16': True, 'tf32': True, 'max_grad_norm': 0.3, 'warmup_ratio': 0.03, 'lr_scheduler_type': 'constant', 'report_to': 'tensorboard', 'output_dir': '/tmp/tun', 'fsdp': 'shard_grad_op auto_wrap offload', 'fsdp_config': {'transformer_layer_cls_to_wrap': 'GemmaDecoderLayer', 'backward_prefetch': 'backward_pre', 'forward_prefetch': False, 'use_orig_params': False, 'cpu_ram_efficient_loading': True}, 'dataset_text_field': 'prompt', 'packing': True, 'max_seq_length': 1024, 'train_dataset_path': '/opt/ml/input/data/training/train_dataset.json', 'test_dataset_path': '/opt/ml/input/data/training/test_dataset.json', 'model_id': 'google/gemma-2-9b'}


In [26]:
from sagemaker.huggingface import HuggingFace

# define Training Job Name
job_name = f"gemma-9b-fsdp-text-to-sql"

# create the Estimator
huggingface_estimator = HuggingFace(
    entry_point="run_sft_fsdp.py",  # train script
    source_dir="./scripts",  # directory which includes all the files needed for training
    instance_type="ml.p4d.24xlarge",  # instances type used for the training job
    instance_count=1,  # the number of instances used for training
    max_run=2
    * 24
    * 3600,  # maximum runtime in seconds (days * hours * minutes * seconds)
    base_job_name=job_name,  # the name of the training job
    role=role,  # Iam role used in training job to access AWS ressources, e.g. S3
    volume_size=300,  # the size of the EBS volume in GB
    transformers_version="4.36",  # the transformers version used in the training job
    pytorch_version="2.1",  # the pytorch_version version used in the training job
    py_version="py310",  # the python version used in the training job
    hyperparameters=hyperparameters,  # the hyperparameters passed to the training job
    disable_output_compression=True,  # not compress output to save training time and cost
    environment={
        "HUGGINGFACE_HUB_CACHE": "/tmp/.cache",  # set env variable to cache models in /tmp
        "HF_TOKEN": "REPLACE_WITH_YOUR_TOKEN",  # huggingface token to access gated models, e.g. llama 2
    },
)

In [None]:
# define a data input dictonary with our uploaded s3 uris
data = {"training": training_input_path}

# starting the train job with our uploaded datasets as input
huggingface_estimator.fit(data, wait=True)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: gemma-9b-fsdp-text-to-sql-2024-07-20-08-36-34-388


2024-07-20 08:36:34 Starting - Starting the training job...
2024-07-20 08:36:34 Pending - Training job waiting for capacity.........................................................................