In [1]:
!pip -q install wandb datasets torchvision

In [1]:
# for common
import os
import json
from pathlib import Path

import wandb
from datasets import load_dataset, Image
import boto3, botocore
import sagemaker
from sagemaker import get_execution_role

# for resizing
from PIL import Image
from torchvision import transforms

# for training
from sagemaker import image_uris, model_uris, script_uris

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.parameter import ContinuousParameter, IntegerParameter
from sagemaker.tuner import HyperparameterTuner

# for inference
import matplotlib.pyplot as plt
import numpy as np
from sagemaker.predictor import Predictor

# utils
from utils import create_bucket_if_not_exists

  from pandas.core.computation.check import NUMEXPR_INSTALLED


sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [2]:
try:
    aws_role = sagemaker.get_execution_role()
except:
    iam = boto3.client("iam")
    # TODO: replace with your role name (i.e. "AmazonSageMaker-ExecutionRole-20211014T154824")
    aws_role = iam.get_role(RoleName="<replace with your RoleName>")["Role"]["Arn"]

boto_session = boto3.Session()
aws_region = boto_session.region_name
sess = sagemaker.Session(boto_session=boto_session)
account_id = boto3.client("sts").get_caller_identity().get("Account")

print(aws_role)
print(aws_region)
print(sess.boto_region_name)

arn:aws:iam::851725450449:role/service-role/AmazonSageMaker-ExecutionRole-20240125T135936
us-east-1
us-east-1


In [3]:
local_training_dataset_folder= Path("text_to_image_training_images")
if not os.path.exists(local_training_dataset_folder):
    os.mkdir(local_training_dataset_folder)

In [7]:
use_local_images = False

cache_dir = '/home/ec2-user/SageMaker/.cache/dataset'
if not use_local_images:
    dataset_name = 'haandol/icon'
    dataset = load_dataset(dataset_name, split='train', cache_dir=cache_dir)
    
    metadata = []
    for i, datum in enumerate(dataset):
        fn = f'{i}'.zfill(3) + '.jpg'
        datum['image'].convert('RGB').save(local_training_dataset_folder / fn)
        datum['text']
        metadata.append(json.dumps({
            'file_name': fn,
            'text': datum['text']
        }))
    
    with open(local_training_dataset_folder / 'metadata.jsonl', 'w') as fp:
        fp.write('\n'.join(metadata))

In [8]:
validation_prompt = 'a photo of black telephone handset in front of the Eiffel tower'
# Instance prompt is fed into the training script via dataset_info.json present in the training folder. Here, we write that file.
with open(os.path.join(local_training_dataset_folder, "dataset_info.json"), "w") as fp:
    fp.write(json.dumps({
        'validation_prompt': validation_prompt
    }))

In [10]:
training_bucket = f"sdxl-text-to-image-lora-{account_id}"

create_bucket_if_not_exists(training_bucket)

train_s3_path = f"s3://{training_bucket}/text-to-image/"

Created s3.Bucket(name='sdxl-text-to-image-lora-851725450449')


In [11]:
!aws s3 cp --recursive $local_training_dataset_folder $train_s3_path

upload: text_to_image_training_images/001.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/001.jpg
upload: text_to_image_training_images/000.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/000.jpg
upload: text_to_image_training_images/004.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/004.jpg
upload: text_to_image_training_images/002.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/002.jpg
upload: text_to_image_training_images/003.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/003.jpg
upload: text_to_image_training_images/007.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/007.jpg
upload: text_to_image_training_images/009.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/009.jpg
upload: text_to_image_training_images/006.jpg to s3://sdxl-text-to-image-lora-851725450449/text-to-image/006.jpg
upload: text_to_image_training_images/008.jpg to s3://sdxl-text-to-image-lora-851725450449/text-

In [12]:
output_bucket = sess.default_bucket()
output_prefix = "sdxl-text-to-image-lora"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
s3_output_location

's3://sagemaker-us-east-1-851725450449/sdxl-text-to-image-lora/output'

In [14]:
training_instance_type = "ml.g5.12xlarge"

# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = f'{output_bucket}/models/sdxl-1.0-base.tar.gz'
train_model_uri

'sagemaker-us-east-1-851725450449/models/sdxl-1.0-base.tar.gz'

In [15]:
base_model = 'stabilityai/stable-diffusion-xl-base-1.0'
vae_model = 'madebyollin/sdxl-vae-fp16-fix'

In [16]:
hyperparameters = {
    'pretrained_model_name_or_path': base_model,
    #    'pretrained_vae_model_name_or_path': vae_model,
    # mix_precision has bug in this version. - https://github.com/huggingface/accelerate/issues/2182
    #    'mixed_precision': 'fp16',
    'enable_xformers_memory_efficient_attention': True,
    'rank': 128,
    'learning_rate': 1e-06,
    'max_train_steps': 3000,
    'train_batch_size': 1,
    'checkpointing_steps': 500,
    'checkpoints_total_limit': 5,
    'gradient_accumulation_steps': 4,
    'resolution': 1024,
    'use_8bit_adam': True,
    'gradient_checkpointing': True,
    'lr_warmup_steps': 0,
    'lr_scheduler': 'constant',
}
hyperparameters

{'pretrained_model_name_or_path': 'stabilityai/stable-diffusion-xl-base-1.0',
 'enable_xformers_memory_efficient_attention': True,
 'rank': 128,
 'learning_rate': 1e-06,
 'max_train_steps': 3000,
 'train_batch_size': 1,
 'checkpointing_steps': 500,
 'checkpoints_total_limit': 5,
 'gradient_accumulation_steps': 4,
 'resolution': 1024,
 'use_8bit_adam': True,
 'gradient_checkpointing': True,
 'lr_warmup_steps': 0,
 'lr_scheduler': 'constant'}

In [17]:
training_job_name = "sdxl-text-to-image-lora"
training_job_name

'sdxl-text-to-image-lora'

In [18]:
s3_checkpoint_location = f"s3://{output_bucket}/{training_job_name}/checkpoints"
s3_checkpoint_location

's3://sagemaker-us-east-1-851725450449/sdxl-text-to-image-lora/checkpoints'

In [20]:
tag = 'da4dd2d995'
train_image_uri = f'{account_id}.dkr.ecr.{aws_region}.amazonaws.com/sdxl-text-to-image-lora-training-gpu:{tag}'
train_image_uri

'851725450449.dkr.ecr.us-east-1.amazonaws.com/sdxl-text-to-image-lora-training-gpu:da4dd2d995'

In [22]:
use_wandb = True
if use_wandb:
    wandb.login()
    hyperparameters['report_to'] = 'wandb'
    environment = {
        'WANDB_API_KEY': '',  # Update API key
        'WANDB_PROJECT': 'sdxl-text-to-image-lora',
    }

[34m[1mwandb[0m: Currently logged in as: [33mhaandol[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
#    model_uri=train_model_uri,
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    volume_size=128,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
    checkpoint_local_path="/opt/ml/checkpoints",
    checkpoint_s3_uri=s3_checkpoint_location,
    environment=environment if use_wandb else None,
)
sd_estimator.fit({"training": train_s3_path}, logs=True)

INFO:sagemaker:Creating training-job with name: sdxl-text-to-image-lora-2024-01-25-06-40-58-349


2024-01-25 06:40:58 Starting - Starting the training job
2024-01-25 06:40:58 Pending - Training job waiting for capacity.........