In [None]:
import os
#from azure.ai.ml import MLClient, Input, MpiDistribution, command
from azure.ai.ml import MLClient, Input, Output, PyTorchDistribution, command
from azure.ai.ml.entities import AmlCompute, Environment, BuildContext, Data
from azure.identity import DefaultAzureCredential
from azure.ai.ml.constants import AssetTypes
from dotenv import load_dotenv
load_dotenv(override=True)

# Azure ML workspace configuration
SUBSCRIPTION_ID = os.getenv("SUBSCRIPTION_ID")
RESOURCE_GROUP = os.getenv("RESOURCE_GROUP")
WORKSPACE_NAME = os.getenv("WORKSPACE_NAME")
COMPUTE_CLUSTER = "demo-gpucluster01"

# authentication via managed identity or service principal (no hard-coded creds)
ml_client = MLClient(DefaultAzureCredential(), SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE_NAME)

# ensure compute cluster exists or create it
try:
    ml_client.compute.get(COMPUTE_CLUSTER)
except Exception:
    print("demo-gpucluster01 was not found")

In [None]:
# job configuration
NUM_NODES = 1
NUM_GPU_PER_NODE = 1

# define distributed training job
dist = PyTorchDistribution(
    process_count_per_instance=NUM_GPU_PER_NODE,
    node_count=NUM_NODES
)

job = command(
    code="./azureml",
    command=(
        "python megatron_lm/tools/preprocess_data.py \
        --input ${{inputs.train_data}} \
        --output-prefix ${{outputs.indexed}}/wikidump \
        --tokenizer-type Llama2Tokenizer \
        --tokenizer-model ${{inputs.model_dir}} \
        --workers 1 && "
        "cp ${{inputs.train_data}} ${{outputs.indexed}} && "
        "mv ${{outputs.indexed}}/wikidump_text_document.bin ${{outputs.indexed}}/wikidump.jsonl.bin && "
        "mv ${{outputs.indexed}}/wikidump_text_document.idx ${{outputs.indexed}}/wikidump.jsonl.idx"
    ),
    inputs={
        "train_data": Input(
            type=AssetTypes.URI_FILE, 
            path="wiki_dump@latest"
        ),
        "model_dir": Input(
            type=AssetTypes.URI_FOLDER, 
            path="llama3-8b@latest"
        )
    },
    outputs={
        "indexed": Output(
            type=AssetTypes.URI_FOLDER,
            path="azureml://datastores/workspaceblobstore/paths/wiki-indexed-dataset1/",
            mode="rw_mount"
        )                      # 次のジョブからマウント可能
    },
    environment="llama3-8b-wiki_env@latest",
    compute=COMPUTE_CLUSTER,
    instance_count=NUM_NODES,
    distribution=dist,
    environment_variables={
        "LOGLEVEL": "INFO",
        "NCCL_DEBUG": "WARN",
        "NCCL_DEBUG_SUBSYS": "WARN",
        "PYTHONFAULTHANDLER": "1",
        "CUDA_LAUNCH_BLOCKING": "0"
    },
    display_name="llama3-8b-wiki-index",
    experiment_name="llama3-8b-wiki-index-exp"
)

In [None]:
# submit the job
returned_job = ml_client.jobs.create_or_update(job)
print(f"Job submitted: {returned_job.name}")
print(f"Monitor at: {returned_job.studio_url}")