<a href="https://colab.research.google.com/github/davidricardocr/sdxl-lora-fine-tuning/blob/main/SDXL_LoRA_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion XL LoRA Fine-Tuning Guide

Welcome to this guide on fine-tuning **Stable Diffusion XL (SDXL) with LoRA (Low-Rank Adaptation)**. In this notebook, we will walk through setting up the environment, executing the fine-tuning script, and loading the resulting weights for inference.

**Objective**: Our goal is to customize the SDXL model to generate images with specific styles or themes. To achieve this, we use LoRA, a parameter-efficient fine-tuning technique, making the process computationally feasible even on limited hardware.

In this tutorial, we will fine-tune the model using a **Pokémon-themed dataset**, `svjack/pokemon-blip-captions-en-zh`, which consists of a collection of Pokémon images paired with captions in both English and Chinese. This dataset offers a rich variety of images and descriptions that provide a creative and structured source for training the model. By using this dataset, we aim to adapt Stable Diffusion XL to generate unique, Pokémon-inspired visuals aligned with the captions' themes.

### Getting Started

We will start by downloading the setup files (`build.sh` and `config.yaml`) directly from the GitHub repository. These files will help us configure the environment and set the necessary parameters for fine-tuning.



In [1]:
# Download the build.sh script
!wget https://github.com/davidricardocr/sdxl-lora-fine-tuning/raw/main/build.sh -O build.sh

# Download the config.yaml file
!wget https://github.com/davidricardocr/sdxl-lora-fine-tuning/raw/main/config.yaml -O config.yaml

--2024-11-03 15:59:02--  https://github.com/davidricardocr/sdxl-lora-fine-tuning/raw/main/build.sh
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/davidricardocr/sdxl-lora-fine-tuning/main/build.sh [following]
--2024-11-03 15:59:03--  https://raw.githubusercontent.com/davidricardocr/sdxl-lora-fine-tuning/main/build.sh
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1007 [text/plain]
Saving to: ‘build.sh’


2024-11-03 15:59:03 (100 MB/s) - ‘build.sh’ saved [1007/1007]

--2024-11-03 15:59:03--  https://github.com/davidricardocr/sdxl-lora-fine-tuning/raw/main/config.yaml
Resolving github.com (githu

## Environment Setup
To simplify the environment setup, we have created a shell script (build.sh) that installs the necessary dependencies for SDXL fine-tuning. This includes cloning the diffusers repository, installing specific requirements for SDXL examples, and configuring tools to optimize computation.

Simply run the following cell to execute the setup:

In [2]:
# Execute the setup script
!bash build.sh

Cloning Hugging Face diffusers repository...
Cloning into 'diffusers'...
remote: Enumerating objects: 73564, done.[K
remote: Counting objects: 100% (13000/13000), done.[K
remote: Compressing objects: 100% (1297/1297), done.[K
remote: Total 73564 (delta 12433), reused 11880 (delta 11632), pack-reused 60564 (from 1)[K
Receiving objects: 100% (73564/73564), 51.46 MiB | 31.37 MiB/s, done.
Resolving deltas: 100% (54669/54669), done.
Installing diffusers in editable mode...
Obtaining file:///content/diffusers
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: diffusers
  Building editable for diffusers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for diffusers: filename=diffusers-0.32.0.dev0-0.editable-py3-none-any.whl size=11113 sha256=1

## Fine-Tuning with LoRA and `accelerate`

Now that the environment is ready, let's fine-tune SDXL using LoRA. The following command leverages the `accelerate` library to handle efficient parallelism and optimization for large model training. Here’s a breakdown of each parameter:

- `--pretrained_model_name_or_path`: The base model to fine-tune. Here, we're using `stabilityai/stable-diffusion-xl-base-1.0`.
- `--pretrained_vae_model_name_or_path`: The pre-trained VAE model path for stable generation. We use `sdxl-vae-fp16-fix` for improved image quality.
- `--dataset_name`: Specifies the dataset to use; in this case, a Pokémon captioning dataset.
- `--dataloader_num_workers`: Set to 8 to increase data loading efficiency during training.
- `--caption_column`: The column in the dataset that provides the captions for image generation.
- `--train_batch_size`: Batch size of 5 helps balance memory load and training speed.
- `--num_train_epochs`: We set this to 10, allowing sufficient training while controlling computational time.
- `--learning_rate`: A low rate (1e-4) to prevent overfitting during fine-tuning.
- `--lr_scheduler`: "constant" maintains a steady learning rate, simplifying training stability.
- `--resolution`: Output image resolution. Here, we use 512 for quality and computational feasibility.
- `--center_crop` & `--random_flip`: Basic data augmentations to improve model robustness.
- `--output_dir`: Where fine-tuned weights will be saved.
- `--validation_prompt`: Used to periodically check the model's output during training.
- `--checkpointing_steps`: Save model checkpoints every 500 steps.
- `--gradient_checkpointing` & `--gradient_accumulation_steps`: Techniques to handle large model gradients without overwhelming memory.
- `--mixed_precision="fp16"`: 16-bit floating point precision to reduce memory usage.
- `--use_8bit_adam`: Optimizes computation with 8-bit Adam optimizer, reducing resource needs.
- `--seed`: Setting for reproducibility.

Run the following cell to start fine-tuning:


In [None]:
!accelerate launch --config_file config.yaml diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="svjack/pokemon-blip-captions-en-zh" \
  --dataloader_num_workers=8 \
  --caption_column="en_text" \
  --train_batch_size=5 \
  --num_train_epochs=10 \
  --learning_rate=1e-04 \
  --lr_scheduler="constant" \
  --resolution=512 \
  --center_crop \
  --random_flip \
  --output_dir="sdxl-lora-weights" \
  --validation_prompt="A cat in the forest." \
  --num_validation_images=5 \
  --checkpointing_steps=500 \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --gradient_accumulation_steps=4 \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --seed=42


2024-11-03 16:00:27.878071: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-03 16:00:27.893925: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-03 16:00:27.914515: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-03 16:00:27.920875: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-03 16:00:27.936233: I tensorflow/core/platform/cpu_feature_guar

## Loading LoRA Weights and Running Inference

Once fine-tuning is complete, we can load the LoRA weights into the Stable Diffusion XL pipeline and run inference. The `diffusers` library provides a simple way to do this with the `StableDiffusionXLPipeline`, which allows us to leverage the fine-tuned model for custom image generation.

The following code snippet loads the weights and generates an image based on a prompt. Two key parameters in this process are num_inference_steps and guidance_scale:

* **num_inference_steps**: This parameter controls how many steps the model takes to generate the image. Higher values typically lead to more detailed images as the model has more iterations to refine the output. Here, we've set it to 100 to balance image quality with processing time.

* **guidance_scale**: This parameter influences how closely the generated image follows the prompt. A higher guidance scale means the model will adhere more strictly to the prompt details, though excessively high values can sometimes affect image coherence. In this case, a guidance scale of 10 helps ensure the image aligns well with the prompt while maintaining visual quality.


In [None]:
from diffusers import StableDiffusionXLPipeline
import torch

# Path where the LoRA weights are saved
model_path = "sdxl-lora-weights"

# Load the Stable Diffusion XL pipeline and set precision
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.to("cuda")  # Use GPU for faster inference

# Load the fine-tuned LoRA weights
pipe.load_lora_weights(model_path)

# Generate an image with the fine-tuned model
image = pipe(
    prompt="A cat in the forest.",
    num_inference_steps=100,
    guidance_scale=10
).images[0]

# Display the generated image
image