In [None]:
# install once
!pip install -U boto3 sagemaker awscli
# restart jupyter kernel

In [None]:
import sagemaker
import boto3, os
from sagemaker import get_execution_role

sess = sagemaker.Session()
role = get_execution_role()
sagemaker_default_bucket = sess.default_bucket()
region = sess.boto_session.region_name

In [None]:
from sagemaker.pytorch import PyTorch
from sagemaker.estimator import Estimator

# https://github.com/aws/deep-learning-containers/blob/master/available_images.md
# image_uri = f'763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker'
image_uri = f'763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker'

instance_type = "ml.g5.2xlarge"    # 1 * A10g (24G/GPU)
# instance_type = "ml.g5.12xlarge"     # 4 * A10g (24G/GPU)
# instance_type = "ml.g5.48xlarge"    # 8 * A10g (24G/GPU)
# instance_type = "ml.p4d.24xlarge"   # 8 * A100 (40G/GPU)
# instance_type = "ml.p5.48xlarge"    # 8 * H100 (80G/GPU)

instance_count = 1                  # 1 or Multi-node

envs = {
    "DATA_S3_PATH": f's3://{sagemaker_default_bucket}/qwen2-train-dataset/*',
    'MODEL_ID_OR_S3_PATH': f's3://{sagemaker_default_bucket}/Qwen2-0.5B-Instruct/*',
    'MODEL_SAVE_PATH_S3': f's3://{sagemaker_default_bucket}/output-model/2408/'
}

hypers = {
}

smp_estimator = Estimator(role=role,
    sagemaker_session=sess,
    base_job_name='sm-qwen2-multinode',
    entry_point="estimator_entry.py",
    source_dir='submit_src/',
    instance_type=instance_type,
    instance_count=instance_count,
    environment=envs,
    hyperparameters=hypers,
    image_uri=image_uri,
    max_run=7200,
    keep_alive_period_in_seconds=60,
    enable_remote_debug=True,
    disable_output_compression=True,
)

smp_estimator.fit()

In [None]:
!echo s3://$sagemaker_default_bucket/output-model/2408/

In [None]:
!aws s3 ls s3://$sagemaker_default_bucket/output-model/2408/

In [None]:
!aws s3 rm --recursive s3://$sagemaker_default_bucket/output-model/2408/checkpoint-10/