# Dreambooth Implementation Explained

Dreambooth is a fine-tuning technique for text-to-image diffusion models. It adapts a pre-trained model to generate high-quality, personalized images of a specific subject (person, object, or style) using only a few 5-6 images.

It is a unique finetuning technique where we use instance images(subject) and regularization images(class-specific images) to finetune the model, where regularization images helps the model to prevent it from overfitting and forget the previous informations.

In [101]:
import argparse
import itertools
import math
import os
from multiprocessing import Value
import toml

from tqdm import tqdm
import torch
import transformers

In [102]:
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

### Prepare Dataset

We will load instance images (images of the subject which we need to personalize the model) and regularization images (subject's class images) and they will be indexed based on even and odd, means in even index instance will be retrive and odd the reg images will be retrive.

In [103]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pathlib
from PIL import Image, UnidentifiedImageError
import numpy as np
from pathlib import Path

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512))
])


class DreamBoothDataset(Dataset):
    def __init__(self, data_dir, reg_dir, transforms, tokenizer, instance_prompt, class_prompt):
        self.instance_images = self.load_images(data_dir)
        self.reg_images = self.load_images(reg_dir)
        self.transforms = transforms
        self.tokenizer = tokenizer
        self.instance_prompt = instance_prompt
        self.class_prompt = class_prompt
        self._length = max(len(self.reg_images), len(self.instance_images))
        

    def load_images(self, data_dir):
        images = [] 
        for img_path in tqdm(pathlib.Path(data_dir).glob("*")):
            try:
                img = Image.open(img_path)
                images.append(np.array(img))
            except UnidentifiedImageError:
                print(f"Error: {img_path} is not a valid image file.")
            except Exception as e:
                print(f"An error occurred while opening {img_path}: {e}")
        return images

    def process_text(self, tokenizer, input_text):
        max_length = tokenizer.model_max_length
        text_input = tokenizer(
        input_text,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
        )
        return text_input['input_ids']

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        if idx % 2 == 0:
            img = self.instance_images[idx % len(self.instance_images)]
            input_ids = self.process_text(self.tokenizer, self.instance_prompt)
        else:
            reg_idx = torch.randint(0, len(self.reg_images), (1,)).item()
            img = self.reg_images[reg_idx]
            input_ids = self.process_text(self.tokenizer, self.class_prompt)
        if self.transforms:
            img = transforms(img)
        
        return img, input_ids

### Download the SD model

In [104]:
from huggingface_hub import snapshot_download

repo_id = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
local_model_dir = Path("./realitic_vision_sd1.5")
model_dir = snapshot_download(repo_id, local_dir=local_model_dir)

Fetching 21 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:00<00:00, 6743.77it/s]


### Define some variables

In [105]:
device = 'cuda'
instance_prompt = "a photo of a zwx man"
class_prompt = "a photo of a man"
prior_loss_weight = 1.0
instance_dir = './harsh_photos'
reg_dir = Path("./reg_dir")

### Generate Reglarization Images

In [106]:
from pathlib import Path
from diffusers import StableDiffusionPipeline

num_reg_images = 50
batch_size = 4


if not reg_dir.exists() or len(list(reg_dir.glob("*.jpg"))) < num_reg_images:
    pipeline = StableDiffusionPipeline.from_pretrained(model_dir).to(device)
    reg_dir.mkdir(exist_ok=True)
    for i in range((num_reg_images // batch_size)+1):
        if i == (num_reg_images // batch_size):
            total += num_reg_images % batch_size
            images = pipeline(class_prompt, num_images_per_prompt=num_reg_images % batch_size).images
        else:
            total += 4
            images = pipeline(class_prompt, num_images_per_prompt=batch_size).images

        [img.save(f'{reg_dir}/{i}_{j}.jpg') for j, img in enumerate(images)]

### Load Tokenizer

In [107]:
# load tokenizer
from transformers import CLIPTokenizer, CLIPTextModel

tokenizer = CLIPTokenizer.from_pretrained(local_model_dir / 'tokenizer')
tokenizer

CLIPTokenizer(name_or_path='realitic_vision_sd1.5/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	49407: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

### Load the Dataset

In [108]:
dataset = DreamBoothDataset(instance_dir, reg_dir, transforms, tokenizer, instance_prompt, class_prompt)
len(dataset)

17it [00:01, 16.96it/s]
50it [00:00, 726.26it/s]


50

### Create Training Dataloader

In [109]:
train_dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=8  # or use 1
)

### Define Tensor Presicion

In [110]:
dtype = torch.float16

### Load Text Encoder, Unet And VAE

In [111]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained(local_model_dir).to('cuda')

vae = pipe.vae.to(dtype=dtype)
text_encoder = pipe.text_encoder.to(dtype=dtype)
unet = pipe.unet.to(dtype=dtype)

vae, text_encoder, unet

Loading pipeline components...: 100%|█████████████████████████████████████████████████████████████████████████████████| 7/7 [00:02<00:00,  2.96it/s]


(AutoencoderKL(
   (encoder): Encoder(
     (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (down_blocks): ModuleList(
       (0): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0-1): 2 x ResnetBlock2D(
             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
             (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
             (dropout): Dropout(p=0.0, inplace=False)
             (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (nonlinearity): SiLU()
           )
         )
         (downsamplers): ModuleList(
           (0): Downsample2D(
             (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
           )
         )
       )
       (1): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0): ResnetBlock2D(
             (norm

#### Weather to train text encoder

In [112]:
train_text_encoder = True

In [113]:
unet.requires_grad_(True)
text_encoder.requires_grad_(train_text_encoder)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

### We don't need to train VAE so make sure it is in eval mode

In [114]:
vae.requires_grad_(False)
vae.eval()

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, ep

In [115]:
lr = 1e-4

In [116]:
trainable_params = [
    {"params": list(unet.parameters()), "lr": lr },
    {"params": list(text_encoder.parameters()), "lr": lr }
]

In [117]:
optimizer_type = transformers.optimization.Adafactor
optimizer = optimizer_type(trainable_params, lr=lr, relative_step=False)
optimizer

Adafactor (
Parameter Group 0
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 0.0001
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0

Parameter Group 1
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 0.0001
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0
)

### Initialize Scheduler

In [118]:
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
scheduler_type = SchedulerType("constant") # consine, polynomial
lr_scheduler = TYPE_TO_SCHEDULER_FUNCTION[scheduler_type](optimizer)
lr_scheduler

<torch.optim.lr_scheduler.LambdaLR at 0x70c64b236fb0>

In [119]:
if train_text_encoder:
    training_models = [text_encoder, unet]
else:
    [unet]

### Difine Noise Scheduler

In [120]:
noise_scheduler = DDPMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)

### Constants

In [121]:
num_training_steps = 500
total_steps = 0
stop = 0
min_timestep = 0
max_timestep = 1000

### Start Finetuning

This training loop iterates over batches, processing images and text inputs. It encodes images, adds noise, predicts noise using a UNet, 
calculates loss, and updates model parameters. The loop continues until a specified number of steps is reached.

How the model is learning:

a) Image encoding: The VAE encodes input images into a latent space. <br>
b) Text conditioning: The text encoder processes the input text, providing context for image generation. <br>
c) Noise addition: Random noise is added to the latent representation of the images at a randomly selected timestep. <br>
d) Noise prediction: The UNet attempts to predict this added noise, conditioned on the noisy latent, the timestep, and the text encoding.<br>
e) Loss calculation: The model's prediction is compared to the actual noise added, using mean squared error loss. <br>
f) Backpropagation: The loss is used to update the model parameters, improving its ability to predict noise accurately. <br>

Predicting the noise is central to the diffusion model's learning process. By learning to predict the noise that was added, the model learns to denoise the image. This is because if you can predict the noise accurately, you can subtract it from the noisy image to recover the original.

After each step the regularization image is applied to make the learning process long and not prevent the model from overfitting which result in better identity preservation of the subject.

In [122]:
while stop == 0:
    print("Training has started")

    # set the train mode in all trainable models
    for model in training_models:
        model.train()

    for step, batch in enumerate(train_dataloader):

        optimizer.zero_grad()
        
        images = batch[0].to(device).to(dtype=dtype)
        input_ids = batch[1].to(device)
        with torch.no_grad():
            # extract the low dim latents from the vae
            latents = vae.encode(images).latent_dist.sample()
            latents = latents * 0.18215

        batch_size = latents.shape[0]
        
        # get the text embedding for conditioning
        input_ids = input_ids.reshape((-1, tokenizer.model_max_length))
        encoder_hidden_states = text_encoder(input_ids)[0]

        # sample a random timestep for each image, add noise to the latents
        # generate random noise
        noise = torch.randn_like(latents, device=latents.device)
        # generate random timestep 
        timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device)
        # apply noise to latents
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        # run the unet
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        target = noise


        # calculate loss, we are using l2 loss
        loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
        loss = loss.mean([1,2,3])

        loss = loss.mean()
        loss.backward()

        # apply gradient norm to provent gradient from explording or vanishing
        torch.nn.utils.clip_grad_norm_(unet.parameters(), max_norm=1.0)
        torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), max_norm=1.0)

        optimizer.step()
        lr_scheduler.step()
        

        total_steps += 1
        print(f"Step: {total_steps}  Loss: {loss.item()}")
        if total_steps >= num_training_steps:
            stop = 1
            break

        
        
            

Training has started
Step: 1  Loss: 0.2859356105327606
Step: 2  Loss: 0.14254876971244812
Step: 3  Loss: 0.17556744813919067
Step: 4  Loss: 0.3015856146812439
Step: 5  Loss: 0.47736552357673645
Step: 6  Loss: 0.006030983291566372
Step: 7  Loss: 0.30024585127830505
Step: 8  Loss: 0.015318508259952068
Step: 9  Loss: 0.2885662913322449
Step: 10  Loss: 0.174933522939682
Step: 11  Loss: 0.20796789228916168
Step: 12  Loss: 0.0394824780523777
Step: 13  Loss: 0.04179912060499191
Step: 14  Loss: 0.13642758131027222
Step: 15  Loss: 0.009633347392082214
Step: 16  Loss: 0.14913204312324524
Step: 17  Loss: 0.1373412311077118
Step: 18  Loss: 0.013937128707766533
Step: 19  Loss: 0.007134093437343836
Step: 20  Loss: 0.3783424198627472
Step: 21  Loss: 0.08668725192546844
Step: 22  Loss: 0.005828009452670813
Step: 23  Loss: 0.01196327619254589
Step: 24  Loss: 0.06169142946600914
Step: 25  Loss: 0.016028687357902527
Step: 26  Loss: 0.15228384733200073
Step: 27  Loss: 0.3887144923210144
Step: 28  Loss: 0.

In [131]:
pipeline = StableDiffusionPipeline(
    text_encoder=text_encoder,
    vae=vae,
    unet=unet,
    tokenizer=tokenizer,
    scheduler=noise_scheduler,
    safety_checker=None,
    feature_extractor=None,
)

You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .


In [None]:
pipeline("potrait of a zwx man in beach").images[0]

 90%|███████████████████████████████████████████████████████████████████████████████████████████████████▉           | 45/50 [00:01<00:00, 28.31it/s]