In [27]:
!pip install -U sagemaker



In [28]:
import sagemaker
from sagemaker import get_execution_role

role = get_execution_role()

In [29]:
region = 'us-east-1'

In [30]:
import boto3

boto_session = boto3.Session(region_name=region)
sagemaker_session = sagemaker.Session(boto_session=boto_session)

In [31]:
s3_bucket = sagemaker_session.default_bucket()
s3_folder = "gpt2_deepspeed_training"
dataset_folder = "datasets"
dataset_name = "wikitext"
resume = True

In [None]:
secrets = boto3.client("secretsmanager")
wandb_api_key = secrets.get_secret_value(SecretId="wandb_api_key")["SecretString"]

In [36]:
from sagemaker.pytorch.estimator import PyTorch

estimator = PyTorch(
    entry_point='main.py',
    source_dir='gpt2_distributed_training/',
    role=role,
    instance_count=2,
    instance_type='ml.g5.xlarge',
    output_path=f's3://{s3_bucket}/{s3_folder}/sagemaker_output/',
    checkpoint_s3_uri=f"s3://{s3_bucket}/{s3_folder}/outputs",
    checkpoint_local_path="/opt/ml/checkpoints",
    dependencies=["gpt2_distributed_training/requirements.txt"],
    distribution={
        "torch_distributed": {
            "enabled": True,
            "processes_per_host": 1,     # 1 GPU on each ml.g4dn.xlarge
        }
    },
    environment={
        'HYDRA_FULL_ERROR': '1',
        "WANDB_API_KEY": wandb_api_key,
    },
    hyperparameters={
        "trainer": "deepspeed",
        "dataset.train_dataset.data_root": f"/opt/ml/input/data/{dataset_name}",
        "dataset.val_dataset.data_root": f"/opt/ml/input/data/{dataset_name}",
        "trainer.resume": resume,
    },
    py_version='py311',
    framework_version='2.5.1',
    # enable_sagemaker_metrics=True,
)

In [None]:
estimator.fit({
    'wikitext': f's3://{s3_bucket}/{dataset_folder}/{dataset_name}'
})