# Fine-tune the Instructor-Model for Dialogue Summarization using SageMaker Training Jobs


# NOTE:  THIS NOTEBOOK WILL TAKE ABOUT 20 MINUTES TO COMPLETE.

# PLEASE BE PATIENT.

In [2]:
import boto3
import sagemaker
import pandas as pd

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

import botocore.config

config = botocore.config.Config(
    user_agent_extra='dsoaws/2.0'
)

sm = boto3.Session().client(service_name="sagemaker", 
                            region_name=region, 
                            config=config)

# _PRE-REQUISITE: You need to have succesfully run the notebooks in the `PREPARE` section before you continue with this notebook._

In [3]:
%store -r processed_train_data_s3_uri

In [4]:
try:
    processed_train_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [5]:
print(processed_train_data_s3_uri)

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/train


In [6]:
%store -r processed_validation_data_s3_uri

In [7]:
try:
    processed_validation_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [8]:
print(processed_validation_data_s3_uri)

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/validation


In [9]:
%store -r processed_test_data_s3_uri

In [10]:
try:
    processed_test_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [11]:
print(processed_test_data_s3_uri)

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/test


In [12]:
%store -r model_checkpoint

In [13]:
try:
    model_checkpoint
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [14]:
print(model_checkpoint)

google/flan-t5-base


# Specify the Dataset in S3
We are using the train, validation, and test splits created in the previous section.

In [21]:
print(processed_train_data_s3_uri)

!aws s3 ls $processed_train_data_s3_uri/

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/train
2023-04-17 01:52:38    2545157 1681696352307.parquet
2023-04-17 01:52:42    2540571 1681696357396.parquet


In [22]:
print(processed_validation_data_s3_uri)

!aws s3 ls $processed_validation_data_s3_uri/

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/validation
2023-04-17 01:52:38     150220 1681696352307.parquet
2023-04-17 01:52:42     150701 1681696357396.parquet


In [23]:
print(processed_test_data_s3_uri)

!aws s3 ls $processed_test_data_s3_uri/

s3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/test
2023-04-17 01:52:39     153865 1681696352307.parquet
2023-04-17 01:52:43     157115 1681696357396.parquet


# Specify S3 Input Data

In [24]:
from sagemaker.inputs import TrainingInput

s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri)
s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri)
s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri)

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/train', 'S3DataDistributionType': 'FullyReplicated'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/validation', 'S3DataDistributionType': 'FullyReplicated'}}}
{'DataSource': {'S3DataSource': {'S3DataType': 'S3Prefix', 'S3Uri': 's3://sagemaker-us-east-1-371366150581/sagemaker-scikit-learn-2023-04-17-01-47-03-582/output/test', 'S3DataDistributionType': 'FullyReplicated'}}}


# Setup Hyper-Parameters for FLAN Model

In [26]:
epochs = 1 # increase this if you want to train for a longer period
learning_rate = 0.00001
weight_decay = 0.01
train_batch_size = 4
validation_batch_size = 4
test_batch_size = 4
train_steps_per_epoch = 10
validation_steps = 10
test_steps = 10
train_instance_count = 1
train_instance_type = "ml.c5.9xlarge"
train_volume_size = 1024
enable_sagemaker_debugger = False
enable_checkpointing = False
enable_tensorboard = False
input_mode = "FastFile"
run_validation = False
run_test = False
run_sample_predictions = False

# Setup Metrics To Track Model Performance

These sample log lines...
```
45/50 [=====>..] - ETA: 3s - loss: 0.425 - accuracy: 0.881
50/50 [=======>] - ETA: 0s - val_loss: 0.407 - val_accuracy: 0.885
```
...will produce the following 4 metrics in CloudWatch:

`loss` = 0.425

`accuracy` = 0.881

`val_loss` = 0.407

`val_accuracy` = 0.885

In [27]:
metrics_definitions = [
    {"Name": "train:loss", "Regex": "loss: ([0-9\\.]+)"},
    {"Name": "train:accuracy", "Regex": "accuracy: ([0-9\\.]+)"},
    {"Name": "validation:loss", "Regex": "val_loss: ([0-9\\.]+)"},
    {"Name": "validation:accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"},
]

# Specify Checkpoint S3 Location
This is used for Spot Instances Training.  If nodes are replaced, the new node will start training from the latest checkpoint.

In [30]:
import uuid

checkpoint_s3_prefix = "checkpoints/{}".format(str(uuid.uuid4()))
checkpoint_s3_uri = "s3://{}/{}/".format(bucket, checkpoint_s3_prefix)

print(checkpoint_s3_uri)

s3://sagemaker-us-east-1-371366150581/checkpoints/b7550ab6-30e3-46ac-9850-6566f5744d1f/


# Setup Our Script to Run on SageMaker
Prepare our model to run on the managed SageMaker service

In [55]:
!pygmentize src/train.py

[34mimport[39;49;00m [04m[36margparse[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mos[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mjson[39;49;00m[37m[39;49;00m
[34mimport[39;49;00m [04m[36mpprint[39;49;00m[37m[39;49;00m
[37m[39;49;00m
[34mfrom[39;49;00m [04m[36mtransformers[39;49;00m [34mimport[39;49;00m AutoTokenizer, AutoModelForSeq2SeqLM, TrainingArguments, Trainer, GenerationConfig[37m[39;49;00m
[34mfrom[39;49;00m [04m[36mdatasets[39;49;00m [34mimport[39;49;00m load_dataset[37m[39;49;00m
[37m[39;49;00m
[34mdef[39;49;00m [32mlist_files[39;49;00m(startpath):[37m[39;49;00m
[37m    [39;49;00m[33m"""Helper function to list files in a directory"""[39;49;00m[37m[39;49;00m
    [34mfor[39;49;00m root, dirs, files [35min[39;49;00m os.walk(startpath):[37m[39;49;00m
        level = root.replace(startpath, [33m'[39;49;00m[33m'[39;49;00m).count(os.sep)[37m[39;49;00m
        indent = [33m'[39;49;00m[33m 

In [32]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    role=role,
    instance_count=train_instance_count,
    instance_type=train_instance_type,
    volume_size=train_volume_size,
    checkpoint_s3_uri=checkpoint_s3_uri,
    py_version="py39",
    framework_version="1.13",
    hyperparameters={
        "epochs": epochs,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,        
        "train_batch_size": train_batch_size,
        "validation_batch_size": validation_batch_size,
        "test_batch_size": test_batch_size,
        "train_steps_per_epoch": train_steps_per_epoch,
        "validation_steps": validation_steps,
        "test_steps": test_steps,
        "model_checkpoint": model_checkpoint,
        "enable_checkpointing": enable_checkpointing,
        "enable_tensorboard": enable_tensorboard,
        "run_validation": run_validation,
        "run_test": run_test,
        "run_sample_predictions": run_sample_predictions,
    },
    input_mode=input_mode,
    metric_definitions=metrics_definitions,
)

# Train the Model on SageMaker

In [33]:
estimator.fit(
    inputs={"train": s3_input_train_data, "validation": s3_input_validation_data, "test": s3_input_test_data},
    wait=False,
)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: pytorch-training-2023-04-17-18-55-56-793


In [34]:
training_job_name = estimator.latest_training_job.name
print("Training Job Name: {}".format(training_job_name))

Training Job Name:  pytorch-training-2023-04-17-18-55-56-793


In [35]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}">Training Job</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

In [36]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

In [37]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview">S3 Output Data</a> After The Training Job Has Completed</b>'.format(
            bucket, training_job_name, region
        )
    )
)

In [39]:
%%time

estimator.latest_training_job.wait(logs=False)


2023-04-17 18:55:57 Starting - Starting the training job..
2023-04-17 18:56:13 Starting - Preparing the instances for training........
2023-04-17 18:57:01 Downloading - Downloading input data...
2023-04-17 18:57:21 Training - Downloading the training image.......
2023-04-17 18:58:02 Training - Training image download completed. Training in progress.............................................
2023-04-17 19:01:48 Uploading - Uploading generated training model....................
2023-04-17 19:03:35 Completed - Training job completed
CPU times: user 345 ms, sys: 34 ms, total: 379 ms
Wall time: 7min 39s


# Deploy the Fine-Tuned Model to a Real Time Endpoint

In [65]:
sm_model = estimator.create_model(
    entry_point='inference.py',
    source_dir='src',
)
endpoint_name = training_job_name.replace('pytorch-training', 'summary-tuned')
predictor = sm_model.deploy(
    initial_instance_count=1,
    instance_type='ml.m5.2xlarge',
    endpoint_name=endpoint_name
)

# Zero Shot Inference with the Fine-Tuned Model in a SageMaker Endpoint

In [72]:
zero_shot_prompt = """Summarize the following conversation.

#Person1#: Tom, I've got good news for you.
#Person2#: What is it?
#Person1#: Haven't you heard that your novel has won The Nobel Prize?
#Person2#: Really? I can't believe it. It's like a dream come true. I never expected that I would win The Nobel Prize!
#Person1#: You did a good job. I'm extremely proud of you.
#Person2#: Thanks for the compliment.
#Person1#: You certainly deserve it. Let's celebrate!

Summary:"""

In [74]:
import json
from sagemaker import Predictor
predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
)
response = predictor.predict(zero_shot_prompt,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
)
response_json = json.loads(response.decode('utf-8'))
print(response_json)

Tom's novel has won the Nobel Prize.


# Tear Down the Endpoint

In [None]:
predictor.delete_endpoint()