In [36]:
from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration
from codeflare_sdk.cluster.auth import TokenAuthentication
import os
import sys

In [None]:
# Create the training and evaluation datasets.
# This can be run only once.
!{sys.executable} -m pip install datasets
import create_dataset
create_dataset.main()

In [None]:
# Authenticate the CodeFlare SDK
# On OpenShift, you can retrieve the token by running `oc whoami -t`,
# and the server with `oc cluster-info`.
auth = TokenAuthentication(
    token = "",
    server = "",
    skip_tls=False
)
auth.login()

In [None]:
# Configure the Ray cluster
cluster = Cluster(ClusterConfiguration(
    name='ray',
    namespace='ray-finetune-llm-deepspeed',
    num_workers=7,
    min_cpus=16,
    max_cpus=16,
    head_cpus=16,
    min_memory=128,
    max_memory=256,
    head_memory=128,
    head_gpus=1,
    num_gpus=1,
    image="quay.io/rhoai/ray:2.23.0-py311-cu121-torch",
))

In [30]:
# Create the Ray cluster
cluster.up()

In [None]:
cluster.wait_ready()

In [None]:
cluster.details()

In [40]:
# Initialize the Job Submission Client
client = cluster.job_client

In [None]:
# The S3 bucket where to store checkpoint.
# It can be set manually, otherwise it's retrieved from configured the data connection.
s3_bucket = ""
if not s3_bucket:
    s3_bucket = os.environ.get('AWS_S3_BUCKET')
assert s3_bucket, "An S3 bucket must be provided to store checkpoints"

In [None]:
submission_id = client.submit_job(
    entrypoint="python ray_finetune_llm_deepspeed.py "
               "--model-name=meta-llama/Llama-2-7b-chat-hf "
               "--lora "
               "--num-devices=8 "
               "--num-epochs=3 "
               "--ds-config=./deepspeed_configs/zero_3_llama_2_7b.json "
               f"--storage-path=s3://{s3_bucket}/ray_finetune_llm_deepspeed/ "
               "--batch-size-per-device=32 "
               "--eval-batch-size-per-device=32 ",
    runtime_env={
        "env_vars": {
            "AWS_ACCESS_KEY_ID": os.environ.get('AWS_ACCESS_KEY_ID'),
            "AWS_SECRET_ACCESS_KEY": os.environ.get('AWS_SECRET_ACCESS_KEY'),
            "AWS_DEFAULT_REGION": os.environ.get('AWS_DEFAULT_REGION')
        },
        "pip": "requirements.txt",
        "working_dir": "./",
        "excludes": ["/docs/", "*.ipynb", "*.md"]
    },
)
print(submission_id)

In [None]:
client.stop_job(submission_id)

In [42]:
cluster.down()