# Compressing an embedding using wavelets (jpeg2000) algorithm
## [jpeg-2000-wavelet-compression](http://www.jeanfeydy.com/Teaching/MasterClass_Radiologie/Part%207%20-%20JPEG2000%20compression.html)

This notebook demonstrates generating an embedding, saving it, loading it, and using the jpeg2000 algorithm in order to compress it.

In [None]:
ENV_TYPE = "TEST"

if(ENV_TYPE != "TEST"):
  !git clone "https://github.com/kk-digital/kcg-ml-sd1p4.git"
  %cd kcg-ml-sd1p4
  !pip3 install -r requirements.txt
  exit()
  base_directory = "./"
else:
  base_directory = "../"

# Magical check for fixing all of our directory issues
import subprocess
output = subprocess.check_output(["pwd"], universal_newlines=True)
if "notebooks" in output:
    %cd ..
del output

In [None]:
# Check for dependency needed for using OpenCV
import subprocess

result = subprocess.run(['dpkg', '-s', 'libgl1-mesa-glx'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)

# If the package is not installed, install it
if 'is not installed and no information is available' in result.stderr:
    print("Installing libgl, which is needed to run the GA script.")
    subprocess.run(["apt", "update"]) 
    subprocess.run(["apt", "install", "libgl1-mesa-glx"]) 
else:
    print("Package 'libgl1-mesa-glx' is already installed.")

In [None]:
!python3 ./download_models.py

In [None]:
!python3 ./process_models.py

In [5]:
import os
import sys
import torch
import time
import shutil
from torchvision.transforms import ToPILImage
from os.path import join

base_directory = "./"
sys.path.insert(0, base_directory)

from stable_diffusion.model_paths import *
from configs.model_config import ModelPathConfig
from stable_diffusion.utils_backend import *
from stable_diffusion.utils_image import *
from stable_diffusion.utils_model import *
from stable_diffusion.stable_diffusion import StableDiffusion
from utility.labml import monit


output_base_dir = join(base_directory, "./output/sd2-notebook/")
output_directory = join(output_base_dir, "jpeg_embed_compression/")


try:
    shutil.rmtree(output_directory)
except Exception as e:
    print(e, "\n", "Creating the path...")
    os.makedirs(output_directory, exist_ok=True)
else:
    os.makedirs(output_directory, exist_ok=True)


def to_pil(image):
    return ToPILImage()(torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0))

  warn(


In [6]:
device = get_device()
base_dir = os.getcwd()
sys.path.insert(0, base_dir)

batch_size = 1
model_config = ModelPathConfig()
pt = IODirectoryTree(model_config)



In [None]:
# initialize an empty stable diffusion class
stable_diffusion = StableDiffusion(device=device)
get_memory_status(device)

In [None]:
# initialize an empty latent diffusion model; it returns self.model
# then load the clip text embedder from the path `pt.embedder_path` with .load_clip_embedder()
# it returns the clip embedder, so you can chain a .load_submodels() to load the text embedder submodels

stable_diffusion.quick_initialize().load_clip_embedder().load_submodels()
get_memory_status(device)

In [None]:
stable_diffusion.model.clip_embedder

In [None]:
# get the embedding for a prompt
prompt_embedding = stable_diffusion.model.clip_embedder(
    ["Just another prompt embedding"]
)

In [None]:
get_memory_status(device)
prompt_embedding.shape

In [None]:
stable_diffusion.model.unload_clip_embedder()
get_memory_status(device)

In [None]:
# Save the prompt embedding
torch.save(prompt_embedding, join(output_directory, "prompt_embedding_uncompressed.pt"))

In [7]:
import torch

# Load the prompt embedding
prompt_embedding = torch.load(join(output_directory, "prompt_embedding_uncompressed.pt"))
prompt_embedding.shape

In [None]:
import numpy as np

# Convert the PyTorch tensor to a numpy array
prompt_embedding = prompt_embedding.cpu()
prompt_embedding_np = prompt_embedding.detach().numpy()

In [None]:
from PIL import Image
from pywt import dwt2, idwt2
import io

# Convert the numpy array to a Pillow image
prompt_embedding_img = Image.fromarray(prompt_embedding_np)

# Save the compressed image as a .jpg file
prompt_embedding_img.save(join(output_directory, "prompt_embedding_compressed.jp2"), format="JPEG2000")

In [None]:
# Load the compressed .jpg image
loaded_image = Image.open(join(output_directory, "prompt_embedding_compressed.jpg"))

In [None]:
# Convert the loaded image to a numpy array
loaded_embedding_np = np.array(loaded_image)

In [None]:
# Convert the numpy array back to a PyTorch tensor
loaded_embedding_tensor = torch.from_numpy(loaded_embedding_np)

In [None]:
# Check the shape of the loaded tensor
print(loaded_embedding_tensor.shape)  # Should be (77, 768)