In [None]:
!git clone https://github.com/kk-digital/kcg-ml-sd1p4
%cd kcg-ml-sd1p4

In [None]:
!pip install -r requirements.txt

In [None]:
!python3 ./download_models.py

In [None]:
!python3 ./process_models.py

In [None]:
import os
import sys
import torch
import time
import shutil
from torchvision.transforms import ToPILImage
from os.path import join

base_directory = "../"
sys.path.insert(0, base_directory)

output_base_dir = join(base_directory, "./output/sd2-notebook/")
output_directory = join(output_base_dir, "autoencoder/")

try:
    shutil.rmtree(output_directory)
except Exception as e:
    print(e, "\n", "Creating the path...")
    os.makedirs(output_directory, exist_ok=True)
else:
    os.makedirs(output_directory, exist_ok=True)

from stable_diffusion.stable_diffusion import StableDiffusion
from stable_diffusion.utils_backend import *
from stable_diffusion.utils_image import *
from stable_diffusion.utils_model import *
from stable_diffusion.utils_logger import *
from stable_diffusion.constants import IODirectoryTree


to_pil = lambda image: ToPILImage()(torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0))
device = get_device()

In [None]:
pt = IODirectoryTree(base_directory=base_directory)
print(pt.decoder)  # should be `.../kcg-ml-sd1p4/input/model/autoencoder/encoder.ckpt`

In [None]:
# initialize an empty stable diffusion class
stable_diffusion = StableDiffusion(device=device)
get_memory_status()

In [None]:
# initialize an empty latent diffusion model; it returns self.model
stable_diffusion.quick_initialize()
get_memory_status()

In [None]:
# the latent diffusion class has a method to load the autoencoder, since it is a submodel of it. it returns the autoencoder
stable_diffusion.model.load_autoencoder(**pt.autoencoder)
get_memory_status()

In [None]:
# the autoencoder has a method to load the encoder, since it's one of its submodels. it returns the encoder
stable_diffusion.model.autoencoder.load_encoder(**pt.encoder)
get_memory_status()

In [None]:
# since each method returns the thing it loads, we could, for convenience, one-line that out: intialize a latent diffusion model, then load the autoencoder, then load the encoder
stable_diffusion.quick_initialize().load_autoencoder(**pt.autoencoder).load_encoder(
    **pt.encoder
)
get_memory_status()

In [None]:
# load an image to test the encoder module
img = load_img(join(base_directory, ".test/input/test_img.jpg")).to(device)
to_pil(img.squeeze(0))

In [None]:
# get the latent representation of the test image
encoded_img = stable_diffusion.encode(img)
get_memory_status()

In [None]:
# check its shape
encoded_img.shape

In [None]:
# show each dimension of the latent representation
grid = torchvision.utils.make_grid(
    encoded_img.permute(1, 0, 2, 3),
    nrow=2,
    normalize=False,
    range=(-1, 1),
    scale_each=True,
    pad_value=0,
)
dim_grid_image = to_pil(grid)
dim_grid_image.save(join(output_directory, f"encoding_dimensions_grid.png"))
dim_grid_image

In [None]:
# save it as a tensor
torch.save(encoded_img, join(output_base_dir, f"encoded_img_tensor.pt"))

In [None]:
del encoded_img
torch.cuda.empty_cache()
print(get_memory_status())

In [None]:
# load it back
encoded_img = torch.load(join(output_base_dir, f"encoded_img_tensor.pt"))
torch.cuda.empty_cache()
get_memory_status()

In [None]:
encoded_img.shape

In [None]:
grid = torchvision.utils.make_grid(
    encoded_img.permute(1, 0, 2, 3),
    nrow=2,
    normalize=False,
    range=(-1, 1),
    scale_each=True,
    pad_value=0,
)
to_pil(grid)

In [None]:
del grid
torch.cuda.empty_cache()
get_memory_status()

In [None]:
# unload the encoder submodel
stable_diffusion.model.autoencoder.unload_encoder()
torch.cuda.empty_cache()
get_memory_status()

In [None]:
# load the decoder submodel
stable_diffusion.model.autoencoder.load_decoder(**pt.decoder)
torch.cuda.empty_cache()
get_memory_status()

In [None]:
# decode the latent representation that we loaded from disk
decoded_img = stable_diffusion.decode(encoded_img)
save_images(decoded_img, join(output_directory, f"decoded_img.png"))
torch.cuda.empty_cache()
get_memory_status()
to_pil(decoded_img[0])

In [None]:
# initially loaded image isn't the same as the decoded image
torch.norm(img - decoded_img)

In [None]:
# plot the difference as an image
diff_img = to_pil((img - decoded_img).squeeze(0))
diff_img.save(join(output_directory, f"diff_img.png"))
diff_img