In [3]:
import argparse
import itertools
import math
import os
from multiprocessing import Value
import toml

from tqdm import tqdm
import torch
import transformers

In [4]:
from accelerate.utils import set_seed
from diffusers import DDPMScheduler

  torch.utils._pytree._register_pytree_node(


### Prepare Dataset

In [223]:
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import pathlib
from PIL import Image, UnidentifiedImageError
import numpy as np


transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512,512))
])
instance_prompt = "a photo of a zwx man"
class_prompt = "a photo of a man"
prior_loss_weight = 1.0


class DreamBoothDataset(Dataset):
    def __init__(self, data_dir, reg_dir, transforms, tokenizer, instance_prompt, class_prompt):
        self.instance_images = self.load_images(data_dir)
        self.reg_images = self.load_images(reg_dir)
        self.transforms = transforms
        self.tokenizer = tokenizer
        self.instance_prompt = instance_prompt
        self.class_prompt = class_prompt
        self._length = max(len(self.instance_images), len(self.reg_images))
        

    def load_images(self, data_dir):
        images = [] 
        for img_path in tqdm(pathlib.Path(data_dir).glob("*")):
            try:
                img = Image.open(img_path)
                images.append(np.array(img))
            except UnidentifiedImageError:
                print(f"Error: {img_path} is not a valid image file.")
            except Exception as e:
                print(f"An error occurred while opening {img_path}: {e}")
        return images

    def process_text(self, tokenizer, input_text):
        max_length = tokenizer.model_max_length
        text_input = tokenizer(
        input_text,
        truncation=True,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt"
        )
        return text_input['input_ids']

    def __len__(self):
        return self._length

    def __getitem__(self, idx):
        if idx % 2 == 0:
            img = self.instance_images[idx % len(self.instance_images)]
            input_ids = self.process_text(self.tokenizer, self.instance_prompt)
        else:
            reg_idx = torch.randint(0, len(self.reg_images), (1,)).item()
            img = self.reg_images[reg_idx]
            input_ids = self.process_text(self.tokenizer, self.class_prompt)
        if self.transforms:
            img = transforms(img)
        
        return img, input_ids

### Download the SD model

In [224]:
from huggingface_hub import snapshot_download

model_dir = snapshot_download('runwayml/stable-diffusion-v1-5', local_dir="./runwayml-sd1.5")

Fetching 36 files: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 36/36 [00:00<00:00, 2906.49it/s]


### Generate Reglarization Images

In [148]:
from pathlib import Path
from diffusers import StableDiffusionPipeline


reg_dir = Path("./reg_dir")
num_reg_images = 50
batch_size = 4
if not reg_dir.exists() or len(list(reg_dir.glob("*.jpg"))) < num_reg_images:
    pipeline = StableDiffusionPipeline.from_pretrained(model_dir).to('cuda')
    reg_dir.mkdir(exist_ok=True)
    for i in range(num_reg_images // batch_size):
        if i == (num_reg_images // batch_size)-1 :
            images = pipeline(class_prompt, num_images_per_prompt=num_reg_images % batch_size).images
        else:
            images = pipeline(class_prompt, num_images_per_prompt=batch_size).images

        [img.save(f'{reg_dir}/{i}_{j}.jpg') for j, img in enumerate(images)]

Loading pipeline components...:   0%|                                                                                                                                                    | 0/7 [00:00<?, ?it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  8.97it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████

In [225]:
# load tokenizer
from transformers import CLIPTokenizer, CLIPTextModel

tokenizer = CLIPTokenizer.from_pretrained('./runwayml-sd1.5/tokenizer')
tokenizer

CLIPTokenizer(name_or_path='./runwayml-sd1.5/tokenizer', vocab_size=49408, model_max_length=77, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	49406: AddedToken("<|startoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
	49407: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [226]:
dataset = DreamBoothDataset('./harsh_photos', reg_dir, transforms, tokenizer, instance_prompt, instance_prompt)
len(dataset)

17it [00:01, 16.45it/s]
46it [00:00, 696.80it/s]


46

In [229]:
train_dataloader = DataLoader(
    dataset,
    batch_size=1,
    shuffle=True,
    num_workers=8
)

In [230]:
device = 'cuda'

### Define Tensor Presicion

In [231]:
dtype = torch.float16

### Load Text Encoder, Unet And VAE

In [233]:
from diffusers import StableDiffusionPipeline

pipe = StableDiffusionPipeline.from_pretrained('./runwayml-sd1.5').to('cuda')

vae = pipe.vae
text_encoder = pipe.text_encoder
unet = pipe.unet

vae, text_encoder, unet

Loading pipeline components...:   0%|                                                                                                                                                    | 0/7 [00:00<?, ?it/s]`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
Loading pipeline components...: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00,  7.94it/s]


(AutoencoderKL(
   (encoder): Encoder(
     (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
     (down_blocks): ModuleList(
       (0): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0-1): 2 x ResnetBlock2D(
             (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
             (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
             (dropout): Dropout(p=0.0, inplace=False)
             (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
             (nonlinearity): SiLU()
           )
         )
         (downsamplers): ModuleList(
           (0): Downsample2D(
             (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
           )
         )
       )
       (1): DownEncoderBlock2D(
         (resnets): ModuleList(
           (0): ResnetBlock2D(
             (norm

#### Weather to train text encoder

In [234]:
train_text_encoder = True

In [235]:
unet.requires_grad_(True)
text_encoder.requires_grad_(train_text_encoder)

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

#### gradient checkpointing

In [236]:
unet.enable_gradient_checkpointing()
text_encoder.gradient_checkpointing_enable()

### We don't need to train VAE so make sure it is in eval mode

In [237]:
vae.requires_grad_(False)
vae.eval()

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): LoRACompatibleConv(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, ep

In [238]:
lr = 1e-5

In [239]:
trainable_params = [
    {"params": list(unet.parameters()), "lr": lr },
    {"params": list(text_encoder.parameters()), "lr": lr }
]

In [240]:
optimizer_type = transformers.optimization.Adafactor
optimizer = optimizer_type(trainable_params, lr=lr, relative_step=False)
optimizer

Adafactor (
Parameter Group 0
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 1e-05
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0

Parameter Group 1
    beta1: None
    clip_threshold: 1.0
    decay_rate: -0.8
    eps: (1e-30, 0.001)
    lr: 1e-05
    relative_step: False
    scale_parameter: True
    warmup_init: False
    weight_decay: 0.0
)

### Initialize Scheduler

In [241]:
from diffusers.optimization import SchedulerType, TYPE_TO_SCHEDULER_FUNCTION
scheduler_type = SchedulerType("constant") # consine, polynomial
scheduler = TYPE_TO_SCHEDULER_FUNCTION[scheduler_type](optimizer)
scheduler

<torch.optim.lr_scheduler.LambdaLR at 0x769ef5adc970>

In [242]:
if train_text_encoder:
    training_models = [text_encoder, unet]
else:
    [unet]

### Difine Noise Scheduler

In [243]:
noise_scheduler = DDPMScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
)

### Start Training

In [244]:
num_train_epochs = 1


In [245]:
for epoch in tqdm(range(num_train_epochs)):
    print("Training has started")

    # set the train mode in all trainable models
    for model in training_models:
        model.train()

    for step, batch in enumerate(train_dataloader):
        images = batch[0].to(device)
        input_ids = batch[1].to(device)
        # extract the low dim latents from the vae
        latents = vae.encode(images).latent_dist.sample()
        print(latents.shape)
        # get the text embedding for conditioning
        with torch.set_grad_enabled(True):
            encoder_hidden_states = text_encoder(input_ids)[0]

        # sample a random timestep for each image, add noise to the latents
        batch_size = latents.shape[0]
        min_timestep = 0
        max_timestep = 1000

        # generate random noise
        noise = torch.randn_like(latents, device=latents.device)
        # generate random timestep 
        timesteps = torch.randint(min_timestep, max_timestep, (batch_size,), device=device)
        # apply noise to latents
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
        # run the unet
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
        target = noise

        # calculate loss, we are using l2 loss
        loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
        loss = loss.mean([1,2,3])
        loss = loss * 1.0

        loss = loss.mean()
        loss.backward()

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad(set_to_none=True)
        
        break
    break
            

  0%|                                                                                                                                                                                    | 0/1 [00:00<?, ?it/s]

Training has started
torch.Size([1, 4, 64, 64])


  0%|                                                                                                                                                                                    | 0/1 [00:01<?, ?it/s]


NameError: name 'lr_scheduler' is not defined

In [None]:
# tokenize prompt 

In [None]:
encoder_hidden_states = text_encoder(text_input['input_ids'])