In [None]:
%%time
! python3 -m pip install --upgrade sagemaker
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session(default_bucket='cvpr-derrick')
bucket = 'cvpr-derrick'
role = sagemaker.get_execution_role()
role_name = role.split(["/"][-1])
print(f"The Amazon Resource Name (ARN) of the role used for this demo is: {role}")
print(f"The name of the role used for this demo is: {role_name[-1]}")

In [None]:
# Setting up File-system to import data from S3
train_data_s3 = 's3://{}/{}'.format(bucket, 'ImageNet/lmdb')
train_output_s3 = 's3://{}/{}'.format(bucket, 'Output_MAE')
data_channels = {'train': sagemaker.inputs.TrainingInput(
                                    s3_data_type='S3Prefix',
                                    s3_data=train_data_s3,
                                    input_mode='File')}
print(data_channels, train_output_s3)

In [None]:
instance_type = "ml.p4d.24xlarge"  # Other supported instance type: ml.p3.16xlarge, ml.p4d.24xlarge
instance_count = 2  # You can use 2, 4, 8 etc.
docker_image = "509553677659.dkr.ecr.us-east-1.amazonaws.com/derrick-smdataparallel-sagemaker:1.0"  # YOUR_ECR_IMAGE_BUILT_WITH_ABOVE_DOCKER_FILE
job_name = "MAE-pytorch-2node"  # This job name is used as prefix to the sagemaker training job. Makes it easy for your look for your training job in SageMaker Training job console.

In [None]:
ONE_DAY = 24 * 60 * 60
hyperparameters = {
    "normlize_target" : True,
    "mask_ratio" : 0.75,
    "model" : "pretrain_mae_base_patch16_224",
    "batch_size" : 512,
    "num_workers" : 8,
    "opt" : "adamw",
    "warmup_epochs" : 40,
    "epochs" : 800,
    "save_ckpt_freq" : 80,
    "output_dir" : "pretrain/mae_800ep_bs8192_base_size224_patch16_mask75_decdepth8_decdim512_norm_pos2d_mmseg"
}

In [None]:
import os
from sagemaker.pytorch import PyTorch
estimator = PyTorch(
    base_job_name=job_name,
    source_dir="./",
    entry_point="main_unsup.py",
    role=role,
    image_uri=docker_image,
    max_run=ONE_DAY * 5,
    instance_count=instance_count,
    instance_type=instance_type,
    framework_version="1.9.1",
    py_version="py38",
    sagemaker_session=sagemaker_session,
    hyperparameters=hyperparameters,
    output_path=train_output_s3,
    debugger_hook_config=False,
    # Training using SMDataParallel Distributed Training Framework
    distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
)

In [None]:
estimator.fit(inputs=data_channels)

In [None]:
model_data = estimator.model_data
print("Storing {} as model_data".format(model_data))
%store model_data