In [None]:
%pip install numpy
%pip install matplotlib
%pip install fastai
%pip install accelerate
%pip install transformers diffusers ftfy
%pip install torch
%pip install opencv-python
%pip install ipywidgets

### Setup


In [None]:
import logging
import os
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from fastai.basics import show_image, show_images
from fastcore.all import concat
from fastdownload import FastDownload
from huggingface_hub import notebook_login
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging

In [None]:
# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Authenticate
path = Path.home() / ".cache" / "huggingface" / "token"
if not path.exists():
    notebook_login()

In [None]:
# Set seed
torch.manual_seed(1)

# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
# hyper parameters match those used during training the model
scheduler = LMSDiscreteScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000
)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device)
vae_magic = 0.18215  # vae model trained with a scale term to get closer to unit variance

In [None]:
def image2latent(im):
    im = tfms.ToTensor()(im).unsqueeze(0)
    with torch.no_grad():
        latent = vae.encode(im.to(torch_device) * 2 - 1)
    latent = latent.latent_dist.sample() * vae_magic
    return latent

In [None]:
def decode_latent(latents):
    with torch.no_grad():
        return vae.decode(latents / vae_magic).sample

In [None]:
def latents2images(latents):
    latents = latents / vae_magic
    with torch.no_grad():
        imgs = vae.decode(latents).sample
    imgs = (imgs / 2 + 0.5).clamp(0, 1)
    imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
    imgs = (imgs * 255).round().astype("uint8")
    imgs = [Image.fromarray(i) for i in imgs]
    return imgs

In [None]:
def get_embedding_for_prompt(prompt):
    max_length = tokenizer.model_max_length
    tokens = tokenizer([prompt], padding="max_length", max_length=max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        embeddings = text_encoder(tokens.input_ids.to(torch_device))[0]
    return embeddings

In [None]:
def generate_noise_pred(prompts, im_latents, seed=32, g=0.15):
    height = 512  # default height of Stable Diffusion
    width = 512  # default width of Stable Diffusion
    num_inference_steps = 30  # Number of denoising steps
    generator = torch.manual_seed(seed)  # Seed generator to create the inital latent noise

    uncond = get_embedding_for_prompt("")
    text = get_embedding_for_prompt(prompts)
    text_embeddings = torch.cat([uncond, text])

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    # Prep latents
    if im_latents != None:
        # img2img
        # start_step = 10
        start_step = int(num_inference_steps * 0.5)
        timesteps = torch.tensor([scheduler.timesteps[-start_step]], device=torch_device)
        noise = torch.randn_like(im_latents)
        latents = scheduler.add_noise(im_latents, noise, timesteps=timesteps)
        latents = latents.to(torch_device).float()
    else:
        # just text prompts
        start_step = -1  # disable branching below
        latents = torch.randn((1, unet.in_channels, height // 8, width // 8))  # ,generator=generator)
        latents = latents.to(torch_device)
        latents = latents * scheduler.init_noise_sigma  # scale to initial amount of noise for t0

    latent_model_input = torch.cat([latents] * 2)
    latent_model_input = scheduler.scale_model_input(latent_model_input, timesteps)
    with torch.no_grad():
        u, t = unet(latent_model_input, timesteps, encoder_hidden_states=text_embeddings).sample.chunk(2)
    pred_nonscaled = u + g * (t - u) / torch.norm(t - u) * torch.norm(u)
    pred = pred_nonscaled * torch.norm(u) / torch.norm(pred_nonscaled)
    return scheduler.step(pred, timesteps, latents).pred_original_sample

In [None]:
def generate_image_from_embedding(text_embeddings, im_latents, mask=None, seed=None, guidance_scale=0.15):
    height = 512  # default height of Stable Diffusion
    width = 512  # default width of Stable Diffusion
    num_inference_steps = 30  # Number of denoising steps
    if seed is None:
        seed = torch.seed()
    generator = torch.manual_seed(seed)  # Seed generator to create the inital latent noise

    uncond = get_embedding_for_prompt("")
    text_embeddings = torch.cat([uncond, text_embeddings])

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    # Prep latents

    if im_latents != None:
        # img2img
        start_step = 10
        noise = torch.randn_like(im_latents)
        latents = scheduler.add_noise(im_latents, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]))
        latents = latents.to(torch_device).float()
    else:
        # just text prompts
        start_step = -1  # disable branching below
        latents = torch.randn((1, unet.in_channels, height // 8, width // 8))  # ,generator=generator)
        latents = latents.to(torch_device)
        latents = latents * scheduler.init_noise_sigma  # scale to initial amount of noise for t0

    noisy_latent = latents.clone()
    # Loop
    noise_pred = None
    for i, tm in tqdm(
        enumerate(scheduler.timesteps), total=num_inference_steps, desc="Generating Masked Image for Prompt"
    ):
        if i > start_step:
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = scheduler.scale_model_input(latent_model_input, tm)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, tm, encoder_hidden_states=text_embeddings)["sample"]

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            u = noise_pred_uncond
            g = guidance_scale
            t = noise_pred_text

            if g > 0:
                pred_nonscaled = u + g * (t - u) / torch.norm(t - u) * torch.norm(u)
                pred = pred_nonscaled * torch.norm(u) / torch.norm(pred_nonscaled)
            else:
                pred = u

            noise_pred = pred

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, tm, latents).prev_sample
            if mask is not None:
                latents = latents * mask + im_latents * (1.0 - mask)

    noise_pred = noisy_latent - latents
    return latents2images(latents)[0], noise_pred

In [None]:
def image2latentmask(im):
    im = tfms.ToTensor()(im).permute(1, 2, 0)
    m = im.mean(-1)  # convert to grayscale
    m = (m > 0.5).float()  # binarize to 0.0 or 1.0
    m = cv2.resize(m.cpu().numpy(), (64, 64), interpolation=cv2.INTER_NEAREST)
    m = torch.tensor(m).to(torch_device)
    return m