# Example notebook for using Ray with Amazon SageMaker Training Jobs

This notebook describes my best practices for using Ray distributed training in Amazon SageMaker training jobs.

The cell below shows the basic structure you should use for your training code. It shows:

* How to apply logging
* How to set up MLflow
* How to detect and set the workers and GPUs

Remember, we use PyTorch Lightning as our preferred framework.

In [None]:
%%writefile ../scripts/train.py
import os
import json
import logging

import ray
import mlflow
import mlflow.pytorch
import sagemaker_training.environment


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


def train_func(config):
    data_path = os.environ.get("SM_CHANNEL_TRAIN", ".")
    model_path = os.environ.get("SM_MODEL_DIR", "./model/")

    # Continue your training code here


def setup_workers(env):
    """Configure worker settings based on available resources."""
    num_gpus = int(ray.available_resources().get("GPU", 0))
    num_cpus = int(ray.available_resources().get("CPU", env.num_cpus))
    logger.info(f"Found {num_gpus} GPUs, {num_cpus} CPUs")

    if env.is_hetero:
        logger.info("Heterogeneous cluster detected")
        all_hosts = []
        for instance_group in env.instance_groups_dict.values():
            if instance_group["instance_group_name"] != env.current_instance_group:
                group_hosts = instance_group["hosts"]
                all_hosts.extend(group_hosts)
        # Multi-node vs single-node setup
        num_workers = num_gpus if num_gpus > 0 else len(all_hosts)
    else:
        logger.info("Homogeneous cluster detected")
        # Multi-node vs single-node setup
        num_workers = num_gpus if num_gpus > 0 else len(env.hosts)

    logger.info(f"Number of workers: {num_workers}")

    return num_workers, num_gpus


def main():
    logger.info("Fetching parameters")
    hyperparams = json.loads(os.environ["SM_HPS"])
    env = sagemaker_training.environment.Environment()
    num_workers, num_gpus = setup_workers(env)
    
    logger.info("Initializing MLflow")
    mlflow.enable_system_metrics_logging()
    mlflow.autolog()
    try:
        mlflow_arn = str(os.environ.get("MLFLOW_TRACKING_ARN"))
        mlflow_experiment = str(os.environ.get("MLFLOW_EXPERIMENT_NAME"))
    except Exception as error:
        logger.error(f"Can't fetch MLflow details: {error}", exc_info=True)
        raise
    mlflow.set_tracking_uri(mlflow_arn)
    mlflow.set_experiment(mlflow_experiment)
    
    logger.info("Starting training run")
    with mlflow.start_run():
        # Continue your code here


if __name__ == "__main__":
    try:
        main()
    except Exception as error:
        logger.error(f"Training failed: {error}", exc_info=True)
        raise

Since we use Amazon SageMaker Unified Studio, our project has a designated S3 bucket and prefix which we need to use. Note: SageMaker Sessions are also different between SageMaker AI Studio and SageMaker Unified Studio! The correct import for Unified Studio is shown in the cell below. 

We use the PyTorch 2.7.1 training image by default.

In [None]:
from sagemaker.modules import Session

s3_uri = project.s3.root
bucket, prefix = s3_uri.replace('s3://', '').split('/', 1)
sagemaker_session = Session(
    default_bucket=bucket,
    default_bucket_prefix=prefix
)

instance_type = "ml.g4dn.xlarge"

image_uri = sagemaker.image_uris.retrieve(
    framework="pytorch",
    region=sagemaker_session.boto_session.region_name,
    version="2.7.1",
    instance_type=instance_type,
    image_scope="training",
)

Although there are multiple ways to create a SageMaker training job, we prefer using the newer ModelTrainer class. See example below.

In [None]:
from sagemaker.modules.configs import (
    Compute,
    OutputDataConfig,
    SourceCode,
    StoppingCondition,
    InputData,
)
from sagemaker.modules.train import ModelTrainer
from datetime import datetime

timestamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
job_name = f"ray-training-job-{timestamp}"
output_path = f"{project.s3.root}/{job_name}"

# Define the source code configuration
source_code = SourceCode(
    source_dir="../scripts",
    requirements="requirements.txt",
    entry_script="launcher.py",
)

# Define compute configuration
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=3,
    keep_alive_period_in_seconds=0,
)

# Create the ModelTrainer
model_trainer = ModelTrainer(
    sagemaker_session=sagemaker_session,
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    environment={
        "entry_script": "train.py",
        "MLFLOW_TRACKING_ARN": mlflow_server_arn,
        "MLFLOW_EXPERIMENT_NAME": "initial-tests"
    },
    hyperparameters={
        "n_factors": 10,
        "sample_size": 500,
        "epochs": 20,
    },
    stopping_condition=StoppingCondition(max_runtime_in_seconds=10800),
    output_data_config=OutputDataConfig(s3_output_path=output_path),
)