In [None]:
import os
import sys
import torch
import time
import shutil
from torchvision.transforms import ToPILImage
from os.path import join

base_directory = "../"
sys.path.insert(0, base_directory)

output_base_dir = join(base_directory, "./output/sd2-notebook/")
output_directory = join(output_base_dir, "unet/")

try:
    shutil.rmtree(output_directory)
except Exception as e:
    print(e, "\n", "Creating the path...")
    os.makedirs(output_directory, exist_ok=True)
else:
    os.makedirs(output_directory, exist_ok=True)

from stable_diffusion.stable_diffusion import StableDiffusion
from stable_diffusion.utils_backend import *
from stable_diffusion.utils_image import *
from stable_diffusion.utils_model import *
from stable_diffusion.utils_logger import *
from stable_diffusion.constants import IODirectoryTree


to_pil = lambda image: ToPILImage()(torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0))

In [None]:
device = get_device()

In [None]:
pt = IODirectoryTree(base_directory=base_directory)

In [None]:
# initialize an empty stable diffusion class
stable_diffusion = StableDiffusion(device=device)
get_memory_status()

In [None]:
# initialize an empty latent diffusion model; it returns self.model
stable_diffusion.quick_initialize()
get_memory_status()

In [None]:
# to use the unet we first need conditioning context. we can use the clip embedder to get it.
stable_diffusion.model.load_clip_embedder(**pt.embedder)
get_memory_status()

In [None]:
# load the embedder submodels, tokenizer and transformer
stable_diffusion.model.clip_embedder.load_submodels(**pt.embedder_submodels)
get_memory_status()

In [None]:
stable_diffusion.model.clip_embedder

In [None]:
# get the embedding for a prompt
prompt_embedding = stable_diffusion.model.clip_embedder(
    ["A woman with flowers in her hair in a courtyard, in the style of Frank Frazetta"]
)

In [None]:
get_memory_status()
prompt_embedding.shape

In [None]:
# we don't need the embedder anymore, so we can unload it
stable_diffusion.model.unload_clip_embedder()
get_memory_status()

In [None]:
# let's save the prompt embedding
torch.save(prompt_embedding, join(output_base_dir, "prompt_embedding.pt"))

In [None]:
# the latent diffusion class has a method to load the unet, since it is a submodel of it. it returns the unet model, wrapped in a DiffusionWrapper class.
# it is accessible as self.model.model or through the alias self.model.unet
stable_diffusion.model.load_unet(**pt.unet)
get_memory_status()
stable_diffusion.model.unet

In [None]:
# load an encoded image and get its shape
encoded_img = torch.load(join(output_base_dir, "encoded_img_tensor.pt")).to(device)
encoded_img.shape

In [None]:
# sample a latent representation of same shape as the encoded image
sample = torch.randn_like(encoded_img)
get_memory_status()

In [None]:
# define a timestep for this sample
time_step = torch.tensor([0.0]).to(device)
time_step.shape

In [None]:
# predict noise with the unet
unet_output = stable_diffusion.model.unet(sample, time_step, prompt_embedding)

In [None]:
get_memory_status()
unet_output.shape

In [None]:
grid = torchvision.utils.make_grid(
    unet_output.permute(1, 0, 2, 3),
    nrow=2,
    normalize=False,
    range=(-1, 1),
    scale_each=True,
    pad_value=0,
)
dim_grid_image = to_pil(grid)
dim_grid_image.save(join(output_directory, f"unet_output.png"))
dim_grid_image

In [None]:
stable_diffusion.model.unload_submodels()
torch.cuda.empty_cache()
get_memory_status()