# Fine-tune LLM with PyTorch FSDP and QLora on Amazon SageMaker AI using ModelTrainer

In this notebook, we fine-tune LLM on Amazon SageMaker AI, using Python scripts and SageMaker ModelTrainer for executing a training job.

## Prerequisites

In [None]:
! pip install -r ./scripts/requirements.txt --upgrade
! pip install -U sagemaker

***

## Setup Configuration file path

In [None]:
import os
import sys

module_path = os.path.abspath(os.path.join("../.."))

if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import os

model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

os.environ["model_id"] = model_id

***

## Visualize and upload the dataset

We are going to load [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT) dataset

In [None]:
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
from datasets import load_dataset
import pandas as pd

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split="train[:1000]")

df = pd.DataFrame(dataset)

df.head()

In [None]:
from sklearn.model_selection import train_test_split

train, val = train_test_split(df, test_size=0.1, random_state=42)

print("Number of train elements: ", len(train))
print("Number of test elements: ", len(val))

Create a prompt template and load the dataset with a random sample to try summarization.

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_id)

def prepare_dataset(sample):
    system_text = (
        "You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning.\n"
        "Below is an instruction that describes a task, paired with an input that provides further context.\n"
        "Write a response that appropriately completes the request.\n"
        "Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response."
    )

    messages = []
    messages.append({"role": "system", "content": system_text})
    messages.append({"role": "user", "content": sample["Question"]})

    # Use different tags that won't be detected by the template
    messages.append(
        {
            "role": "assistant",
            "content": f"\n[REASONING_START]\n{sample['Complex_CoT']}\n[REASONING_END]\n{sample['Response']}",
        }
    )

    formatted_text = tokenizer.apply_chat_template(messages, tokenize=False)

    # Replace with actual think tags after template processing
    formatted_text = formatted_text.replace("[REASONING_START]", "<think>")
    formatted_text = formatted_text.replace("[REASONING_END]", "</think>")

    sample["text"] = formatted_text

    return sample

Use the Hugging Face Trainer class to fine-tune the model. Define the hyperparameters we want to use. We also create a DataCollator that will take care of padding our inputs and labels.

In [None]:
from datasets import Dataset, DatasetDict
from random import randint

train_dataset = Dataset.from_pandas(train)
val_dataset = Dataset.from_pandas(val)

dataset = DatasetDict({"train": train_dataset, "val": val_dataset})

train_dataset = dataset["train"].map(
    prepare_dataset, remove_columns=list(train_dataset.features)
)

print(train_dataset[randint(0, len(dataset))]["text"])

val_dataset = dataset["val"].map(
    prepare_dataset, remove_columns=list(val_dataset.features)
)

### Upload to Amazon S3

In [None]:
import boto3
import shutil
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
s3_client = boto3.client('s3')

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

In [None]:
# save train_dataset to s3 using our SageMaker session
if default_prefix:
    input_path = f"{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft-batch"
else:
    input_path = f"datasets/llm-fine-tuning-modeltrainer-sft-batch"

# Save datasets to s3
# We will fine tune only with 20 records due to limited compute resource for the workshop
train_dataset.to_json("./data/train/dataset.json", orient="records")
val_dataset.to_json("./data/val/dataset.json", orient="records")

s3_client.upload_file("./data/train/dataset.json", bucket_name, f"{input_path}/train/dataset.json")
train_dataset_s3_path = f"s3://{bucket_name}/{input_path}/train/dataset.json"
s3_client.upload_file("./data/val/dataset.json", bucket_name, f"{input_path}/val/dataset.json")
val_dataset_s3_path = f"s3://{bucket_name}/{input_path}/val/dataset.json"

shutil.rmtree("./data")

print(f"Training data uploaded to:")
print(train_dataset_s3_path)
print(val_dataset_s3_path)

***

## Model fine-tuning

We are now ready to fine-tune our model. We will use the [Trainer](https://huggingface.co/docs/transformers/main_classes/trainer) from transfomers to fine-tune our model. We prepared a script [train.py](./scripts/train.py) which will loads the dataset from disk, prepare the model, tokenizer and start the training.

For configuration we use `TrlParser`, that allows us to provide hyperparameters in a `yaml` file. This yaml will be uploaded and provided to Amazon SageMaker similar to our datasets. Below is the config file for fine-tuning the model on `ml.g5.12xlarge`. We are saving the config file as `args.yaml` and upload it to S3.

In [None]:
%%bash

cat > ./args.yaml <<EOF
model_id: "${model_id}"       # Hugging Face model id
# sagemaker specific parameters
output_dir: "/opt/ml/model"                       # path to where SageMaker will upload the model 
train_dataset_path: "/opt/ml/input/data/train/"   # path to where FSx saves train dataset
test_dataset_path: "/opt/ml/input/data/val/"     # path to where FSx saves test dataset
# training parameters
lora_r: 8
lora_alpha: 16
lora_dropout: 0.1                 
learning_rate: 2e-4                    # learning rate scheduler
num_train_epochs: 1                    # number of training epochs
per_device_train_batch_size: 2         # batch size per device during training
per_device_eval_batch_size: 1          # batch size for evaluation
gradient_accumulation_steps: 2         # number of steps before performing a backward/update pass
gradient_checkpointing: true           # use gradient checkpointing
bf16: true                             # use bfloat16 precision
tf32: false                            # use tf32 precision
fsdp: "full_shard auto_wrap offload"
fsdp_config: 
    backward_prefetch: "backward_pre"
    cpu_ram_efficient_loading: true
    offload_params: true
    forward_prefetch: false
    use_orig_params: true
merge_weights: true                    # merge weights in the base model
EOF

Lets upload the config file to S3.

In [None]:
import os
from sagemaker.s3 import S3Uploader

if default_prefix:
    input_path = f"s3://{bucket_name}/{default_prefix}/datasets/llm-fine-tuning-modeltrainer-sft-batch"
else:
    input_path = f"s3://{bucket_name}/datasets/llm-fine-tuning-modeltrainer-sft-batch"

# upload the model yaml file to s3
model_yaml = "args.yaml"
train_config_s3_path = S3Uploader.upload(local_path=model_yaml, desired_s3_uri=f"{input_path}/config")

os.remove("./args.yaml")

print(f"Training config uploaded to:")
print(train_config_s3_path)

## Create the ModelTrainer

Below the Estimator will will be used to submit the jobs

#### Get PyTorch image_uri

We are going to use the native PyTorch container image, pre-built for Amazon SageMaker

In [None]:
import sagemaker
from sagemaker.config import load_sagemaker_config

In [None]:
sagemaker_session = sagemaker.Session()

bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix
configs = load_sagemaker_config()

In [None]:
instance_type = "ml.g6.12xlarge"
instance_count = 1

instance_type

In [None]:
image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.6",
    instance_type=instance_type,
    image_scope="training"
)

image_uri

In [None]:
model_id = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"

In [None]:
from sagemaker.modules.configs import Compute, OutputDataConfig, SourceCode, StoppingCondition
from sagemaker.modules.distributed import Torchrun
from sagemaker.modules.train import ModelTrainer

role = sagemaker.get_execution_role()

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="train.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=1,
    keep_alive_period_in_seconds=0
)

# define Training Job Name
job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}-sft-batch"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    distributed=Torchrun(),
    stopping_condition=StoppingCondition(max_runtime_in_seconds=7200),
    hyperparameters={
        "config": "/opt/ml/input/data/config/args.yaml"  # path to TRL config which was uploaded to s3
    },
    output_data_config=OutputDataConfig(s3_output_path=output_path),
    role=role,
)

In [None]:
from sagemaker.modules.configs import InputData

# Pass the input data
train_input = InputData(
    channel_name="train",
    data_source=train_dataset_s3_path, # S3 path where training data is stored
)

val_input = InputData(
    channel_name="val",
    data_source=val_dataset_s3_path,  # S3 path where training data is stored
)

config_input = InputData(
    channel_name="config",
    data_source=train_config_s3_path, # S3 path where training data is stored
)

# Check input channels configured
TRAINING_INPUTS = [train_input, val_input, config_input]
TRAINING_INPUTS

***

## Queue Some Training Jobs
This section and the following are intended to be used interactively so that you can explore how to use the SageMaker Python SDK to submit jobs to your Batch queues. Let's start by selecting which queue to submit to.

### Select the Queue to Use

In [None]:
from sagemaker.aws_batch.training_queue import TrainingQueue

# Set the queue type to use for your job submission
SMTJ_BATCH_QUEUE = "ml-g6-12xlarge-queue"

# Construct the queue object using the SageMaker Python SDK
queue = TrainingQueue(SMTJ_BATCH_QUEUE)
print(f"Using queue: {queue.queue_name}")

### Submit your jobs
In the next cell, we are going to submit 2 Training jobs in the queue
1. LOW PRIORITY
3. MEDIUM PRIORITY

We are going to use the API `submit` to submit all the jobs

In [None]:
job_name_1 = job_name + "-low-pri"
queued_job_1 = queue.submit(
    model_trainer, TRAINING_INPUTS, job_name_1, priority=5, share_identifier="LOWPRI"
)

job_name_2 = job_name + "-mid-pri"
queued_job_2 = queue.submit(
    model_trainer, TRAINING_INPUTS, job_name_2, priority=3, share_identifier="MIDPRI"
)

## Display the Status of Running and 'In Queue' Jobs
We can use the job queue list and job queue snapshot APIs to programmaticaly view a snapshot of the jobs that the queue will run next. Keep in mind that for fair-share queues this ordering is dynamic and occassionally needs to be refreshed as new jobs are submitted to the queue or as share usage changes over time.

In [None]:
from smtj_batch_utils.queue_utils import print_queue_state

print_queue_state(queue)

### Submit an additional job
In the next cell, we are going to submit an additional job to the queue, by using the API `submit`

In [None]:
job_name_3 = job_name + "-high-pri"
queued_job_3 = queue.submit(
    model_trainer, TRAINING_INPUTS, job_name_3, priority=1, share_identifier="HIGHPRI"
)

## Display the Status of Running and 'In Queue' Jobs
Now we are going to see another runnable job. Given that the last job has high priority, it will be run before the `MIDPRI` and `LOWPRI` jobs

In [None]:
from smtj_batch_utils.queue_utils import print_queue_state

print_queue_state(queue)

## Cancel a Job in the Queue
This next cell shows how to cancel an in queue job.

In [None]:
runnable_jobs = queue.list_jobs(status="RUNNABLE")
if runnable_jobs:
    for job in runnable_jobs:
        job_to_cancel = job
        print(f"Cancelling job: {job_to_cancel.describe().get('jobName', '')}")
        job_to_cancel.terminate()