<a href="https://colab.research.google.com/github/davidricardocr/sdxl-lora-fine-tuning/blob/main/SDXL_LoRA_Fine_Tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Stable Diffusion XL LoRA Fine-Tuning Guide

Welcome to this guide on fine-tuning Stable Diffusion XL (SDXL) with **LoRA (Low-Rank Adaptation)**. In this notebook, we will walk through setting up the environment, executing the fine-tuning script, and loading the resulting weights for inference.

**Objective**: Our goal is to customize the SDXL model to generate images with specific styles or themes. To achieve this, we use LoRA, a parameter-efficient fine-tuning technique, making the process computationally feasible even on limited hardware.

We'll start by setting up the environment using a custom shell script, build.sh, which handles the installation of required libraries and configurations. After that, we'll explain each parameter in the fine-tuning command to understand their roles in managing computational load and model performance.


## Environment Setup
To simplify the environment setup, we have created a shell script (build.sh) that installs the necessary dependencies for SDXL fine-tuning. This includes cloning the diffusers repository, installing specific requirements for SDXL examples, and configuring tools to optimize computation.

Simply run the following cell to execute the setup:

In [None]:
# Execute the setup script
!bash build.sh

## Fine-Tuning with LoRA and `accelerate`

Now that the environment is ready, let's fine-tune SDXL using LoRA. The following command leverages the `accelerate` library to handle efficient parallelism and optimization for large model training. Here’s a breakdown of each parameter:

- `--pretrained_model_name_or_path`: The base model to fine-tune. Here, we're using `stabilityai/stable-diffusion-xl-base-1.0`.
- `--pretrained_vae_model_name_or_path`: The pre-trained VAE model path for stable generation. We use `sdxl-vae-fp16-fix` for improved image quality.
- `--dataset_name`: Specifies the dataset to use; in this case, a Pokémon captioning dataset.
- `--dataloader_num_workers`: Set to 8 to increase data loading efficiency during training.
- `--caption_column`: The column in the dataset that provides the captions for image generation.
- `--train_batch_size`: Batch size of 5 helps balance memory load and training speed.
- `--num_train_epochs`: We set this to 10, allowing sufficient training while controlling computational time.
- `--learning_rate`: A low rate (1e-4) to prevent overfitting during fine-tuning.
- `--lr_scheduler`: "constant" maintains a steady learning rate, simplifying training stability.
- `--resolution`: Output image resolution. Here, we use 512 for quality and computational feasibility.
- `--center_crop` & `--random_flip`: Basic data augmentations to improve model robustness.
- `--output_dir`: Where fine-tuned weights will be saved.
- `--validation_prompt`: Used to periodically check the model's output during training.
- `--checkpointing_steps`: Save model checkpoints every 500 steps.
- `--gradient_checkpointing` & `--gradient_accumulation_steps`: Techniques to handle large model gradients without overwhelming memory.
- `--mixed_precision="fp16"`: 16-bit floating point precision to reduce memory usage.
- `--use_8bit_adam`: Optimizes computation with 8-bit Adam optimizer, reducing resource needs.
- `--seed`: Setting for reproducibility.

Run the following cell to start fine-tuning:


In [None]:
!accelerate launch --config_file config.yaml diffusers/examples/text_to_image/train_text_to_image_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="svjack/pokemon-blip-captions-en-zh" \
  --dataloader_num_workers=8 \
  --caption_column="en_text" \
  --train_batch_size=5 \
  --num_train_epochs=10 \
  --learning_rate=1e-04 \
  --lr_scheduler="constant" \
  --resolution=512 \
  --center_crop \
  --random_flip \
  --output_dir="sdxl-lora-weights" \
  --validation_prompt="A cat in the forest." \
  --num_validation_images=5 \
  --checkpointing_steps=500 \
  --lr_warmup_steps=0 \
  --gradient_checkpointing \
  --gradient_accumulation_steps=4 \
  --mixed_precision="fp16" \
  --use_8bit_adam \
  --seed=42


## Loading LoRA Weights and Running Inference

Once fine-tuning is complete, we can load the LoRA weights into the Stable Diffusion XL pipeline and run inference. The `diffusers` library provides a simple way to do this with the `StableDiffusionXLPipeline`, which allows us to leverage the fine-tuned model for custom image generation.

The following code snippet loads the weights and generates an image based on a prompt. Two key parameters in this process are num_inference_steps and guidance_scale:

* **num_inference_steps**: This parameter controls how many steps the model takes to generate the image. Higher values typically lead to more detailed images as the model has more iterations to refine the output. Here, we've set it to 100 to balance image quality with processing time.

* **guidance_scale**: This parameter influences how closely the generated image follows the prompt. A higher guidance scale means the model will adhere more strictly to the prompt details, though excessively high values can sometimes affect image coherence. In this case, a guidance scale of 10 helps ensure the image aligns well with the prompt while maintaining visual quality.


In [None]:
from diffusers import StableDiffusionXLPipeline
import torch

# Path where the LoRA weights are saved
model_path = "sdxl-lora-weights"

# Load the Stable Diffusion XL pipeline and set precision
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16
)
pipe.to("cuda")  # Use GPU for faster inference

# Load the fine-tuned LoRA weights
pipe.load_lora_weights(model_path)

# Generate an image with the fine-tuned model
image = pipe(
    prompt="A cat in the forest.",
    num_inference_steps=100,
    guidance_scale=10
).images[0]

# Display the generated image
image