# Fine stable diffusion XL

In [None]:
%%capture
import ads
import oci
import os
import ocifs
from oci.object_storage import ObjectStorageClient
from datetime import datetime, timedelta

ads.set_auth(auth='resource_principal')
rps = oci.auth.signers.get_resource_principals_signer()

# Object storage client
object_storage_client = ObjectStorageClient(config={}, signer=rps)
OBJECT_STORAGE_NAMESPACE = object_storage_client.get_namespace().data

os.environ['namespace'] = OBJECT_STORAGE_NAMESPACE
LOG_GROUP_ID = os.environ['loggroup_ocid']
LOG_ID = os.environ['log_ocid']
BUCKET_NAME = os.environ['bucket_name']

INPUT_FOLDER = "oci://{bucket}@{namespace}/sdxl/input/".format(bucket=BUCKET_NAME, namespace=OBJECT_STORAGE_NAMESPACE)
OUTPUT_FOLDER = "oci://{bucket}@{namespace}/sdxl/output/".format(bucket=BUCKET_NAME, namespace=OBJECT_STORAGE_NAMESPACE)
CONDA_ENV = "onnx110_p39_cpu_v1"

In [None]:
if not os.path.exists("/home/datascience/job_artifact/kohya_ss"):
    !git clone https://github.com/bmaltais/kohya_ss.git /home/datascience/job_artifact/kohya_ss

!cp /home/datascience/repos/carlgira/oci-tf-odsc-sdxl/app/stable_main.py /home/datascience/job_artifact/

In [None]:
from ads.jobs import Job, DataScienceJob, PythonRuntime

job = (
    Job(name="sdxl-train-job")
    .with_infrastructure(
        DataScienceJob()
        .with_log_group_id(LOG_GROUP_ID)
        .with_log_id(LOG_ID)
        .with_shape_name("VM.GPU2.1")
    )
    .with_runtime(
        PythonRuntime()
        .with_service_conda(CONDA_ENV)
        .with_source("/home/datascience/job_artifact/")
        .with_entrypoint("stable_main.py")
        .with_working_dir("job_artifact")
        .with_environment_variable(full_input_folder=INPUT_FOLDER)
        .with_output("./output", OUTPUT_FOLDER)
    )
)

job.create()

In [None]:
job_run_env = job.run(
    name="Job Run - Passing dynamic values",
    env_var={'full_input_folder': INPUT_FOLDER}
)
#job_run_watch = job_run_env.watch()

In [None]:
%%capture
import os
from ads.model.generic_model import GenericModel

fs = ocifs.OCIFileSystem()

comfyui = "ComfyUI"
if not os.path.exists(comfyui):
    !git clone https://github.com/comfyanonymous/ComfyUI

if not os.path.exists("ComfyUI/models/loras/sks.safetensors"):
    fs.invalidate_cache(OUTPUT_FOLDER)
    fs.get(OUTPUT_FOLDER + "sks.safetensors", comfyui + "/models/loras/" , recursive=True, refresh=True)
    
if not os.path.exists("ComfyUI/models/checkpoints/sd_xl_base_1.0.safetensors"):
    !wget -O ComfyUI/models/checkpoints/sd_xl_base_1.0.safetensors https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors


if not os.path.exists("ComfyUI/venv"):
    ! cd ComfyUI && pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 xformers && pip install -r requirements.txt


In [None]:
import threading
import subprocess

def start_comfyui():
    subprocess.run(["python3", "ComfyUI/main.py"])

thread_1 = threading.Thread(target=start_comfyui)
thread_1.start()

In [None]:
import inference
import matplotlib.pyplot as plt
img = inference.generate_image('portrait sks, pencil')
plt.axis('off')
plt.imshow(img)