In [21]:
import argparse
import itertools
import math
import os
from multiprocessing import Value
import toml

from tqdm import tqdm
import torch
import transformers

In [2]:
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


### Prepare Dataset

In [72]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pathlib
from PIL import Image, UnidentifiedImageError
import numpy as np


transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512))
])
instance_prompt = "a photo of a zwx man"

class DreamBoothDataset(Dataset):
    def __init__(self, data_dir, transforms, tokenizer):
        self.images = self.load_images(data_dir)
        self.transforms = transforms
        self.tokenizer = tokenizer

    def load_images(self, data_dir):
        images = []
        for img_path in pathlib.Path(data_dir).glob("*"):
            try:
                img = Image.open(img_path)
                images.append(np.array(img))
            except UnidentifiedImageError:
                print(f"Error: {img_path} is not a valid image file.")
            except Exception as e:
                print(f"An error occurred while opening {img_path}: {e}")
        return images

    def process_text(self, tokenizer):
        text_input = tokenizer(
        instance_prompt,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
        )
        return text_input['input_ids']

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if self.transforms:
            img = transforms(img)
        input_ids = process_text(self.tokenizer)
        return img, input_ids

In [73]:
# load tokenizer
from transformers import CLIPTokenizer, CLIPTextModel

tokenizer = CLIPTokenizer.from_pretrained('./runwayml-sd1.5/tokenizer')
tokenizer

CLIPTokenizer(name_or_path='./runwayml-sd1.5/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	49407: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [74]:
dataset = DreamBoothDataset('./harsh_photos', transforms, tokenizer)
len(dataset)

17

In [75]:
train_dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=8
)

In [76]:
from huggingface_hub import snapshot_download

snapshot_download('runwayml/stable-diffusion-v1-5', local_dir="./runwayml-sd1.5")

Fetching 36 files: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:00<00:00, 4991.07it/s]


'/home/harshb/workspace/learnings/dreambooth/runwayml-sd1.5'

In [77]:
device = 'cuda'

### Define Tensor Presicion

In [79]:
dtype = torch.float16

### Load Text Encoder, Unet And VAE

In [80]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained('./runwayml-sd1.5').to('cuda')

vae = pipe.vae
text_encoder = pipe.text_encoder
unet = pipe.unet

vae, text_encoder, unet

Loading pipeline components...:  29%|███████████████████████████████▍                                                                              | 2/7 [00:00<00:00,  5.11it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loading pipeline components...: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  9.80it/s]


(AutoencoderKL(
   (encoder): Encoder(
     (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (down_blocks): ModuleList(
       (0): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0-1): 2 x ResnetBlock2D(
             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
             (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
             (dropout): Dropout(p=0.0, inplace=False)
             (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (nonlinearity): SiLU()
           )
         )
         (downsamplers): ModuleList(
           (0): Downsample2D(
             (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
           )
         )
       )
       (1): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0): ResnetBlock2D(
             (norm

#### Weather to train text encoder

In [81]:
train_text_encoder = True

In [82]:
unet.requires_grad_(True)
text_encoder.requires_grad_(train_text_encoder)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

#### gradient checkpointing

In [83]:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()

### We don't need to train VAE so make sure it is in eval mode

In [84]:
vae.requires_grad_(False)
vae.eval()

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, ep

In [85]:
lr = 1e-5

In [86]:
trainable_params = [
    {"params": list(unet.parameters()), "lr": lr },
    {"params": list(text_encoder.parameters()), "lr": lr }
]

In [87]:
optimizer_type = transformers.optimization.Adafactor
optimizer = optimizer_type(trainable_params, lr=lr, relative_step=False)
optimizer

Adafactor (
Parameter Group 0
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 1e-05
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0

Parameter Group 1
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 1e-05
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0
)

### Initialize Scheduler

In [88]:
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
scheduler_type = SchedulerType("constant") # consine, polynomial
scheduler = TYPE_TO_SCHEDULER_FUNCTION[scheduler_type](optimizer)
scheduler

<torch.optim.lr_scheduler.LambdaLR at 0x7756f5e71780>

In [89]:
if train_text_encoder:
    training_models = [text_encoder, unet]
else:
    [unet]

### Difine Noise Scheduler

In [90]:
noise_scheduler = DDPMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)

### Start Training

In [91]:
num_train_epochs = 1


In [92]:
for epoch in tqdm(range(num_train_epochs)):
    print("Training has started")

    # set the train mode in all trainable models
    for model in training_models:
        model.train()

    for step, batch in enumerate(train_dataloader):
        images = batch[0].to(device)
        input_ids = batch[1].to(device)
        # extract the low dim latents from the vae
        latents = vae.encode(images).latent_dist.sample()
        print(latents.shape)
        # get the text embedding for conditioning
        with torch.set_grad_enabled():
            encoder_hidden_states = text_encoder(input_ids)

  0%|                                                                                                                                                      | 0/1 [00:00<?, ?it/s]

Training has started


  0%|                                                                                                                                                      | 0/1 [00:00<?, ?it/s]


NameError: Caught NameError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/harshb/miniconda3/envs/sdxl_3.10/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/home/harshb/miniconda3/envs/sdxl_3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/harshb/miniconda3/envs/sdxl_3.10/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_380336/2426679733.py", line 49, in __getitem__
    input_ids = process_text(self.tokenizer)
NameError: name 'process_text' is not defined


In [64]:
# tokenize prompt 
max_length = tokenizer.model_max_length


In [100]:
encoder_hidden_states = text_encoder(text_input['input_ids'])