In [1]:
from codeflare_sdk import Cluster, ClusterConfiguration, TokenAuthentication

In [2]:
import os
import sys

In [3]:
# Authenticate the CodeFlare SDK
# On OpenShift, you can retrieve the token by running `oc whoami -t`,
# and the server with `oc cluster-info`.


auth = TokenAuthentication(
    token = os.getenv("AUTH_TOKEN"),
    server = os.getenv("API_SERVER"),
    skip_tls=True
)
auth.login()



'Logged into https://api.cluster-mpcjw.mpcjw.sandbox2589.opentlc.com:6443'

In [4]:
# Configure the Ray cluster
cluster = Cluster(ClusterConfiguration(
    name='ray',
    namespace='distributed-training-demo',
    num_workers=5,
    min_cpus=8,
    max_cpus=8,
    head_cpus=16,
    min_memory=48,
    max_memory=48,
    head_memory=48,
    head_gpus=1,
    num_gpus=1,
    image="quay.io/rhoai/ray:2.23.0-py39-cu121-torch",
    local_queue="local-queue-ray",
))

Yaml resources loaded for ray


In [5]:
# Create the Ray cluster
cluster.up()

In [6]:
cluster.wait_ready()

Waiting for requested resources to be set up...
Requested cluster is up and running!
Dashboard is ready!


In [7]:
cluster.details()

RayCluster(name='ray', status=<RayClusterStatus.READY: 'ready'>, head_cpus=16, head_mem='48G', head_gpu=1, workers=5, worker_mem_min='48G', worker_mem_max='48G', worker_cpu=8, worker_gpu=1, namespace='distributed-training-demo', dashboard='https://ray-dashboard-ray-distributed-training-demo.apps.cluster-mpcjw.mpcjw.sandbox2589.opentlc.com')

In [8]:
# Initialize the Job Submission Client
client = cluster.job_client

In [11]:
# Create the training and evaluation datasets.
# This can be run only once.
!{sys.executable} -m pip install datasets
import create_dataset
create_dataset.main()


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.2.2[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [12]:
# The S3 bucket where to store checkpoint.
# It can be set manually, otherwise it's retrieved from configured the data connection.
s3_bucket = ""
if not s3_bucket:
    s3_bucket = os.environ.get('AWS_S3_BUCKET')
assert s3_bucket, "An S3 bucket must be provided to store checkpoints"

In [13]:
submission_id = client.submit_job(
    entrypoint="python ray_finetune_llm_deepspeed.py "
               "--model-name=meta-llama/Llama-2-7b-chat-hf "
               "--lora "
               "--num-devices=2 "
               "--num-epochs=3 "
               "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json "
               f"--storage-path={s3_bucket}/ray_finetune_llm_deepspeed/ "
               "--batch-size-per-device=16 "
               "--eval-batch-size-per-device=32 ",
    runtime_env={
        "env_vars": {
            "AWS_ACCESS_KEY_ID": os.environ.get('AWS_ACCESS_KEY_ID'),
            "AWS_SECRET_ACCESS_KEY": os.environ.get('AWS_SECRET_ACCESS_KEY'),
            "AWS_DEFAULT_REGION": 'us-east-1'
        },
        "pip": "requirements.txt",
        "working_dir": "./",
        "excludes": ["/docs/", "*.ipynb", "*.md", "/tmp/"]
    },
)
print(submission_id)

2024-07-23 13:36:46,995	INFO dashboard_sdk.py:338 -- Uploading package gcs://_ray_pkg_a3b1460ee5c5cb6e.zip.
2024-07-23 13:36:46,996	INFO packaging.py:530 -- Creating a file package for local directory './'.


raysubmit_KU9TGTtQbv6Le2gY


In [14]:
client.stop_job(submission_id)

True

In [15]:
cluster.down()