Checkout this [blog](https://aws.amazon.com/blogs/machine-learning/choose-the-best-data-source-for-your-amazon-sagemaker-training-job/) to verify if FSx is needed for your use-case to save op costs. This sample shows how to 

- Setup FSx
- Mount data from S3
- Run a SageMaker Training job using data from FSx mount
- Save artifacts into FSx which are automatically pushed to S3
- Tear down the infra

Requirements (Temporary - below need to be changed in setup/cfn-nlp.yaml)
- Existing VPC Id
- Existing IGW Id
- CIDR blocks for Public and Private Subnet


**Please make sure the CIDR block in setup/cfn-nlp.yaml does not conflict with your existing Subnets. You can change FSx storage (currently set at 1.2 TB) depending on your data sets**

In [1]:
#Inputs

#Infra
cfn_stack_name = 'large-scale-training' # cloudformation stack name
AWS_VPC_ID = 'vpc-4c715b34' # your vpc id
AWS_IGW_ID = 'igw-0b168a72' # your internet gateway id, if there is no out, modify cfn to s3 vpc endpoint
AWS_PUBLIC_SUBNET_CIDR_BLOCK = '172.31.64.0/20' #new public subnet, modify cfn if you want to use existing
AWS_PRIVATE_SUBNET_CIDR_BLOCK = '172.31.80.0/20' #new private subnet, modify cfn if you want to use existing

#Application
s3_data_bucket = 's3://nlp-largescale-training' #your s3 bucket for training artifacts
s3_data_train_prefix = 'train' # s3 training data set
fsx_file_system_path = 'gpt2' #this is file system path on FSx for the data, can be any name

Infra setup
- Setup Networking Components and FSx
- Configure FSx and add association to load data from S3

**Please make sure the region you want to use**

In [2]:
# Setup infra stack for FSx
!AWS_REGION=us-west-2 AWS_REGION_AZ=us-west-2c sh ./setup/stack-nlp.sh $cfn_stack_name $AWS_VPC_ID $AWS_IGW_ID $AWS_PUBLIC_SUBNET_CIDR_BLOCK $AWS_PRIVATE_SUBNET_CIDR_BLOCK

# Grab security grp, fsx id and private subnet from the output of CFN
tmp = !aws cloudformation describe-stacks --stack-name $cfn_stack_name --query "Stacks[0].Outputs[?OutputKey=='sg' || OutputKey=='privatesubnet' || OutputKey=='fsx'].OutputValue" --no-paginate --output text
security_grp, fsx_id, private_subnet_id = tuple(str(tmp.s).split('\t'))

# Grab the fsx mount name
tmp = !aws fsx describe-file-systems --file-system-ids $fsx_id --no-paginate --query "FileSystems[0].LustreConfiguration.MountName" --output text
fsx_mount_name = str(tmp.s)

{
    "StackId": "arn:aws:cloudformation:us-west-2:423047559695:stack/large-scale-training/eb4fd710-a170-11ec-94bf-065c86bf2b8f"
}
Creating stack [ eta 600 seconds ]
Stack large-scale-training status: Create in progress: [ 31 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 62 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 92 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 123 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 153 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 184 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 214 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 245 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 275 secs elapsed ]
Stack large-scale-training status: Create in progress: [ 306 secs elapsed ]
{
    "Stacks": [
        {
            "StackId": "arn:aws:cloudformation:us

In [3]:
# Configure FSx to load data from S3 and persist changes back to s3 to save training artifacts (model, checkpoints)
fsx_data_assoc_cmd = f'create-data-repository-association --file-system-id {fsx_id} --file-system-path /{fsx_file_system_path} --data-repository-path {s3_data_bucket} \
                        --batch-import-meta-data-on-create --s3 "AutoExportPolicy={{Events=[NEW,CHANGED,DELETED]}}"'

!aws fsx $fsx_data_assoc_cmd

# make sure the association is created and available
def get_association_status(fsx_id):
    tmp = !aws fsx describe-data-repository-associations --no-paginate --filters "Name=file-system-id,Values={fsx_id}" --query Associations[0].Lifecycle
    status = str(tmp.s)
    return status

def wait_for_assoc_complete(fsx_id):
    import time
    status = get_association_status(fsx_id)
    while status == '"CREATING"':
        print(f'Waiting for s3 association in FSx, current status {status}  ...')
        time.sleep(20)
        status = get_association_status(fsx_id)
    if status != '"AVAILABLE"':
        raise SystemExit(f'Failed to create s3 associations in FSx, failure reason : {status}')
    print(f'Association {status}.')
    
wait_for_assoc_complete(fsx_id)

{
    "Association": {
        "AssociationId": "dra-0fddbb0325a40750d",
        "ResourceARN": "arn:aws:fsx:us-west-2:423047559695:association/fs-0243c1712a9f2b4b2/dra-0fddbb0325a40750d",
        "FileSystemId": "fs-0243c1712a9f2b4b2",
        "Lifecycle": "CREATING",
        "FileSystemPath": "/gpt2",
        "DataRepositoryPath": "s3://nlp-largescale-training",
        "BatchImportMetaDataOnCreate": true,
        "ImportedFileChunkSize": 1024,
        "S3": {
            "AutoExportPolicy": {
                "Events": [
                    "NEW",
                    "CHANGED",
                    "DELETED"
                ]
            }
        },
        "Tags": [],
        "CreationTime": 1647027033.847
    }
}
Waiting for s3 association in FSx, current status "CREATING"  ...
Waiting for s3 association in FSx, current status "CREATING"  ...
Waiting for s3 association in FSx, current status "CREATING"  ...
Waiting for s3 association in FSx, current status "CREATING"  ...
Associati

Sample SageMaker Training job to showcase the parameters needed to be passed for FSx [Integration](https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html)

In [4]:
import os
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorch
from sagemaker.inputs import FileSystemInput

# setup fsx config for data channels
fsx_directory_path = f'/{fsx_mount_name}/{fsx_file_system_path}'
fsx_input = FileSystemInput(
    file_system_id=fsx_id,
    file_system_type='FSxLustre',
    directory_path=fsx_directory_path,
    file_system_access_mode="rw", # write needed for saving model artifacts to fsx
)
data_channels = {"train": fsx_input}

# for ease, so that you can use fsx for data and training artifacts
SM_TRAIN_DIR = "/opt/ml/input/data/train" #path where fsx is mounted in the training container
hyperparameters = {}
hyperparameters["checkpoint-dir"] = f"{SM_TRAIN_DIR}/checkpointdir"
hyperparameters["model-dir"] = f"{SM_TRAIN_DIR}/modeldir"
hyperparameters["training-dir"] = f"{SM_TRAIN_DIR}/{s3_data_train_prefix}"

In [None]:
# setup estimator and invoke
instance_type = "ml.p3.2xlarge"
instance_count = 1
base_job_name = f'sagemaker-fsx-mount-sample'

estimator = PyTorch(
    entry_point="train.py",
    source_dir=os.getcwd(),
    instance_type=instance_type,
    role=get_execution_role(),
    instance_count=instance_count,
    framework_version="1.8.1",
    py_version="py36",
    checkpoint_s3_uri=None, #as it is FSx
    checkpoint_local_path=hyperparameters["checkpoint-dir"], #FSx
    hyperparameters=hyperparameters,
    base_job_name=base_job_name,
    subnets = [private_subnet_id], # Give SageMaker Training Jobs access to FSx resources in your Amazon VPC
    security_group_ids=[security_grp],
    max_retry_attempts=30)

estimator.fit(inputs=data_channels)

Tear down infra, Training artifacts are uploaded to S3 from FSx

In [None]:
# Delete the stack
!aws cloudformation delete-stack --stack-name $cfn_stack_name