## Setup Configuration file path

If you have created a Managed MLflow server, copy the `ARN` code here and assign a name to the experiment

In [None]:
import os

# os.environ["AWS_PROFILE"] = "<aws_profile>"

In [None]:
import boto3
import sagemaker

In [None]:
sagemaker_session = sagemaker.Session()
bucket_name = sagemaker_session.default_bucket()
default_prefix = sagemaker_session.default_bucket_prefix

sts = boto3.client("sts")

#### Get ECR Image

In [None]:
instance_type = "ml.g5.12xlarge"
instance_count = 2

instance_type

In [None]:
account_id = sts.get_caller_identity()["Account"]
region = sagemaker_session.boto_session.region_name
repo_name = (
    "sagemaker-cluster-test"  # or "sagemaker-cc-cluster-test" for the custom container
)
tag = "latest"

image_uri = f"{account_id}.dkr.ecr.{region}.amazonaws.com/{repo_name}:{tag}"

image_uri

In [None]:
from sagemaker.modules.configs import (
    Compute,
    Networking,
    OutputDataConfig,
    SourceCode,
    StoppingCondition,
)
from sagemaker.modules.distributed import MPI
from sagemaker.modules.train import ModelTrainer

# Define the script to be run
source_code = SourceCode(
    source_dir="./scripts",
    requirements="requirements.txt",
    entry_script="train_mpi.py",
)

# Define the compute
compute_configs = Compute(
    instance_type=instance_type,
    instance_count=instance_count,
    keep_alive_period_in_seconds=1800,
)

# Define networking configurations
networking = Networking(
    subnets=["subnet-0bb1c79de3EXAMPLE", "subnet-0bb1c79de3EXAMPLE"],
    security_group_ids=["sg-0a1b2c3d4e5f6g7h8"],
)

# define Training Job Name
job_name = "train-sagemaker-cluster-test"

# define OutputDataConfig path
if default_prefix:
    output_path = f"s3://{bucket_name}/{default_prefix}/{job_name}"
else:
    output_path = f"s3://{bucket_name}/{job_name}"

# Define the ModelTrainer
model_trainer = ModelTrainer(
    training_image=image_uri,
    source_code=source_code,
    base_job_name=job_name,
    compute=compute_configs,
    # networking=networking,
    distributed=MPI(process_count_per_node=4),
    stopping_condition=StoppingCondition(max_runtime_in_seconds=18000),
    environment={
        "FI_PROVIDER": "efa",  # If instance type supports EFA
        # "FI_EFA_USE_DEVICE_RDMA": "1",  # If RDMA is supported
        "NCCL_PROTO": "simple",
        "NCCL_DEBUG": "INFO",  # See NCCL logs to confirm EFA usage
    },
    hyperparameters={
        "gpu-check": "true",
        "system-check": "true",
        "network-check": "true",
        "dcgm-level1": "true",
        "dcgm-level3": "true",
        # "export-mlflow": "true",
        # "mlflow-uri": "arn:aws:sagemaker:us-east-1:691148928602:mlflow-tracking-server/tracking-server-mlflow-3-0",
        # "mlflow-experiment-name": "sagemaker-cluster-test",
    },
    output_data_config=OutputDataConfig(
        s3_output_path=output_path,
        compression_type="NONE",
    ),
)

In [None]:
model_trainer.train(wait=False)