**Importing Useful Libraries**

In [13]:
pip install -q diffusers transformers accelerate

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import diffusers
from diffusers import StableDiffusionPipeline
from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
import numpy as np
from tqdm import tqdm
from torch import autocast

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
# Auto encoder model to convert image from pixel space to latent space
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device);

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vae/config.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/335M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.22M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

In [6]:
def pixel_to_latent(image):
    with torch.no_grad():
        latent_image = vae.encode(image.to(device).unsqueeze(0)).latent_dist.sample()
    return latent_image

In [7]:
# Prep Scheduler
def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)
    scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925


In [8]:
# Transformations for the CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize images to 224x224 for ResNet
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # Normalize images
])

# Load CIFAR-10 dataset
cifar10_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
prompt = cifar10_dataset.classes

# Get one image and its label
image_orignal, label = cifar10_dataset[200]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12298130.88it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data


In [9]:
print(type(image_orignal), image_orignal.shape)

<class 'torch.Tensor'> torch.Size([3, 512, 512])


In [10]:
print(cifar10_dataset.classes)

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [17]:
############ Diffusion Model ################

# Number of denoising steps
num_inference_steps = 30

# Scale for classifier-free guidance
guidance_scale = 7.5


# Conversion of image into latent vector
latent_image = pixel_to_latent(image_orignal)
print(latent_image.shape)
# Adding noise to the latent space of the original image


latents = latent_image * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]
print(torch.tensor(latents)[0].squeeze().shape)
# Generating Embedding of the possible classes for the original image

# Tokenization
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

# Converting tokens to its embeddings
cond_embeddings = text_encoder(text_input.input_ids.to(device))[0]
print(cond_embeddings.shape)
# Making unconditional text embeddings
max_length = text_input.input_ids.shape[-1]
batch_size = 10
uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

# Concatinating unconditional and conditional embeddings
text_embeddings = torch.cat([uncond_embeddings, cond_embeddings])
print(text_embeddings.shape)
# Setting timesteps
set_timesteps(scheduler, num_inference_steps)

 # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
latent_model_input = torch.cat([latents] * 20)



torch.Size([1, 4, 64, 64])
torch.Size([4, 64, 64])


  print(torch.tensor(latents)[0].squeeze().shape)


torch.Size([10, 77, 768])
torch.Size([20, 77, 768])


In [18]:
# Loop
for t in tqdm(scheduler.timesteps):


    latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
    print(f"latent_model_input shape: {latent_model_input.shape}")
    print(f"text_embeddings shape: {text_embeddings.shape}")

    # predict the noise residual
    with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states = text_embeddings).sample

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, t, latents).prev_sample

  0%|          | 0/30 [00:00<?, ?it/s]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


  3%|▎         | 1/30 [00:00<00:24,  1.18it/s]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


  7%|▋         | 2/30 [00:04<01:11,  2.56s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 10%|█         | 3/30 [00:08<01:23,  3.10s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 13%|█▎        | 4/30 [00:12<01:27,  3.35s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 17%|█▋        | 5/30 [00:15<01:27,  3.49s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 20%|██        | 6/30 [00:19<01:26,  3.58s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 23%|██▎       | 7/30 [00:23<01:23,  3.65s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 27%|██▋       | 8/30 [00:27<01:21,  3.70s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 30%|███       | 9/30 [00:30<01:18,  3.73s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 33%|███▎      | 10/30 [00:34<01:15,  3.76s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 37%|███▋      | 11/30 [00:38<01:12,  3.79s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 40%|████      | 12/30 [00:42<01:08,  3.83s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 43%|████▎     | 13/30 [00:46<01:05,  3.86s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 47%|████▋     | 14/30 [00:50<01:02,  3.89s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 50%|█████     | 15/30 [00:54<00:58,  3.93s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 53%|█████▎    | 16/30 [00:58<00:55,  3.97s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 57%|█████▋    | 17/30 [01:02<00:52,  4.01s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 60%|██████    | 18/30 [01:06<00:48,  4.04s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 63%|██████▎   | 19/30 [01:10<00:44,  4.07s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 67%|██████▋   | 20/30 [01:15<00:40,  4.09s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 70%|███████   | 21/30 [01:19<00:36,  4.09s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 73%|███████▎  | 22/30 [01:23<00:32,  4.08s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 77%|███████▋  | 23/30 [01:27<00:28,  4.06s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 80%|████████  | 24/30 [01:31<00:24,  4.04s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 83%|████████▎ | 25/30 [01:35<00:20,  4.02s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 87%|████████▋ | 26/30 [01:39<00:15,  4.00s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 90%|█████████ | 27/30 [01:43<00:11,  3.98s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 93%|█████████▎| 28/30 [01:47<00:07,  3.97s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


 97%|█████████▋| 29/30 [01:50<00:03,  3.95s/it]

latent_model_input shape: torch.Size([20, 4, 64, 64])
text_embeddings shape: torch.Size([20, 77, 768])


100%|██████████| 30/30 [01:54<00:00,  3.83s/it]


In [19]:
# img = latents.squeeze(0).permute(1, 2, 0).numpy()
print(latents.shape)
print(latents.dtype)
# Display the image
# plt.imshow(img)
# plt.title(f'Label: {label}')
# plt.axis('off')
# plt.show()

torch.Size([10, 4, 64, 64])
torch.float32
