# SDXL + DreamBooth + LoRA – Local Training Example

In this notebook, we demonstrate how to fine-tune Stable Diffusion XL (SDXL) with DreamBooth
using LoRA (Low-Rank Adaptation) for local usage.  


LoRA works by injecting low-rank adaptation matrices into certain layers of a large model,
significantly reducing the number of trainable parameters. This method enables the model
to efficiently adapt to new concepts.

## 1. Install and Import Dependencies

In [1]:
# Check GPU
# !nvidia-smi

# Install dependencies
!pip install bitsandbytes transformers accelerate peft -q
!pip install git+https://github.com/huggingface/diffusers.git -q

# Download the DreamBooth + LoRA SDXL training script
!wget https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora_sdxl.py


^C
[31mERROR: Operation cancelled by user[0m[31m
[0m--2025-02-19 11:28:55--  https://raw.githubusercontent.com/huggingface/diffusers/main/examples/dreambooth/train_dreambooth_lora_sdxl.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.108.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 84922 (83K) [text/plain]
Saving to: ‘train_dreambooth_lora_sdxl.py’


2025-02-19 11:28:55 (3.84 MB/s) - ‘train_dreambooth_lora_sdxl.py’ saved [84922/84922]



## 2. Dataset
Below we show how to either upload images locally or download example data from the Hugging Face Hub.
Make sure you have your training images in a local folder. You can also auto-generate captions
using a BLIP model if desired.

In [None]:
import os
from google.colab import files
import glob
from PIL import Image


local_dir = "./dog/"
os.makedirs(local_dir, exist_ok=True)
os.chdir(local_dir)

# Uncomment to upload images manually:
# uploaded_images = files.upload()

# Comment if using uploaded images:
os.chdir("/content")
from huggingface_hub import snapshot_download
snapshot_download(
    "diffusers/dog-example",
    local_dir=local_dir,
    repo_type="dataset",
    ignore_patterns=".gitattributes",
)



def image_grid(imgs, rows, cols, resize=256):
    if resize is not None:
        imgs = [img.resize((resize, resize)) for img in imgs]
    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

img_paths = "./dog/*.jpeg"
imgs = [Image.open(path) for path in glob.glob(img_paths)]
num_imgs_to_preview = min(5, len(imgs))
display(image_grid(imgs[:num_imgs_to_preview], 1, num_imgs_to_preview))


## 3. Auto-generate Captions with BLIP
You can generate image captions automatically, then prepend or append with tokens relevant to your concept.


In [None]:
import torch
import gc
from transformers import AutoProcessor, BlipForConditionalGeneration
import glob
from PIL import Image
import json



device = "cuda" if torch.cuda.is_available() else "cpu"
blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained(
    "Salesforce/blip-image-captioning-base", torch_dtype=torch.float16
).to(device)


def caption_images(input_image):
    inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
    pixel_values = inputs.pixel_values
    generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_caption


# Specify folder containing images to be tagged:
local_dir = "./dog/"
imgs_and_paths = [(path, Image.open(path)) for path in glob.glob(f"{local_dir}*.jpeg")]

# Add desired captioning prefix to each image:
caption_prefix = "a photo of TOK dog, "
with open(f'{local_dir}metadata.jsonl', 'w') as outfile:
    for (path, img) in imgs_and_paths:
        caption = caption_prefix + caption_images(img).split("\n")[0]
        entry = {"file_name": path.split("/")[-1], "prompt": caption}
        json.dump(entry, outfile)
        outfile.write('\n')

# Clean up memory
del blip_processor, blip_model
gc.collect()
torch.cuda.empty_cache()


## 4. Prepare Accelerate & Configuration
Initialize an Accelerate config, which helps handle multi-GPU or single-GPU setups.


In [None]:
import locale
locale.getpreferredencoding = lambda: "UTF-8"
!accelerate config default

## 5. Train the Model
We call the training script with relevant parameters, including LoRA settings for DreamBooth.
This saves LoRA weights to a local directory.

In [None]:
!pip install datasets -q

!accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
  --pretrained_vae_model_name_or_path="madebyollin/sdxl-vae-fp16-fix" \
  --dataset_name="dog" \
  --output_dir="corgy_dog_LoRA" \
  --caption_column="prompt" \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of TOK dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=3 \
  --gradient_checkpointing \
  --learning_rate=1e-4 \
  --snr_gamma=5.0 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --use_8bit_adam \
  --max_train_steps=500 \
  --checkpointing_steps=717 \
  --seed="0"


## 6. Local Inference
Once training has finished, we have a local folder (e.g., "corgy_dog_LoRA") containing LoRA weights.
We load them into the SDXL pipeline to generate new images for our concept:


In [None]:
import torch
from diffusers import DiffusionPipeline, AutoencoderKL

lora_folder = "corgy_dog_LoRA"

vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
    use_safetensors=True
).to("cuda")

# Load local LoRA weights
pipe.load_lora_weights(lora_folder)

prompt = "a photo of TOK dog in a new york"
image = pipe(prompt=prompt, num_inference_steps=25).images[0]
display(image)


### References

- [LoRA Paper](https://arxiv.org/abs/2106.09685)
- [DreamBooth Paper](https://arxiv.org/abs/2208.12242)
- [Huggingface Diffusers Github](https://github.com/huggingface/diffusers)