In [78]:
!pip -q install wandb datasets

In [79]:
# for common
import os
import json
from pathlib import Path

import wandb
import boto3
import sagemaker
from datasets import load_dataset

# for training
from sagemaker.huggingface import HuggingFace

# utils
from utils import create_bucket_if_not_exists

In [80]:
try:
    aws_role = sagemaker.get_execution_role()
except:
    iam = boto3.client("iam")
    # TODO: replace with your role name (i.e. "AmazonSageMaker-ExecutionRole-20211014T154824")
    aws_role = iam.get_role(RoleName="<replace with your RoleName>")["Role"]["Arn"]

boto_session = boto3.Session()
aws_region = boto_session.region_name
sess = sagemaker.Session(boto_session=boto_session)
account_id = boto3.client("sts").get_caller_identity().get("Account")

print(aws_role)
print(aws_region)
print(sess.boto_region_name)

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole


arn:aws:iam::851725450449:role/service-role/AmazonSageMaker-ExecutionRole-20240125T135936
us-east-1
us-east-1


In [81]:
local_training_dataset_folder = Path("dataset")
if not os.path.exists(local_training_dataset_folder):
    os.mkdir(local_training_dataset_folder)

In [82]:
use_local_images = False

cache_dir = "/home/ec2-user/SageMaker/.cache/dataset"
if not use_local_images:
    dataset_name = "haandol/icon"
    dataset = load_dataset(dataset_name, split="train", cache_dir=cache_dir)

    metadata = []
    for i, datum in enumerate(dataset):
        fn = f"{i}".zfill(3) + ".jpg"
        datum["image"].convert("RGB").save(local_training_dataset_folder / fn)
        datum["text"]
        metadata.append(json.dumps({"file_name": fn, "text": datum["text"]}))

    with open(local_training_dataset_folder / "metadata.jsonl", "w") as fp:
        fp.write("\n".join(metadata))

In [83]:
validation_prompt = "a black telephone handset in front of the Eiffel tower"
# Instance prompt is fed into the training script via dataset_info.json present in the training folder. Here, we write that file.
with open(os.path.join(local_training_dataset_folder, "dataset_info.json"), "w") as fp:
    fp.write(json.dumps({"validation_prompt": validation_prompt}))

In [84]:
training_bucket = f"sdxl-txt2img-lora-{account_id}"

create_bucket_if_not_exists(training_bucket)

train_s3_path = f"s3://{training_bucket}/sdxl-txt2img-lora/"

Using an existing bucket sdxl-txt2img-lora-851725450449


In [85]:
!aws s3 cp --recursive $local_training_dataset_folder $train_s3_path

upload: dataset/.ipynb_checkpoints/dataset_info-checkpoint.json to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/.ipynb_checkpoints/dataset_info-checkpoint.json
upload: dataset/000.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/000.jpg
upload: dataset/001.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/001.jpg
upload: dataset/005.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/005.jpg
upload: dataset/.ipynb_checkpoints/metadata-checkpoint.jsonl to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/.ipynb_checkpoints/metadata-checkpoint.jsonl
upload: dataset/003.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/003.jpg
upload: dataset/002.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/002.jpg
upload: dataset/009.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/009.jpg
upload: dataset/013.jpg to s3://sdxl-txt2img-lora-851725450449/sdxl-txt2img-lora/013.jpg
upload: dataset/012.jpg to s3://sdxl-txt2img-

In [86]:
output_bucket = sess.default_bucket()
output_prefix = "sdxl-txt2img-lora"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
s3_output_location

's3://sagemaker-us-east-1-851725450449/sdxl-txt2img-lora/output'

In [87]:
base_model = "stabilityai/stable-diffusion-xl-base-1.0"
vae_model = "madebyollin/sdxl-vae-fp16-fix"

In [88]:
hyperparameters = {
    "pretrained_model_name_or_path": base_model,
    "pretrained_vae_model_name_or_path": vae_model,
    # xformers with mix_precision has bug in this version. - https://github.com/huggingface/accelerate/issues/2182
    # "mixed_precision": "fp16",
    "rank": 32,
    "learning_rate": 1e-04,
    "max_train_steps": 12000,
    "train_batch_size": 1,
    "checkpointing_steps": 500,
    "checkpoints_total_limit": 5,
    "gradient_accumulation_steps": 4,
    "resolution": 512,
    "use_8bit_adam": True,
    "gradient_checkpointing": True,
    "train_text_encoder": True,
    "enable_xformers_memory_efficient_attention": True,
    "lr_warmup_steps": 100,
    "lr_scheduler": "cosine_with_restarts",
}
hyperparameters

{'pretrained_model_name_or_path': 'stabilityai/stable-diffusion-xl-base-1.0',
 'pretrained_vae_model_name_or_path': 'madebyollin/sdxl-vae-fp16-fix',
 'rank': 32,
 'learning_rate': 0.0001,
 'max_train_steps': 12000,
 'train_batch_size': 1,
 'checkpointing_steps': 500,
 'checkpoints_total_limit': 5,
 'gradient_accumulation_steps': 4,
 'resolution': 512,
 'use_8bit_adam': True,
 'gradient_checkpointing': True,
 'train_text_encoder': True,
 'enable_xformers_memory_efficient_attention': True,
 'lr_warmup_steps': 100,
 'lr_scheduler': 'cosine_with_restarts'}

In [89]:
training_job_name = "sdxl-txt2img-lora"
training_job_name

'sdxl-txt2img-lora-hf'

In [90]:
s3_checkpoint_location = f"s3://{output_bucket}/{training_job_name}/checkpoints"
s3_checkpoint_location

's3://sagemaker-us-east-1-851725450449/sdxl-txt2img-lora-hf/checkpoints'

In [91]:
use_wandb = True

environment = {}
if use_wandb:
    wandb.login()
    hyperparameters["report_to"] = "wandb"
    environment.update(
        {
            "WANDB_API_KEY": "",  # Update API key
            "WANDB_PROJECT": "sdxl-txt2img-lora",
        }
    )

In [92]:
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = f"{output_bucket}/models/sdxl-1.0-base.tar.gz"
train_model_uri

'sagemaker-us-east-1-851725450449/models/sdxl-1.0-base.tar.gz'

In [103]:
instance_type = "ml.g5.16xlarge"
image_uri = "763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-training:2.1.0-transformers4.36.0-gpu-py310-cu121-ubuntu20.04"

In [None]:
sd_estimator = HuggingFace(
    # model_uri=train_model_uri,
    role=aws_role,
    entry_point="train.py",
    source_dir="./src",
    instance_count=1,
    instance_type=instance_type,
    max_run=360000,
    volume_size=128,
    image_uri=image_uri,
    pytorch_version="2.1",
    py_version="3.10",
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
    checkpoint_local_path="/opt/ml/checkpoints",
    checkpoint_s3_uri=s3_checkpoint_location,
    environment=environment,
)
sd_estimator.fit({"training": train_s3_path}, logs=True)

INFO:sagemaker:Creating training-job with name: sdxl-txt2img-lora-hf-2024-02-28-02-33-49-872


2024-02-28 02:33:50 Starting - Starting the training job...
2024-02-28 02:34:17 Pending - Preparing the instances for training......
2024-02-28 02:35:08 Downloading - Downloading the training image.....................
2024-02-28 02:38:29 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2024-02-28 02:38:33,917 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2024-02-28 02:38:33,936 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2024-02-28 02:38:33,945 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2024-02-28 02:38:33,947 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2024-02-28 02:38:35,314 sagemaker-training-toolkit INFO     Installing dependenci