In [None]:
import os
#from azure.ai.ml import MLClient, Input, MpiDistribution, command
from azure.ai.ml import MLClient, Input, Output, PyTorchDistribution, command
from azure.ai.ml.entities import AmlCompute, Environment, BuildContext, Data
from azure.identity import DefaultAzureCredential
from azure.ai.ml.constants import AssetTypes
from dotenv import load_dotenv
load_dotenv(override=True)

# Azure ML workspace configuration
SUBSCRIPTION_ID = os.getenv("SUBSCRIPTION_ID")
RESOURCE_GROUP = os.getenv("RESOURCE_GROUP")
WORKSPACE_NAME = os.getenv("WORKSPACE_NAME")
COMPUTE_CLUSTER = "demo-gpucluster01"
# Wandb Settings
WANDB_API_KEY = os.getenv("WANDB_API_KEY")
WANDB_ENTITY = os.getenv("WANDB_ENTITY")

# authentication via managed identity or service principal (no hard-coded creds)
ml_client = MLClient(DefaultAzureCredential(), SUBSCRIPTION_ID, RESOURCE_GROUP, WORKSPACE_NAME)

# ensure compute cluster exists or create it
try:
    ml_client.compute.get(COMPUTE_CLUSTER)
except Exception:
    print("demo-gpucluster01 was not found")

#### Docker environment

In [None]:
CLOUD_DIR = "./azureml"

In [None]:
%%writefile {CLOUD_DIR}/train/Dockerfile
FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu121-py310-torch22x:biweekly.202504.1

# Install pip dependencies
COPY requirements.txt .
RUN pip install -r requirements.txt --no-cache-dir

# Inference requirements
COPY --from=mcr.microsoft.com/azureml/o16n-base/python-assets:20230419.v1 /artifacts /var/
RUN /var/requirements/install_system_requirements.sh && \
    cp /var/configuration/rsyslog.conf /etc/rsyslog.conf && \
    cp /var/configuration/nginx.conf /etc/nginx/sites-available/app && \
    ln -sf /etc/nginx/sites-available/app /etc/nginx/sites-enabled/app && \
    rm -f /etc/nginx/sites-enabled/default
ENV SVDIR=/var/runit
ENV WORKER_TIMEOUT=400
EXPOSE 5001 8883 8888

# support Deepspeed launcher requirement of passwordless ssh login
RUN apt-get update
RUN apt-get install -y openssh-server openssh-client


In [None]:
%%writefile {CLOUD_DIR}/train/requirements.txt
azureml-core==1.60.0
azureml-dataset-runtime==1.60.0
azureml-defaults==1.60.0
azure-ml==0.0.1
azure-ml-component==0.9.18.post2
azureml-mlflow==1.60.0
azureml-contrib-services==1.60.0
azureml-contrib-services==1.60.0
torch-tb-profiler~=0.4.0
azureml-inference-server-http
inference-schema
MarkupSafe==2.1.2
regex
pybind11
urllib3>=1.26.18
cryptography>=42.0.4
aiohttp>=3.8.5
py-spy==0.3.12
debugpy~=1.6.3
ipykernel~=6.0
tensorboard
psutil~=5.8.0
matplotlib~=3.5.0
tqdm~=4.66.3
py-cpuinfo==5.0.0
torch-tb-profiler~=0.4.0
# huggingface
transformers>=4.36.0
datasets
accelerate
optimum
peft
appdirs
loralib
scipy
py7zr  # 圧縮解凍library
bitsandbytes
fire  # argparser
# formatter & linter
black
flake8
# tokenizer
sentencepiece
# logging
wandb
# multi node
mpi4py
# megatron-lm
nltk
pybind11

In [None]:
env_name = "llama3-8b-wiki_env"
docker_dir=f"{CLOUD_DIR}/train"

env_docker_image = Environment(
    build=BuildContext(path=docker_dir),
    name=env_name,
    description="Environment created from a Docker context.",
)
env_asset = ml_client.environments.create_or_update(env_docker_image)

<h5> Register the training dataset </h5>

In [None]:
data = Data(
    path="azureml://datastores/workspaceblobstore/paths/wiki-indexed-dataset1/",
    type = AssetTypes.URI_FOLDER,
    description = "wiki dump data for pretraining",
    name = "wiki_dump_01",
    version = '1'
)

ml_client.data.create_or_update(data)

In [None]:
# job configuration
NUM_NODES = 2
NUM_GPU_PER_NODE = 1
CACHE_DIR = "${{outputs.cache}}"

dist = PyTorchDistribution(
    process_count_per_instance=NUM_GPU_PER_NODE,
    node_count=NUM_NODES
)

job = command(
    code="./azureml",
    command=(
        # C++ extension build
        "rm -f megatron_lm/megatron/core/datasets/helpers_cpp*.so && "
        "make -C megatron_lm/megatron/core/datasets && "
        f"mkdir -p {CACHE_DIR} && "
        # Run training
        f"python examples/finetuning.py \
        --fsdp-cpu-offload \
        --fsdp-activation-checkpointing \
        --low-cpu-fsdp \
        --bf16 \
        --epoch 3 \
        --base-model ${{inputs.model_dir}} \
        --tokenizer-type Llama2Tokenizer \
        --tokenizer-model ${{inputs.model_dir}} \
        --global-batch-size 128 \
        --micro-batch-size 8 \
        --min-lr 1e-5 \
        --lr 1e-4 \
        --lr-warmup-iters 0 \
        --lr-decay-style cosine \
        --train-iters 8 \
        --lr-decay-iters 1 \
        --weight-decay 0.1 \
        --train-data-path ${{inputs.train_data}}/wikidump.jsonl \
        --data-cache-path {CACHE_DIR}  \
        --seq-length 4096 \
        --sliding-window-size 4096 \
        --num-workers 2 \
        --save-interval 100 \
        --save ./outputs/checkpoints \
        --load ./outputs/checkpoints \
        --use-better-transformer \
        --wandb-entity {WANDB_ENTITY} \
        --wandb-project llama3-8b-test \
        --wandb-name llama3-8b-wiki_dataset"
    ),
    inputs={
        "train_data": Input(
            type=AssetTypes.URI_FOLDER, 
            path="wiki_dump_01@latest",
            mode="ro_mount"
        ),
        "model_dir": Input(
            type=AssetTypes.URI_FOLDER, 
            path="llama3-8b@latest",
            mode="ro_mount"
        )
    },
    outputs={
        "cache": Output(
            type=AssetTypes.URI_FOLDER,
            mode="rw_mount" 
        )
    },
    environment="llama3-8b-wiki_env@latest",
    compute=COMPUTE_CLUSTER,
    instance_count=NUM_NODES,
    distribution=dist,
    environment_variables={
        "MEGATRON_CACHE" : CACHE_DIR, 
        "NCCL_IB_DISABLE": "1", 
        "LOGLEVEL": "INFO",
        "NCCL_DEBUG": "INFO",
        "NCCL_DEBUG_SUBSYS": "WARN",
        "PYTHONFAULTHANDLER": "1",
        "CUDA_LAUNCH_BLOCKING": "0",
        "WANDB_MODE": "online",
        "WANDB_API_KEY": WANDB_API_KEY
    },
    display_name="llama3-8b-wiki-pretrain",
    experiment_name="llama3-8b-wiki-exp"
)

In [None]:
# submit the job
returned_job = ml_client.jobs.create_or_update(job)
print(f"Job submitted: {returned_job.name}")
print(f"Monitor at: {returned_job.studio_url}")