In [None]:
import torch
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import logging
from diffusers import AutoencoderKL, UNet2DConditionModel, LMSDiscreteScheduler
from tqdm.auto import tqdm
from torch import autocast
from PIL import Image
from matplotlib import pyplot as plt
import numpy
from torchvision import transforms as tfms
import glob
import os

# For video display:
from IPython.display import HTML
from base64 import b64encode

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Load the autoencoder model which will be used to decode the latents into image space. 
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text. 
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device);

In [None]:
def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

def create_movie(dir,movie_name,fps=12):
    !ffmpeg -v 1 -y -f image2 -framerate {fps} -i {dir}/%04d.jpg -c:v libx264 -preset slow -qp 18 -pix_fmt yuv420p {movie_name}

def embed_movie(movie_name):
    mp4 = open(movie_name,'rb').read()
    data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
    return"""
    <video width=600 controls>
          <source src="%s" type="video/mp4">
    </video>
    """ % data_url

In [None]:
def generate_image_from_embedding(prompt, encoded, start_step=10, seed=1):
    # Settings (same as before except for the new prompt)
    prompt = [prompt]
    height = 512                        # default height of Stable Diffusion
    width = 512                         # default width of Stable Diffusion
    num_inference_steps = 50            # Number of denoising steps
    guidance_scale = 7.5                  # Scale for classifier-free guidance
    generator = torch.manual_seed(seed)   # Seed generator to create the inital latent noise
    batch_size = 1

    # Prep text (same as before)
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0] 
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    # Prep Scheduler (setting the number of inference steps)
    scheduler.set_timesteps(num_inference_steps)

    if encoded is None:
        # Prep latents
        start_step = -1
        latents = torch.randn(
        (batch_size, unet.in_channels, height // 8, width // 8),
        generator=generator,
        )
        latents = latents.to(torch_device)
        latents = latents * scheduler.init_noise_sigma # Scaling (previous versions did latents = latents * self.scheduler.sigmas[0]
    else:
        # Prep latents (noising appropriately for start_step)
        #start_step = 10
        start_sigma = scheduler.sigmas[start_step]
        noise = torch.randn_like(encoded)
        latents = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]))
        latents = latents.to(torch_device).float()

    # Loop
    for i, t in tqdm(enumerate(scheduler.timesteps)):
        if i > start_step-1: # << This is the only modification to the loop we do
            
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            sigma = scheduler.sigmas[i]
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]

            # perform guidance
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred, t, latents).prev_sample

    return latents_to_pil(latents)[0]

In [None]:
def loop_diffusion(prompt,encoded,outdir,starting_iteration=1, iterations=10, start_step=10,seed=1, increment_seed=True):
    print(f"Number of iterations: {iterations}")
    im_latents = encoded
    imgs = []
    
    for f in tqdm(range(starting_iteration, starting_iteration + iterations)):
        im = generate_image_from_embedding(prompt,im_latents, start_step, seed)
        if increment_seed: seed = seed + 1
        im_latents = pil_to_latent(im)
        imgs.append(im)
        im.save(f'{outdir}/{f:04}.jpg')
        print(f"Saved {outdir}/{f:04}.jpg")
    return imgs

In [None]:

initial_prompt, prompt = "london bus", "cat"
#initial_prompt, prompt = "london bus", "dog"
initial_prompt, prompt = "mount fuji in spring", "cat"
#initial_prompt, prompt = "mount fuji in spring", "an astronaut on a horse, photo"
initial_prompt, prompt = "mount fuji in spring", "car"

increment_seed = True
initial_image_seed, diffusion_seed = 1000, 1001
start_step = 10
iterations =  200
starting_iteration = 1
exp_label = f"{initial_prompt}-{prompt}-startstep-{start_step}-incrementseed-{increment_seed}-seeds-{initial_image_seed}-{diffusion_seed}"
img_dir = f'frames-{exp_label}'.replace(' ','_')
movie = f'movie-{exp_label}.mp4'.replace(' ','_')

def get_existing_experiment_or_create_new(img_dir):
    imgs = []
    
    if not os.path.exists(img_dir): 
        os.mkdir(img_dir)
        img = generate_image_from_embedding(initial_prompt,None,seed=initial_image_seed)
        imgs.append(img)
        img.save(f'{img_dir}/{0:04}.jpg')
    else:
        print("Image directory already exists. Continuing image generation...")
        # find highest numbered image
        img_files = glob.glob(f'{img_dir}/*.jpg')
        img_files.sort()
        imgs = [Image.open(img_file) for img_file in img_files]
        print(f"Lowest numbered image (original image): {img_files[0]}")
        print(f"Highest numbered image (starting with this image): {img_files[-1]}")
        img = Image.open(img_files[-1])

    
    return imgs

imgs = get_existing_experiment_or_create_new(img_dir)
generated_latents = [pil_to_latent(img) for img in imgs]
starting_iteration = len(imgs)
diffusion_seed = diffusion_seed + len(imgs)
img = imgs[-1]
latent = pil_to_latent(img)

# plot initial image
plt.imshow(img)

# run diffusion and append to list of images
imgs_new = loop_diffusion(prompt,latent,img_dir,starting_iteration=starting_iteration, iterations=iterations, start_step=start_step,seed=1,increment_seed=increment_seed)
imgs.extend(imgs_new)
generated_latents = [pil_to_latent(img) for img in imgs]


In [None]:
create_movie(img_dir,movie,fps=6)
HTML(embed_movie(movie))

In [None]:
import numpy as np

ncols = 10
nrows = int(np.ceil(len(imgs)/ncols))
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20,20))
for i in range(nrows):
    for j in range(ncols):
        if i*ncols + j < len(imgs):
            axs[i,j].imshow(imgs[i*ncols+j])
            axs[i,j].axis('off')
            #axs[i,j].set_title(f"{i*ncols+j:04}")



In [None]:
# do dimensionality reduction on the latents
from sklearn.decomposition import PCA

plot_latents = [l.cpu().numpy().flatten() for l in generated_latents]
print(f"Number of latents: {len(plot_latents)}")

pca = PCA(n_components=2)
pca.fit(np.vstack(plot_latents))
pca_latents = pca.transform(np.vstack(plot_latents))

# plot the latents with a number next to each point representing the order of the latent
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.scatter(pca_latents[:,0],pca_latents[:,1])
for i, (x,y) in enumerate(pca_latents):
    ax.text(x,y,i,fontsize=10)

fig.savefig(f'latent-space-{exp_label}.png')




In [None]:
# when you hover over a point, it will show the image
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data

fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.scatter(pca_latents[:,0],pca_latents[:,1])
for i, (x,y) in enumerate(pca_latents):
    if i % 5 == 0:
        ab = AnnotationBbox(OffsetImage(imgs[i], zoom=0.1), (x, y), frameon=False)
        ax.add_artist(ab)
        ax.text(x,y,i,fontsize=10)

#ax.axis('off')
# save plot 
fig.savefig(f'latent-space-with-images-{exp_label}.png')

In [None]:
exp_label_compare = 'frames-london_bus-cat-startstep-10-incrementseed-True-seeds-1-1'
#exp_label_compare = 'frames-london_bus-cat-startstep-10-incrementseed-True-seeds-100-101'
imgs_compare = get_existing_experiment_or_create_new(exp_label_compare)
generated_latents_compare = [pil_to_latent(img) for img in imgs_compare]
plot_latents_compare = [l.cpu().numpy().flatten() for l in generated_latents_compare]

# calculate the distance between each latent in plot_latents and the latent in plot_latents_compare with the same index
distances = []
for i in range(len(plot_latents_compare)):
    distances.append(np.linalg.norm(plot_latents[i] - plot_latents_compare[i]))

# plot the distances
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.plot(distances)
ax.set_xlabel('Iteration')
ax.set_ylabel('Distance')
fig.savefig(f'distance-{exp_label_compare}-{exp_label}.png')



In [None]:
exp_label_compare = 'frames-london_bus-cat-startstep-10-incrementseed-True-seeds-1-1'
#exp_label_compare = 'frames-london_bus-cat-startstep-10-incrementseed-True-seeds-100-101'
imgs_compare = get_existing_experiment_or_create_new(exp_label_compare)
print(f"len(img): {len(imgs)} len(imgs_compare): {len(imgs_compare)}")
generated_latents_compare = [pil_to_latent(img) for img in imgs_compare]
plot_latents_compare = [l.cpu().numpy().flatten() for l in generated_latents_compare]

# do dimensionality reduction on the generated_latents and generated_latents_compare
from sklearn.decomposition import PCA

pca = PCA(n_components=2)
#pca.fit(np.vstack(np.concatenate([plot_latents, plot_latents_compare])))
pca.fit(np.vstack(plot_latents))
pca_latents = pca.transform(np.vstack(plot_latents))
pca_latents_compare = pca.transform(np.vstack(plot_latents_compare))

print(f"len of pca_latents: {len(pca_latents)} len of pca_latents_compare: {len(pca_latents_compare)}")

# plot the latents with a number next to each point representing the order of the latent, with different colors for the two sets of latents
fig, ax = plt.subplots(1,1,figsize=(10,10))
ax.scatter(pca_latents[:,0],pca_latents[:,1],color='blue')
ax.scatter(pca_latents_compare[:,0],pca_latents_compare[:,1],color='red')
for i, (x,y) in enumerate(pca_latents):
    ax.text(x,y,i,fontsize=10)
for i, (x,y) in enumerate(pca_latents_compare):
    ax.text(x,y,i,fontsize=10)



