# NOTE:  THIS NOTEBOOK WILL TAKE ABOUT 20 MINUTES TO COMPLETE.

# PLEASE BE PATIENT.

## TODOs:
* set up eval and train metric tracking
* set up debugger
* remove comments in script

# Fine-Tuning a Generative Model 

In the previous section, we've already performed the Feature Engineering to create embeddings from the `reviews_body` text using a pre-trained model.  We  split the dataset into train, validation and test files. To optimize for fine-tuning training, we saved the files in Arrow format.  Now, let’s fine-tune the model with data Amazon Customer Reviews Dataset.

In [None]:
import boto3
import sagemaker
import pandas as pd

sess = sagemaker.Session()
bucket = sess.default_bucket()
role = sagemaker.get_execution_role()
region = boto3.Session().region_name

import botocore.config

config = botocore.config.Config(
    user_agent_extra='dsoaws/2.0'
)

sm = boto3.Session().client(service_name="sagemaker", 
                            region_name=region, 
                            config=config)

# _PRE-REQUISITE: You need to have succesfully run the notebooks in the `PREPARE` section before you continue with this notebook._

In [None]:
%store -r processed_train_data_s3_uri

In [None]:
try:
    processed_train_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
print(processed_train_data_s3_uri)

In [None]:
%store -r processed_validation_data_s3_uri

In [None]:
try:
    processed_validation_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
print(processed_validation_data_s3_uri)

In [None]:
%store -r processed_test_data_s3_uri

In [None]:
try:
    processed_test_data_s3_uri
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
print(processed_test_data_s3_uri)

In [None]:
%store -r model_checkpoint

In [None]:
try:
    model_checkpoint
except NameError:
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
    print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
print(model_checkpoint)

In [None]:
# %store -r dataset_templates_name

In [None]:
# try:
#     dataset_templates_name
# except NameError:
#     print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
#     print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
#     print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
# print(dataset_templates_name)

In [None]:
# %store -r prompt_template_name

In [None]:
# try:
#     prompt_template_name
# except NameError:
#     print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
#     print("[ERROR] Please run the notebooks in the PREPARE section before you continue.")
#     print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
# print(prompt_template_name)

# Specify the Dataset in S3
We are using the train, validation, and test splits created in the previous section.

In [None]:
print(processed_train_data_s3_uri)

!aws s3 ls $processed_train_data_s3_uri/

In [None]:
print(processed_validation_data_s3_uri)

!aws s3 ls $processed_validation_data_s3_uri/

In [None]:
print(processed_test_data_s3_uri)

!aws s3 ls $processed_test_data_s3_uri/

# Specify S3 Input Data

In [None]:
from sagemaker.inputs import TrainingInput

s3_input_train_data = TrainingInput(s3_data=processed_train_data_s3_uri)
s3_input_validation_data = TrainingInput(s3_data=processed_validation_data_s3_uri)
s3_input_test_data = TrainingInput(s3_data=processed_test_data_s3_uri)

print(s3_input_train_data.config)
print(s3_input_validation_data.config)
print(s3_input_test_data.config)

# Setup Hyper-Parameters for Classification Layer

In [None]:
print(model_checkpoint)

In [None]:
epochs = 1 # increase this if you want to train for a longer period
learning_rate = 0.00001
weight_decay = 0.01
train_batch_size = 4
validation_batch_size = 4
test_batch_size = 4
train_steps_per_epoch = 10
validation_steps = 10
test_steps = 10
train_instance_count = 1
train_instance_type = "ml.c5.9xlarge"
train_volume_size = 1024
enable_sagemaker_debugger = False
enable_checkpointing = False
enable_tensorboard = False
input_mode = "FastFile"
run_validation = False
run_test = False
run_sample_predictions = False

# Setup Metrics To Track Model Performance

These sample log lines...
```
45/50 [=====>..] - ETA: 3s - loss: 0.425 - accuracy: 0.881
50/50 [=======>] - ETA: 0s - val_loss: 0.407 - val_accuracy: 0.885
```
...will produce the following 4 metrics in CloudWatch:

`loss` = 0.425

`accuracy` = 0.881

`val_loss` = 0.407

`val_accuracy` = 0.885

In [None]:
metrics_definitions = [
    {"Name": "train:loss", "Regex": "loss: ([0-9\\.]+)"},
    {"Name": "train:accuracy", "Regex": "accuracy: ([0-9\\.]+)"},
    {"Name": "validation:loss", "Regex": "val_loss: ([0-9\\.]+)"},
    {"Name": "validation:accuracy", "Regex": "val_accuracy: ([0-9\\.]+)"},
]

# Setup SageMaker Debugger
Define Debugger Rules as deccribed here:  https://docs.aws.amazon.com/sagemaker/latest/dg/debugger-built-in-rules.html

In [None]:
from sagemaker.debugger import Rule
from sagemaker.debugger import rule_configs
from sagemaker.debugger import ProfilerRule
from sagemaker.debugger import CollectionConfig
from sagemaker.debugger import DebuggerHookConfig

actions = rule_configs.ActionList(
    #    rule_configs.StopTraining(),
    #    rule_configs.Email("")
)

rules = [
    ProfilerRule.sagemaker(rule_configs.ProfilerReport()),    
#     ProfilerRule.sagemaker(rule_configs.BatchSize()),
#     ProfilerRule.sagemaker(rule_configs.CPUBottleneck()),
#     ProfilerRule.sagemaker(rule_configs.GPUMemoryIncrease()),
#     ProfilerRule.sagemaker(rule_configs.IOBottleneck()),
#     ProfilerRule.sagemaker(rule_configs.LoadBalancing()),
#     ProfilerRule.sagemaker(rule_configs.LowGPUUtilization()),
#     ProfilerRule.sagemaker(rule_configs.OverallSystemUsage()),
#     ProfilerRule.sagemaker(rule_configs.StepOutlier()),
#     Rule.sagemaker(
#         base_config=rule_configs.loss_not_decreasing(),
#         rule_parameters={
#             "collection_names": "losses,metrics",
#             "use_losses_collection": "true",
#             "num_steps": "10",
#             "diff_percent": "50",
#         },
#         collections_to_save=[
#             CollectionConfig(
#                 name="losses",
#                 parameters={
#                     "save_interval": "10",
#                 },
#             ),
#             CollectionConfig(
#                 name="metrics",
#                 parameters={
#                     "save_interval": "10",
#                 },
#             ),
#         ],
#         actions=actions,
#     ),
#     Rule.sagemaker(
#         base_config=rule_configs.overtraining(),
#         rule_parameters={
#             "collection_names": "losses,metrics",
#             "patience_train": "10",
#             "patience_validation": "10",
#             "delta": "0.5",
#         },
#         collections_to_save=[
#             CollectionConfig(
#                 name="losses",
#                 parameters={
#                     "save_interval": "10",
#                 },
#             ),
#             CollectionConfig(
#                 name="metrics",
#                 parameters={
#                     "save_interval": "10",
#                 },
#             ),
#         ],
#         actions=actions,
#     )    
]

hook_config = DebuggerHookConfig(
    hook_parameters={
        "save_interval": "10",  # number of steps
        "export_tensorboard": "true",
        "tensorboard_dir": "hook_tensorboard/",
    }
)

## Specify a Debugger profiler configuration

The following configuration will capture system metrics at 500 milliseconds. The system metrics include utilization per CPU, GPU, memory utilization per CPU, GPU as well I/O and network.

Debugger will capture detailed profiling information from step 5 to step 15. This information includes Horovod metrics, dataloading, preprocessing, operators running on CPU and GPU.

In [None]:
from sagemaker.debugger import ProfilerConfig, FrameworkProfile

profiler_config = ProfilerConfig(
    system_monitor_interval_millis=500,
    framework_profile_params=FrameworkProfile(local_path="/opt/ml/output/profiler/", start_step=5, num_steps=10),
)

# Specify Checkpoint S3 Location
This is used for Spot Instances Training.  If nodes are replaced, the new node will start training from the latest checkpoint.

In [None]:
import uuid

checkpoint_s3_prefix = "checkpoints/{}".format(str(uuid.uuid4()))
checkpoint_s3_uri = "s3://{}/{}/".format(bucket, checkpoint_s3_prefix)

print(checkpoint_s3_uri)

# Setup Our Script to Run on SageMaker
Prepare our model to run on the managed SageMaker service

In [None]:
!pygmentize src/train.py

In [None]:
from sagemaker.pytorch import PyTorch

estimator = PyTorch(
    entry_point="train.py",
    source_dir="src",
    role=role,
    instance_count=train_instance_count,
    instance_type=train_instance_type,
    volume_size=train_volume_size,
    checkpoint_s3_uri=checkpoint_s3_uri,
    py_version="py39",
    framework_version="1.13",
    hyperparameters={
        "epochs": epochs,
        "learning_rate": learning_rate,
        "weight_decay": weight_decay,        
        "train_batch_size": train_batch_size,
        "validation_batch_size": validation_batch_size,
        "test_batch_size": test_batch_size,
        "train_steps_per_epoch": train_steps_per_epoch,
        "validation_steps": validation_steps,
        "test_steps": test_steps,
        "model_checkpoint": model_checkpoint,
        # "dataset_templates_name": dataset_templates_name,
        # "prompt_template_name": prompt_template_name,
        "enable_checkpointing": enable_checkpointing,
        "enable_tensorboard": enable_tensorboard,
        "run_validation": run_validation,
        "run_test": run_test,
        "run_sample_predictions": run_sample_predictions,
    },
    input_mode=input_mode,
    metric_definitions=metrics_definitions,
    # rules=rules,
    # debugger_hook_config=hook_config,
    # profiler_config=profiler_config,
)

# Train the Model on SageMaker

In [None]:
estimator.fit(
    inputs={"train": s3_input_train_data, "validation": s3_input_validation_data, "test": s3_input_test_data},
    wait=False,
)

In [None]:
training_job_name = estimator.latest_training_job.name
print("Training Job Name:  {}".format(training_job_name))

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}">Training Job</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix">CloudWatch Logs</a> After About 5 Minutes</b>'.format(
            region, training_job_name
        )
    )
)

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview">S3 Output Data</a> After The Training Job Has Completed</b>'.format(
            bucket, training_job_name, region
        )
    )
)

In [None]:
from IPython.core.display import display, HTML

display(
    HTML(
        '<b>Review <a target="blank" href="https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview">S3 Checkpoint Data</a> After The Training Job Has Completed</b>'.format(
            bucket, checkpoint_s3_prefix, region
        )
    )
)

In [None]:
%%time

estimator.latest_training_job.wait(logs=False)

# Wait Until the ^^ Training Job ^^ Completes Above!

# Display Training Job Metrics

In [None]:
estimator.training_job_analytics.dataframe()

# [INFO] _Feel free to continue to the next workshop section while this notebook is running._

In [None]:
%store training_job_name

In [None]:
!aws s3 cp s3://$bucket/$training_job_name/output/model.tar.gz ./model.tar.gz

In [None]:
!mkdir -p ./model/
!tar -xvzf ./model.tar.gz -C ./model/

# Analyze Debugger Rules

In [None]:
estimator.latest_training_job.rule_job_summary()

In [None]:
%store