In [None]:
import torch
import clip
import os
import sys
import torch
import torchvision.transforms as T
import time
import shutil
from os.path import join
from torchvision.transforms import ToPILImage

base_directory = "../"
sys.path.insert(0, base_directory)

output_base_dir = join(base_directory, "./output/sd2-notebook/")
output_directory = join(output_base_dir, "clip_image_encoder/")

try: 
    shutil.rmtree(output_directory)
except Exception as e:
    print(e, "\n", "Creating the path...")
    os.makedirs(output_directory, exist_ok=True)
else:
    os.makedirs(output_directory, exist_ok=True)

from stable_diffusion.stable_diffusion import StableDiffusion
from stable_diffusion.model.clip_image_encoder import CLIPImageEncoder
from utility.labml.monit import section
# from stable_diffusion.utils.utils import SectionManager as section
from stable_diffusion.utils_model import *
from stable_diffusion.utils_backend import *
from stable_diffusion.constants import IODirectoryTree
from pathlib import Path

In [None]:
batch_size = 1
device = get_device(None)
pt = IODirectoryTree(base_directory=base_directory)

In [None]:
sd = StableDiffusion(device=device)
get_memory_status()

In [None]:
# let's start embedding a prompt and saving it

sd.quick_initialize().load_clip_embedder(**pt.embedder).load_submodels(**pt.embedder_submodels)
get_memory_status()

In [None]:
# get the text embeddings
uncond, cond = sd.get_text_conditioning(uncond_scale = 7.5, prompts = ["A computer virus dancing tango."])
get_memory_status()

In [None]:
# check their shape
uncond.shape, cond.shape

In [None]:
# save them to disk
torch.save(uncond, join(output_directory, "uncond.pt"))
torch.save(uncond, join(output_directory, "cond.pt"))

In [None]:
# load the unet, needed for the sampling
sd.model.load_unet(**pt.unet)
get_memory_status()

In [None]:
# load the decoder, needed for getting the sampled image back from the latent space
sd.model.load_autoencoder(**pt.autoencoder).load_decoder(**pt.decoder)
get_memory_status()

In [None]:
images = sd.generate_images_from_embeddings(null_prompt = uncond, embedded_prompt=cond, batch_size = batch_size)
get_memory_status()

In [None]:
images.shape

In [None]:
if batch_size > 1:
    grid = torchvision.utils.make_grid(images, nrow=2, normalize=False, range=(-1, 1))
    img = to_pil(grid)
else:
    img = to_pil(images[0])
img

In [None]:
sd.unload_model()
get_memory_status()

In [None]:
images.shape

In [None]:
img_encoder = CLIPImageEncoder(device=device)
get_memory_status()

In [None]:
img_encoder.load_clip_model(**pt.clip_model)
get_memory_status()
img_encoder.initialize_preprocessor(do_center_crop=True)

In [None]:
img_encoder.image_processor

In [None]:
prep_from_img = img_encoder.preprocess_input(img).to(device)
type(img)

In [None]:
img_encoder.image_processor

In [None]:
prep_from_tensor = img_encoder.preprocess_input(images)
type(images)

In [None]:
prep_from_img.shape, prep_from_tensor.shape

In [None]:
torch.all(prep_from_img.to(device) == prep_from_tensor)

In [None]:
torch.norm(prep_from_img.to(device) - prep_from_tensor)

In [None]:
to_pil(prep_from_img.squeeze(0))

In [None]:
to_pil(prep_from_tensor.squeeze(0))

In [None]:
to_pil((prep_from_img - prep_from_tensor).squeeze(0))

In [None]:
if batch_size > 1:
    grid = torchvision.utils.make_grid([prep_from_img, prep_from_tensor], nrow=2, normalize=False, range=(-1, 1))
    img = to_pil(grid)
else:
    img = to_pil(prep_from_img.squeeze(0))
img

In [None]:
grid = torchvision.utils.make_grid([prep_from_img.squeeze(), prep_from_tensor.squeeze()], nrow=2, normalize=False, range=(-1, 1))
img = to_pil(grid)

In [None]:
img_features_tensor = img_encoder(prep_from_tensor)
img_features_tensor.shape

In [None]:
img_features_image = img_encoder(prep_from_img)
img_features_image.shape

In [None]:
torch.norm(img_features_image - img_features_tensor)