In [None]:
# For matplotlib 3.8.4 see https://discourse.matplotlib.org/t/latest-versions-via-pip-jupyterlab-import-of-matplotlib-broken/24477
!pip install -U ray==2.7.0 codeflare_sdk==0.16.3 huggingface_hub==0.16.2 matplotlib==3.8.4

Restart the notebook kernel after pip install

In [None]:
import os

work_dir=os.getcwd()
print(f"Current working directory: {work_dir}")

In [None]:
# Determine the images path mounted inside the Ray job
from pathlib import Path

original_path = Path(work_dir)

# Prefix to remove
prefix = Path('/opt/app-root/src')
new_path = original_path.relative_to(prefix)
ray_job_images_dir = f'/home/ray/share/{new_path}/images'

print(f"Path in ray job: {ray_job_images_dir}")

In [None]:
# A Ray cluster is created with 4 workers with a single GPU each that has a RWX volume mounted across the cluster. 
!oc apply -f {work_dir}/raycluster.yaml

In [None]:
# Get the Ray dashbord. This is authenticated by OpenShift OAuth.
ray_dashboard=!oc get routes ray-dashboard-stable-diffusion -o jsonpath='{.spec.host}'
print(f'Ray dashboard: https://{ray_dashboard.n}')

In [None]:
!git clone https://github.com/ray-project/ray.git
!(cd ray && git checkout 179b3d5ed6) #The HEAD rev when this notebook was written

In [None]:
!cp -r {work_dir}/ray/doc/source/templates/05_dreambooth_finetuning {work_dir}//dreambooth_finetuning
!rm -rf {work_dir}/ray

In [None]:
# Prepare the directories expected by the stable diffusion fine-tuning example. 
# Note: These directories are created in a shared filesystem, which would be also accesible on ray head and worker pods
# These are mounted as /home/ray/share in ray nodes
!mkdir -p {work_dir}/images/images-own
!mkdir -p {work_dir}/images/images-reg
!mkdir -p {work_dir}/images/images-new
!mkdir -p {work_dir}/images/model-orig
!mkdir -p {work_dir}/images/model-tuned

!chmod 777 -R {work_dir}/images

In [None]:
# Download the dog image dataset
!python {work_dir}/dreambooth_finetuning/dreambooth/download_example_dataset.py {work_dir}/images/images-own/

In [None]:
# Download the model checkpoint from huggingface hub and cache it a shared directory in the shared filesystem
!python {work_dir}/dreambooth_finetuning/dreambooth/cache_model.py --model_dir={work_dir}/images/model-orig --model_name=CompVis/stable-diffusion-v1-4 --revision=b95be7d6f134c3a9e62ee616f310733567f069ce

In [None]:
from codeflare_sdk import generate_cert

cluster_name = "stable-diffusion"
namespace = "distributed"

# Create required TLS cert and export the environment variables to enable TLS
generate_cert.generate_tls_cert(cluster_name, namespace)
generate_cert.export_env(cluster_name, namespace)

In [None]:
# Connect to the ray cluster
import ray
ray.shutdown()

from ray.job_submission import JobSubmissionClient
client = JobSubmissionClient("ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001")

In [None]:
# https://docs.ray.io/en/latest/train/examples/pytorch/dreambooth_finetuning.html#step-3-create-the-regularization-images
# create a ray job which prepares the regularization data. Ray worker pods have mounted the shared filesystem at /home/ray/share. 
#### NOTE:
# The default number of workers used by generate.py is 4. If you do not have sufficient number of GPUs in the cluster please update the generate.py to use lesser number of workers at line https://github.com/ray-project/ray/blob/master/doc/source/templates/05_dreambooth_finetuning/dreambooth/generate.py#L51
# Update the file `dreambooth_finetuning/dreambooth/generate.py`

job_id = client.submit_job(
    entrypoint=f"python generate.py --model_dir='{ray_job_images_dir}/model-orig/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce' --output_dir='{ray_job_images_dir}/images-reg' --prompts='photo of a dog' --num_samples_per_prompt=100 --use_ray_data",
    # Path to the local directory that contains the entrypoint file.
    runtime_env={
        "working_dir": f"{work_dir}/dreambooth_finetuning/dreambooth", # upload the local working directory to ray workers
        "pip": ["accelerate==0.20.3",
                "bitsandbytes==0.39.1",
                "diffusers==0.19.3",
                "flax==0.6.11",
                "jax==0.4.13",
                "jaxlib==0.4.13",
                "ipywidgets",
                "huggingface_hub==0.16.2",
                "numpy==1.24.4",
                "torch==2.0.1",
                "torchvision==0.15.2",
                "transformers==4.30.2"]
    }
)
print("jobid:", job_id)

In [None]:
# Replace with the jobid from the previous cell
!ray job status {job_id}  --address "ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001" 

#!ray job stop {job_id}  --address "ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001" 
#!ray job delete {job_id}  --address "ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001" 

In [None]:
# https://docs.ray.io/en/latest/train/examples/pytorch/dreambooth_finetuning.html#step-4-fine-tune-the-model
# Start the distributed ray train job on the ray cluster. This specific example would use 4 ray workers (each worker using 1 GPU)
job_id = client.submit_job(
    entrypoint=f"python train.py --model_dir='{ray_job_images_dir}/model-orig/models--CompVis--stable-diffusion-v1-4/snapshots/b95be7d6f134c3a9e62ee616f310733567f069ce' --output_dir='{ray_job_images_dir}/model-tuned' --instance_images_dir='{ray_job_images_dir}/images-own' --instance_prompt='photo of unqtkn dog' --class_images_dir='{ray_job_images_dir}/images-reg' --class_prompt='photo of a dog' --train_batch=2 --lr=5e-6 --num_epochs=2 --max_train_steps=100 --num_workers=4",
    runtime_env={
        "working_dir": f"{work_dir}/dreambooth_finetuning/dreambooth",
        "pip": ["accelerate==0.20.3",
                "bitsandbytes==0.39.1",
                "diffusers==0.19.3",
                "flax==0.6.11",
                "jax==0.4.13",
                "jaxlib==0.4.13",
                "ipywidgets",
                "huggingface_hub==0.16.2",
                "numpy==1.24.4",
                "torch==2.0.1",
                "torchvision==0.15.2",
                "transformers==4.30.2"]
    }
)
print("jobid:", job_id)

In [None]:
!ray job status {job_id}  --address "ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001" 

In [None]:
# https://docs.ray.io/en/latest/train/examples/pytorch/dreambooth_finetuning.html#step-5-generate-images-of-the-subject
# Do inference with the tuned model

prompt = 'photo of a unqtkn dog in a sand castle'
num_samples = 9

job_id = client.submit_job(
    entrypoint=f"python generate.py --model_dir='{ray_job_images_dir}/model-tuned'  --output_dir='{ray_job_images_dir}/images-new' --prompts='{prompt}' --num_samples_per_prompt={num_samples} --use_ray_data",
    runtime_env={
        "working_dir": f"{work_dir}/dreambooth_finetuning/dreambooth",
        "pip": ["accelerate==0.20.3",
                "bitsandbytes==0.39.1",
                "diffusers==0.19.3",
                "flax==0.6.11",
                "jax==0.4.13",
                "jaxlib==0.4.13",
                "ipywidgets",
                "huggingface_hub==0.16.2",
                "numpy==1.24.4",
                "torch==2.0.1",
                "torchvision==0.15.2",
                "transformers==4.30.2"]
    }
)
print("jobid:", job_id)

In [None]:
!ray job status {job_id}  --address "ray://stable-diffusion-head-svc.distributed.svc.cluster.local:10001" 

In [None]:
!ls {work_dir}/images/images-new

In [None]:
# If matplotlib raises an exception of invalid backend or no attribute 'colors'
# Try restarting the workbench or the notebook kernel

# Enable inline plotting
%matplotlib inline

# Required imports
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os

# Function to display images in a grid
def display_images_in_grid(image_paths, grid_shape):
    fig, axes = plt.subplots(grid_shape[0], grid_shape[1], figsize=(15, 15))
    axes = axes.ravel()  # Flatten the array of axes for easy iteration

    for idx, img_path in enumerate(image_paths):
        if idx < len(axes):  # Make sure there are enough axes
            img = mpimg.imread(img_path)
            axes[idx].imshow(img)
            axes[idx].axis('off')  # Hide the axes
        else:
            break  # Stop if there are more images than grid spaces

    plt.tight_layout()
    plt.show()

In [None]:
image_directory = f'{work_dir}/images/images-new'
image_paths = [os.path.join(image_directory, f) for f in os.listdir(image_directory) if f.endswith('.jpg')]

# Define the grid shape (rows, columns)
grid_shape = (3, 3)  # For a 3x3 grid

display_images_in_grid(image_paths, grid_shape)

In [None]:
!oc delete -f {work_dir}/raycluster.yaml