In [1]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import logging
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
from matplotlib import pyplot as plt
import numpy as np
from torchvision import transforms as tfms

# For video display:
from IPython.display import HTML
from base64 import b64encode

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
import pandas as pd

## Load one aurora embedding

In [3]:
aur_img_embeds = np.load("data/aurora_embeds_3/img_emb/img_emb_0.npy")
aur_img_meta = pd.read_parquet("data/aurora_embeds_3/metadata/metadata_0.parquet")

In [4]:
first_aur_img = Image.open(aur_img_meta.iloc[4,0])
first_aur_embed = aur_img_embeds[4]

## Load some stuff

Lifted from [walk_with_stable_diffusion.ipynb](https://colab.research.google.com/drive/1Ef_3FOJUXNFm2gLl5vMe35A_uCk85kuZ?usp=sharing#scrollTo=nom-hSmvUvDE) by Zach Mueller (thanks!)

In [5]:
# Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device);

In [6]:
def embed(prompt):
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    return text_embeddings

In [8]:
batch_size=1
height = 512
width=512
seed_2=42
noise  = torch.randn(
  (batch_size, unet.in_channels, height // 8, width // 8),
  generator=torch.manual_seed(seed_2),
)

In [9]:
e = embed("A good ol' aurora northern lights pizza")

In [10]:
def diffuse(text_embeddings, init_noise):

    height = 768                        # default height of Stable Diffusion
    width = 768                         # default width of Stable Diffusion
    num_inference_steps = 20            # Number of denoising steps
    guidance_scale = 7.5                # Scale for classifier-free guidance
    generator = torch.manual_seed(42)   # Seed generator to create the inital latent noise
    batch_size = 1

    # Prep Scheduler
    scheduler.set_timesteps(num_inference_steps)

    # Prep latents
    latents = init_noise
    latents = latents.to(torch_device)
    latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]

    # Loop
    with autocast("cuda"):
        for i, t in tqdm(enumerate(scheduler.timesteps)):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            sigma = scheduler.sigmas[i]
            # Scale the latents (preconditioning):
            # latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5) # Diffusers 0.3 and below
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            # latents = scheduler.step(noise_pred, i, latents)["prev_sample"] # Diffusers 0.3 and below
            latents = scheduler.step(noise_pred, t, latents).prev_sample

    # scale and decode the image latents with vae
    latents = 1 / 0.18215 * latents
    with torch.no_grad():
        image = vae.decode(latents).sample

    # Display
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images[0]

In [11]:
diffuse(e, noise[:769,:])

0it [00:00, ?it/s]

RuntimeError: mat1 and mat2 shapes cannot be multiplied (154x768 and 1024x320)

In [None]:
import torch
from einops import repeat

aur_embed_torch = torch.from_numpy(first_aur_embed).to(torch_device)

In [None]:
aur_embed_torch.shape

In [None]:
aur_embed_torch = repeat(aur_embed_torch, "c -> n c s", n=1, s=3)

diffuse(torch.from_numpy(first_aur_embed).to(torch_device), noise)