In [None]:
!pip install transformers
!pip install diffusers
!pip install accelerate

In [None]:
from PIL import Image
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, UNet2DConditionModel, EulerAncestralScheduler
import matplotlib.pyplot as plt

In [None]:
vae = AutoencoderKL.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
scheduler = EulerAncestralScheduler.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="scheduler")

In [None]:
device = 'cuda'
# Move the most largest three modules to GPU.
vae.to(device)
text_encoder.to(device)
unet.to(device)

In [None]:
prompt = ["a illustration of a anime girl",
          "a illustration of a anime boy"]
width = 512
height = 512
steps = 20
CFG = 7
generator = torch.manual_seed(12345)
batch_size = len(prompt)

In [None]:
# Pass POSITIVE prompt to Tokenizer
token_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt" )

# Pass POSITIVE token series to Text Encoder
with torch.no_grad():
  cond_embeddings = text_encoder(token_input.input_ids.to(device))[0]

# Shows the max_length from token_input. This should be 77. 
max_length = token_input.input_ids.shape[-1]
print(tokenizer.model_max_length)
# Shows the size of POSITIVE guidance embeddings. Batch size = 2.
# print(cond_embeddings.shape)

In [None]:
# Pass NEGATIVE prompt to Tokenizer
uncond_token_input = tokenizer(["perlin noise,voronoi,ugly,lowres,low quality,disfigured,kitsch,ugly,grain,deformed,blurry"] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")

# Pass NEGATIVE token series to Text Encoder
with torch.no_grad():
  uncond_embeddings = text_encoder(uncond_token_input.input_ids.to(device))[0]

# Shows the size of NEGATIVE guidance embeddings. Batch size = 2.
# print(uncond_embeddings.shape)

77
torch.Size([2, 77, 768])


In [None]:
# Concatenate POSITIVE and NEGATVIE guidance embeddings
embeddings = torch.cat([uncond_embeddings, cond_embeddings])
print(embeddings.shape)

torch.Size([4, 77, 768])


In [None]:
# Create latent noise (64w x 64h x 4ch).
latents = torch.randn((batch_size, unet.in_channels, height//8, width//8), generator=generator)
# Cache latent noise to GPU.
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma

In [None]:
from tqdm.auto import tqdm

# Run denoising loops.
scheduler.set_timesteps(steps)
for t in tqdm(scheduler.timesteps):
  latent_model_input = torch.cat([latents]*2)
  latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)

  with torch.no_grad():
    noise_pred = unet(latent_model_input, t, encoder_hidden_states=embeddings).sample

  noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
  # Classifier-free guidance
  epsilon = noise_pred_uncond + CFG * ( noise_pred_cond - noise_pred_uncond)

  latents = scheduler.step(epsilon, t, latents).prev_sample

In [None]:
latents = latents / 0.18215
with torch.no_grad():
  vae_outputs = vae.decode(latents).sample

In [None]:
# Arrange latent images. Note that we discard one channel of latent.
temp = (0.18215* latents / 2 + 0.5).clamp(0,1)
temp = temp.detach().cpu().permute(0,2,3,1).numpy()
temp = (temp * 255).round().astype('uint8')

# Plot latent images. 
plt.figure(figsize=(6,3))
for i in range(len(prompt)):
  plt.subplot(1, len(prompt), i+1)
  plt.imshow(temp[i])

In [None]:
# Arrange for final decoded images.
images = (vae_outputs / 2 + 0.5).clamp(0,1)
images = images.detach().cpu().permute(0,2,3,1).numpy()
images = (images * 255).round().astype('uint8')
pil_images = [Image.fromarray(image) for image in images]

# Plot final images.
plt.figure(figsize=(10,5))
for i in range(len(prompt)):
  plt.subplot(1, len(prompt), i+1)
  plt.imshow(pil_images[i])