## SageMaker training with FSx

This sample shows how to:

- Setup FSx
- Associate data in S3 with FSx
- Run a SageMaker Training job using data from FSx mount
- Save artifacts into FSx which are automatically pushed to S3
- Tear down the infrastructure

Checkout this [blog](https://aws.amazon.com/blogs/machine-learning/choose-the-best-data-source-for-your-amazon-sagemaker-training-job/) to verify if FSx is needed for your use-case to save operational costs. 

**Please make sure the CIDR block in setup/cfn-nlp.yaml does not conflict with your existing VPC. You can also change FSx storage (currently set at 1.2 TB) depending on your data sets**

In [None]:
# Imports
import os
import time
import boto3
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import FileSystemInput

# Clients
cfn_client = boto3.client("cloudformation", region_name=region)
fsx_client = boto3.client("fsx", region_name=region)

# Inputs
region = "us-east-1"  # update this if your region is different
region_az = "us-east-1c"  # customize this as needed. Your FSx will be set up in a subnet in this AZ
cfn_stack_name = 'fsx-training'  # cloudformation stack name

s3_data_bucket = 's3://sagemaker-us-east-1-988346548731'  # s3 bucket for training artifacts and datasets
s3_data_train_prefix = 'train'  # s3 training data set
s3_data_model_prefix = 'model_dir' # s3 path to save model
s3_data_checkpoint_prefix = 'checkpoint_dir'  # s3 path to save model checkpoints
fsx_file_system_path = 'fsx-train'  # this is file system path on FSx for the data, can be anything

In [None]:
# Setup infrastructure using CloudFormation
with open("setup/cfn-nlp.yaml", "r") as f:
    template_body = f.read()
    
create_stack_response = cfn_client.create_stack(
    StackName=cfn_stack_name,
    TemplateBody=template_body,
    Parameters=[
        {
            'ParameterKey': 'AZ',
            'ParameterValue': region_az
        }
    ]
)

create_stack_response

In [None]:
# Wait for stack to be created, it takes ~10 minutes to complete.
stack_id = create_stack_response['StackId']

while True:
    response = cfn_client.describe_stacks(
        StackName=stack_id
    )
    status = response['Stacks'][0]['StackStatus']
    if status== "CREATE_IN_PROGRESS":
        print("Create in progress. Waiting..")
        time.sleep(30)
    elif status=="CREATE_COMPLETE":
        print("Stack created!")
        break
    else:
        print("Error creating stack - check the CFN console")
        break

In [None]:
# Get stack outputs
describe_response = cfn_client.describe_stacks(
    StackName=stack_id
)

outputs = describe_response['Stacks'][0]['Outputs']

for output in outputs:
    if output['OutputKey'] == 'sg':
        sec_group = output['OutputValue']
    elif output['OutputKey'] == 'outputfsx':
        fsx_id = output['OutputValue']
    elif output['OutputKey'] == 'privatesubnet':
        private_subnet_id = output['OutputValue']
        
fsx_response = fsx_client.describe_file_systems(
    FileSystemIds=[fsx_id]
)

fsx_mount = fsx_response['FileSystems'][0]['LustreConfiguration']['MountName']

print("FSx ID:", fsx_id)
print("Security Group ID:", sec_group)
print("Private Subnet ID:", private_subnet_id)
print("FSx Mount path:", fsx_mount)

In [None]:
# Create a data repository association with S3 to load data
# and persist changes back to S3 to save training artifacts

fsx_s3_response = fsx_client.create_data_repository_association(
    FileSystemId=fsx_id,
    FileSystemPath=f"/{fsx_file_system_path}",
    DataRepositoryPath=s3_data_bucket,
    BatchImportMetaDataOnCreate=True,
    S3={
        "AutoImportPolicy": {
            "Events": ['NEW', 'CHANGED', 'DELETED']
        }
    }
)

fsx_s3_response

In [None]:
# Wait for association to be complete
while True:
    fsx_s3_assoc = fsx_client.describe_data_repository_associations(
        AssociationIds=[fsx_s3_response['Association']['AssociationId']]
    )
    fsx_status = fsx_s3_assoc['Associations'][0]['Lifecycle']

    if fsx_status== "CREATING":
        print("Create in progress. Waiting..")
        time.sleep(30)
    elif fsx_status=="AVAILABLE":
        print("FSx - S3 association complete!")
        break
    else:
        print("Error creating the association, with status", fx_status)
        breaksec_group

Now, we create a SageMaker training job that uses FSx as the training input. For detailed parameters for training job, see the [CreateTrainingJob API](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html).

In [None]:
# Setup fsx config for data channels
fsx_directory_path = f"/{fsx_mount}/{fsx_file_system_path}"

fsx_input = FileSystemInput(
    file_system_id=fsx_id,
    file_system_type='FSxLustre',
    directory_path=fsx_directory_path,
    file_system_access_mode="rw", # write needed for saving model artifacts to fsx
)
data_channels = {"train": fsx_input}

# for ease, so that you can use fsx for data and training artifacts
SM_TRAIN_DIR = "/opt/ml/input/data/train"  # path where fsx is mounted in the training container
hyperparameters = {}
hyperparameters["checkpoint-dir"] = f"{SM_TRAIN_DIR}/{s3_data_checkpoint_prefix}"
hyperparameters["model-dir"] = f"{SM_TRAIN_DIR}/{s3_data_model_prefix}"
hyperparameters["training-dir"] = f"{SM_TRAIN_DIR}/{s3_data_train_prefix}"

In [None]:
# setup estimator and invoke
instance_type = "ml.m5.xlarge"
instance_count = 1
base_job_name = f'sagemaker-fsx-mount-sample'

estimator = PyTorch(
    entry_point="train.py",
    source_dir=os.getcwd(),
    instance_type=instance_type,
    role=get_execution_role(),
    instance_count=instance_count,
    framework_version="1.8.1",
    py_version="py36",
    checkpoint_s3_uri=None,  # as it is FSx
    checkpoint_local_path=hyperparameters["checkpoint-dir"],  # FSx
    hyperparameters=hyperparameters,
    base_job_name=base_job_name,
    subnets = [private_subnet_id], # Give SageMaker Training Jobs access to FSx resources in your Amazon VPC
    security_group_ids=[sec_group],
    max_retry_attempts=30)

estimator.fit(inputs=data_channels)

### Clean up resources

You can tear down the CloudFormation stack to delete the VPC and associated resources, and the FSx file system to avoid incurring costs.

In [None]:
# Delete the stack

delete_response = cfn_client.delete_stack(
    StackName=stack_id
)

delete_response