In [1]:
# Stable Diffusion is a text-to-image latent diffusion model. 
# It is called a latent diffusion model because it works with 
# a lower-dimensional representation of the image instead of 
# the actual pixel space, which makes it more memory efficient.
# The encoder compresses the image into a smaller representation,
# and a decoder to convert the compressed representation back 
# into an image. For text-to-image models, you’ll need a 
# tokenizer and an encoder to generate text embeddings. From 
# the previous example, you already know you need a UNet model 
# and a scheduler.

#this is already more complex than the DDPM pipeline which only contains a UNet model. 
#The Stable Diffusion model has three separate pretrained models.

#You can find them in the pretrained 
# runwayml/stable-diffusion-v1-5 checkpoint, 
# and each component is stored in a separate subfolder
# import os
# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

In [2]:
from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler

vae=AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
tokenizer=CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
text_encoder=CLIPTextModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True)
unet=UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True)

  from .autonotebook import tqdm as notebook_tqdm
  deprecate("Transformer2DModelOutput", "1.0.0", deprecation_message)


In [3]:
#instead of default PNDMscheduler, exchange it for the UniPCMultistepScheduler
from diffusers import UniPCMultistepScheduler
scheduler=UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

# torch_device="mps" if torch.backends.mps.is_available() else "cpu"
# print(torch_device)
torch_device="cpu"
vae.to(torch_device)
text_encoder.to(torch_device)
unet.to(torch_device)

#next step is to tokenize the text to generate embeddings. 
# The text is used to condition the UNet model and steer 
# the diffusion process towards something that resembles 
# the input prompt.

#guidance_scale parameter determines how much weight should 
#be given to the prompt when generating the image.

prompt=["a photograph of a little cat on sofa"]
height=512
width=512
num_inference_steps=25
guidance_scale=7.5
generator=torch.manual_seed(0)
batch_size=len(prompt)

text_input=tokenizer(prompt,padding="max_length", max_length=tokenizer.model_max_length,truncation=True, return_tensors="pt")
with torch.no_grad():
    text_embeddings=text_encoder(text_input.input_ids.to(torch_device))[0]

#need to generate the unconditional text embeddings 
# which are the embeddings for the padding token. 
max_length=text_input.input_ids.shape[-1]
uncond_input=tokenizer(
    [""]*batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings=text_encoder(uncond_input.input_ids.to(torch_device))[0]

#concatenate the conditional and unconditional embeddings 
# into a batch to avoid doing two forward passes
text_embeddings=torch.cat([uncond_embeddings,text_embeddings])

In [4]:
#generate some initial random noise as a 
# starting point for the diffusion process. 
# This is the latent representation of the image, 
# and it’ll be gradually denoised
if torch_device == "mps":
    generator = torch.Generator(device=torch_device)
else:
    generator = torch.Generator()


latents = torch.randn(
    (batch_size, unet.config.in_channels, height // 8, width // 8),
    generator=generator,
    device=torch_device,
)
latents = latents * scheduler.init_noise_sigma
#denoising loop needs to do three things:
#Set the scheduler’s timesteps to use during denoising.
#Iterate over the timesteps.
#At each timestep, call the UNet model to predict 
# the noise residual and pass it to the scheduler 
# to compute the previous noisy sample.

from tqdm.auto import tqdm

scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps):
    # expand the latents if we are doing classifier-free guidance 
    # to avoid doing two forward passes.
    latent_model_input=torch.cat([latents]*2)
    latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

    # predict the noise residual
    with torch.no_grad():
        noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, t, latents).prev_sample

100%|██████████| 25/25 [01:49<00:00,  4.37s/it]


In [6]:
#decode the image
#final step is to use the vae to decode the latent representation
#into an image and get the decoded output with sample

#scale and decode the image latents with vae
latents=1/0.18215*latents
with torch.no_grad():
    image=vae.decode(latents).sample

#convert image to a PIL.Image to show
image=(image/2+0.5).clamp(0,1).squeeze()
image=(image.permute(1,2,0)*255).to(torch.uint8).cpu().numpy()
image=Image.fromarray(image)
image.show()

In [None]:
#The loop should set the scheduler’s timesteps, 
# iterate over them, and alternate between calling 
# the UNet model to predict the noise residual and 
# passing it to the scheduler to compute the previous 
# noisy sample.