In [None]:
# Install the required libraries
!pip install diffusers["torch"] transformers datasets einops

In [2]:
import torch
import torch.nn as nn
from torch.optim import Adam, AdamW
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

from accelerate import Accelerator
from datasets import load_dataset

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset

from tqdm.auto import tqdm

from models import UNetModelVD_DualGuided

from transformers import CLIPTextModelWithProjection, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPImageProcessor
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.image_processor import VaeImageProcessor

from diffusers.utils.torch_utils import randn_tensor

# Set the default torch datatype to float16 to optimize memory usage and avoid RAM issues
torch.set_default_dtype(torch.float16)

In [None]:
# Download the pretrained weights of the Versatile Diffusion model and rename them from 'diffuser.text/image' to 'unet_text/image'
# In the official implementation, 'diffuser.text/image' represents the model. In this context, we are using 'unet_text/image'
from huggingface_hub import hf_hub_download

vd_path = hf_hub_download('shi-labs/versatile-diffusion-model', 'pretrained_pth/vd-four-flow-v1-0-fp16.pth')

state_dict = torch.load(vd_path)

renamed_state_dict = {}

for key, value in state_dict.items():
    if key.startswith("diffuser.text."):
        new_key = key.replace("diffuser.text.", "unet_text.")
    elif key.startswith("diffuser.image."):
        new_key = key.replace("diffuser.image.", "unet_image.")
    else:
        new_key = key
    renamed_state_dict[new_key] = value

In [None]:
# Download a minimal dataset consisting of image-text pairs for experimentation.
dataset = load_dataset("lambdalabs/pokemon-blip-captions")

class CustomDataset(Dataset):
  def __init__(self, vae_encoder_img2latent, image_context_encode, text_context_encode, transform):
    self.dataset = dataset['train']
    self.vae_encoder_img2latent = vae_encoder_img2latent
    self.image_context_encode = image_context_encode
    self.text_context_encode = text_context_encode
    self.transform = transform
    self.device = "cuda"

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):

    data = self.dataset[idx]

    # Get image
    image = self.transform(data['image'].convert("RGB"))

    # Process image through VAE encoder
    image_latent = self.vae_encoder_img2latent(image.to(self.device).half()).squeeze(0)

    # Process image through CLIP image encoder
    image_context = self.image_context_encode(data['image']).squeeze(0)
    
    # Get the corresponding text
    text = data['text']

    # Process text through CLIP text encoder
    text_context = self.text_context_encode(text).squeeze(0)

    return image_latent, image_context, text_context

In [None]:
def load_config(model, sub_folder):
    return model.from_pretrained(
      "shi-labs/versatile-diffusion", subfolder=sub_folder)

device = "cuda"

# Load the configuration for various components: VAE, CLIP, and the final model.
# .half() will cast to float16 to optimize memory usage
def load_models():
  scheduler = load_config(DDIMScheduler, "scheduler")
  with torch.no_grad():
    vae = load_config(AutoencoderKL, "vae").half().to(device).requires_grad_(False)
    tokenizer = load_config(CLIPTokenizer, "tokenizer")
    text_encoder = load_config(CLIPTextModelWithProjection, "text_encoder").half().to(device).requires_grad_(False)
    image_feature_extractor = load_config(CLIPImageProcessor, "image_feature_extractor")
    image_encoder = load_config(CLIPVisionModelWithProjection, "image_encoder").half().to(device).requires_grad_(False)
  # Instantiate the Dual Guided model.
  model = UNetModelVD_DualGuided().half().to(device)
  # Load the pretrained model weights.
  model.load_state_dict(renamed_state_dict, strict=False)
  return model, scheduler, vae, tokenizer, text_encoder, image_feature_extractor, image_encoder

model, scheduler, vae, tokenizer, text_encoder, image_feature_extractor, image_encoder = load_models()

In [22]:
class DualGuidedVersatileDiffusion():

  def __init__(self):
    super(DualGuidedVersatileDiffusion, self).__init__()

    self.device = "cuda" if torch.cuda.is_available() else "cpu"

    self.model = model
    self.scheduler = scheduler
    self.vae = vae
    self.tokenizer = tokenizer
    self.text_encoder = text_encoder
    self.image_feature_extractor = image_feature_extractor
    self.image_encoder = image_encoder

    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
    self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)

    self.optimizer = AdamW(self.model.parameters(), lr=1e-4)
    self.mse_loss = nn.MSELoss()
    self.batch_size = 1
    self.height = 256
    self.width = 256

    self.transform = transforms.Compose(
        [
          transforms.Resize((self.height, self.width)),
          transforms.ToTensor(),
          transforms.Normalize([0.5], [0.5]),
        ]
      )
    self.data_loader = DataLoader(
        CustomDataset(self.vae_encoder_img2latent, self.image_context_encode, self.text_context_encode, self.transform),
        batch_size=self.batch_size, shuffle=True
    )


  # Encode and normalize text context
  def text_context_encode(self, text):
    text_inputs = self.tokenizer(
          text,
          padding="max_length",
          max_length= self.tokenizer.model_max_length,
          truncation=True,
          return_tensors="pt",
          )
    text_embeds = self.text_encoder(text_inputs.input_ids.to(self.device))
    embeds = self.text_encoder.text_projection(text_embeds.last_hidden_state)
    embeds_pooled = text_embeds.text_embeds
    embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
    return embeds

  # Encode and normalize image context
  def image_context_encode(self, image):
    image_features = self.image_feature_extractor(images=image, return_tensors="pt")
    image_features = image_features.pixel_values.to(self.device).to(self.image_encoder.dtype)
    image_embeds = self.image_encoder(image_features)
    embeds = self.image_encoder.vision_model.post_layernorm(image_embeds.last_hidden_state)
    embeds = self.image_encoder.visual_projection(embeds)
    embeds_pooled = embeds[:, 0:1]
    embeds = embeds / torch.norm(embeds_pooled, dim=-1, keepdim=True)
    return embeds

  # Encode input image into the latent space
  def vae_encoder_img2latent(self, image):
    image = image.unsqueeze(0)
    encoded = self.vae.encode(image)
    latent = encoded.latent_dist.sample() * self.vae.config.scaling_factor
    return latent

  # Training
  #========================================================================================
  def train_step(self, mixing_type="attention"):

    running_loss = 0.0
    train_bar = tqdm(self.data_loader, desc='Training')
    self.model.train()

    for step, batch in enumerate(train_bar):

      self.optimizer.zero_grad()

      image_latents, image_context, text_context = batch

      # Generate sample noise to be added to the image
      noise = torch.randn(image_latents.shape).to(image_latents.device)

      # Get a random timestep
      timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (self.batch_size,), device=image_latents.device).long()

      # Add noise at the specified timestep
      noisy_latents = self.scheduler.add_noise(image_latents, noise, timesteps).half().to(self.device)

      # For Classifier-free guidance
      unconditional_image_context = self.image_context_encode([np.zeros((self.height, self.width, 3)) + 0.5] * 1)
      unconditional_text_context = self.text_context_encode([""] * 1)

      # Train the model 10% unconditionally and 90% conditionally
      # In this case, we consider image input and image context the same (it would be better if we have a tailored dataset)
      if torch.rand(1).item() < 0.1:
        noise_pred = self.model.forward(noisy_latents, timesteps, xtype="image", c0=unconditional_image_context, c1=unconditional_text_context, c0_type='image', c1_type='text', c0_ratio=0.4, c1_ratio=0.6, mixing_type=mixing_type)
      else:
        noise_pred = self.model.forward(noisy_latents, timesteps, xtype="image", c0=image_context, c1=text_context, c0_type='image', c1_type='text', c0_ratio=0.4, c1_ratio=0.6, mixing_type=mixing_type)

      loss = self.mse_loss(noise_pred.float(), noise.float())

      loss.backward()

      self.optimizer.step()

      running_loss += loss.item()

      train_bar.set_postfix({'mse': f'{running_loss / (step+1):.4f}'})

    total_loss = running_loss / (len(self.data_loader) * self.batch_size)
    return total_loss

  def train(self, epochs):

    for epoch in epochs:

      print("Epoch : ", epoch, " Start")
      loss = self.train_step()
      print("Epoch : ", epoch, " End with loss = ", loss)

  #========================================================================================
  
  # Sampling

  def prepare_context_sampling(self, image_context, text_context):
    conditional_image_context = self.image_context_encode(image_context)
    conditional_text_context = self.text_context_encode(text_context)

    unconditional_image_context = self.image_context_encode([np.zeros((self.height, self.width, 3)) + 0.5] * 1)
    unconditional_text_context = self.text_context_encode([""] * 1)

    image_embeddings = torch.cat([unconditional_image_context, conditional_image_context], axis=0)
    text_embeddings = torch.cat([unconditional_text_context, conditional_text_context], axis=0)

    return image_embeddings, text_embeddings

  # Generate a sample noise 
  def prepare_latents_sampling(self, num_channels_latents, device, dtype, batch_size=1, latents=None):
    shape = (batch_size, num_channels_latents, self.height // self.vae_scale_factor, self.width // self.vae_scale_factor)
    generator = torch.Generator(device=self.vae.device).manual_seed(0)
    if isinstance(generator, list) and len(generator) != batch_size:
        raise ValueError(
            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
        )

    if latents is None:
        latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
    else:
        latents = latents.to(device)

    # scale the initial noise by the standard deviation required by the scheduler
    latents = latents * self.scheduler.init_noise_sigma
    return latents

  # Decode the generated latent image
  def decode_latents_sampling(self, latents):

    latents = 1 / self.vae.config.scaling_factor * latents
    image = self.vae.decode(latents, return_dict=False)[0]
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image

  def sampling(self, image, text, guidance_factor=7.5):
    self.scheduler.set_timesteps(1000)
    latents = self.prepare_latents_sampling(4,  self.vae.device, torch.float16,)
    image_embeddings, text_embeddings = self.prepare_context_sampling(image, text)
    for t in tqdm(self.scheduler.timesteps):
      # Expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
      latent_model_input = torch.cat([latents] * 2)
      t = torch.tensor([t-1])
      t = torch.cat([t] * 2).to(device)

      latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep=t)

      # Predict the noise residual
      with torch.no_grad():
        noise_pred = self.model.forward(latent_model_input, t, xtype="image", c0=image_embeddings, c1=text_embeddings, c0_type='image', c1_type='text', c0_ratio=0.4, c1_ratio=0.6, mixing_type="attention")

      # Perform guidance
      noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
      # Linear interpolation
      noise_pred = noise_pred_uncond + guidance_factor * (noise_pred_cond - noise_pred_uncond)

      # compute the previous noisy sample x_t -> x_t-1
      latents = self.scheduler.step(noise_pred, t[0].cpu(), latents).prev_sample

    images = self.decode_latents_sampling(latents)
    return images

In [23]:
vd = DualGuidedVersatileDiffusion()

In [None]:
loss = vd.train_step()

In [None]:
import requests
from io import BytesIO
from PIL import Image

# let's download an initial image
url = "https://huggingface.co/datasets/diffusers/images/resolve/main/benz.jpg"

response = requests.get(url)
image = Image.open(BytesIO(response.content)).convert("RGB")
text = "a double decker bus"

In [None]:
images = vd.sampling(image, text)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(image[0])