# Image Compression Using VQVAE

Recent advances in statistical machine learning have opened up new possibilities for data compression, allowing compression algorithms to be learned end-to-end from data using powerful generative models such as variational autoencoders, diffusion probabilistic models, and generative adversarial networks. Recent machine learning (ML) models achieve state of the art compression performance by encoding the pixel data to latent space (perceptual encoding) and then applying diffusion models for semantic encoding.

The variational autoencoders (VAEs) can be used for mapping pixel data to latent space and are proved to have a rate distortion (R-D) theory interpretation.  With a distortion metric specified, VAEs learn to “compress” data by minimizing a tight upper bound on their information R-D function achieving image compression. The Vector Quantization (VQ) is used for discrete representation of the latents as it helps with generation of realistic images.

The CelebA-HQ images (1024x1024) are used as the dataset for this model. The Fréchet inception distance (FID) is used as a metric for reconstruction/generation of images using VQ-VAE. 

The VQ-VAE implementation in Huggingface Stable Diffusion is used for testing.

  
 

In [2]:
import torch
import os
import torch.nn as nn
import torchvision
from torchvision import models, transforms, utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
from PIL import Image
import json
%matplotlib inline

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
!ln -s /content/gdrive/MyDrive /mydrive
%cd /mydrive/diffusion/LDM/super_resolution

Mounted at /content/gdrive
/content/gdrive/MyDrive/diffusion/LDM/super_resolution


In [3]:
!pip install git+https://github.com/huggingface/diffusers.git

Collecting git+https://github.com/huggingface/diffusers.git
  Cloning https://github.com/huggingface/diffusers.git to /tmp/pip-req-build-gnc5wa25
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/diffusers.git /tmp/pip-req-build-gnc5wa25
  Resolved https://github.com/huggingface/diffusers.git to commit 1f22c9882020cbe2cc08acfee54fab553bbb5678
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: diffusers
  Building wheel for diffusers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for diffusers: filename=diffusers-0.27.0.dev0-py3-none-any.whl size=1982931 sha256=9f80f74b305c12c8e1ee492f1fadeed88000c109458565cd989cf192afb51625
  Stored in directory: /tmp/pip-ephem-wheel-cache-s54z7q1x/wheels/4d/b7/a8/6f9549ceec5daad78675b857ac57d697c387062506520a7b50
Successfully built diffusers
Installing

In [4]:

from diffusers import LDMSuperResolutionPipeline

from tqdm import tqdm


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/ldm-super-resolution-4x-openimages"

# load model and scheduler
pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model_index.json:   0%|          | 0.00/248 [00:00<?, ?B/s]

unet/diffusion_pytorch_model.safetensors not found


Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

scheduler/scheduler_config.json:   0%|          | 0.00/286 [00:00<?, ?B/s]

unet/config.json:   0%|          | 0.00/873 [00:00<?, ?B/s]

vqvae/config.json:   0%|          | 0.00/713 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/455M [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/221M [00:00<?, ?B/s]

Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

  deprecate("LoRACompatibleLinear", "1.0.0", deprecation_message)
  deprecate("LoRACompatibleConv", "1.0.0", deprecation_message)


In [5]:
!pip install torch-fidelity
!pip install torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance
_ = torch.manual_seed(123)


Collecting torch-fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch-fidelity
Successfully installed torch-fidelity-0.3.0
Collecting torchmetrics
  Downloading torchmetrics-1.3.1-py3-none-any.whl (840 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m840.4/840.4 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.1-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.1 torchmetrics-1.3.1


In [18]:


def preprocess(image):
    w, h = image.size
    #w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
    #image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def postprocess(image):
    image = torch.clamp(image, -1.0, 1.0)
    image = image / 2 + 0.5
    decode_image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
    return(decode_image)

def vae_encode(model, image):
    image = preprocess(image)
    image = image.to(model.device)
    encode_image= model.encode(image).latents
    return(encode_image)


def vae_decode(model, latents):
    latents = latents.to(model.device)
    decode_image= model.decode(latents).sample
    decode_image = postprocess(decode_image)
    return(decode_image)


def quantize_int8(in_tensor):
    max_val_tensor = torch.max(torch.abs(in_tensor))
    tensor_scale = 127*in_tensor/max_val_tensor
    tensor_int8 = tensor_scale.to(torch.int8)
    tensor_quant = tensor_int8.to(in_tensor.dtype)
    tensor_quant = tensor_quant * max_val_tensor/127
    return(tensor_quant)


def compute_MSE(org_image, decode_image):
    image1 = preprocess(org_image)
    image2 = preprocess(decode_image)
    loss = nn.MSELoss(reduction='none')
    loss_result = torch.sum(loss(image1,image2))
    return(loss_result/(org_image.size[0]*org_image.size[1]*3))

def convertint8(image):
    w, h = image.size
    #w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
    #image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.uint8)
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

def compute_FID(fid_module, predict_array, target_array):
    images = [convertint8(image) for image in target_array]
    target_images = torch.cat(images, 0)
    images = [convertint8(image) for image in predict_array]

    predict_images = torch.cat(images, 0)
    fid_module.update(target_images, real=True)
    fid_module.update(predict_images, real=False)
    return(fid_module.compute())

In [7]:
def split_image(image, width_splits, height_splits):
   image_array = []
   width, height = image.size
   split_width = width//width_splits
   split_height = height//height_splits

   h_start = 0
   h_end = split_height

   for i in range(height_splits):
      w_start = 0
      w_end = split_width
      for j in range(width_splits):
          #print(w_start, w_end, h_start, h_end)
          crop_image = image.crop((w_start, h_start, w_end, h_end))
          image_array.append(crop_image)
          w_start = w_end
          w_end = w_end+split_width
          if (w_end > width):
             w_end = width

      h_start = h_end
      h_end = h_end+split_height
      if (h_end > height):
          h_end = height

   return(image_array)


def join_images(image_array, width_splits, height_splits):

   total_width = 0
   total_height = 0


   for i in range(width_splits):
       img_width, img_height = image_array[i].size
       total_width = total_width + img_width

   for i in range(height_splits):
       img_width, img_height = image_array[i*width_splits].size
       total_height = total_height + img_height

   #print(total_width, total_height)

   image = Image.new('RGB', (total_width, total_height))
   h_start = 0
   for i in range(height_splits):
      w_start = 0
      for j in range(width_splits):
          image.paste(image_array[j+i*width_splits], (w_start, h_start))
          img_width, img_height = image_array[j+i*width_splits].size
          w_start = w_start+img_width


      h_start = h_start + img_height


   return(image)

In [None]:
from glob import glob

#file_type and path_name should match
path_name ="/content/gdrive/MyDrive/diffusion/LDM/dataset/train/png/labelA/*.png"
file_type = "/png/"
def vqvae_compression(path_name, file_type, split=True, split_factor = 2):

  for png_file_name in glob(path_name):
    png_image = Image.open(png_file_name).convert("RGB")

    if split:
       split_images = split_image(png_image, split_factor, split_factor)
       decode_split_images = []
       for image in split_images:
          #encode to latents
          latents = vae_encode(pipeline.vqvae, image)

          #Decode to original size of split images
          decode_image = vae_decode(pipeline.vqvae, latents)

          decode_image = Image.fromarray((decode_image[0]*255).astype('uint8'))
          decode_split_images.append(decode_image)

       vae_out_image = join_images(decode_split_images, split_factor, split_factor)
    else:
       #encode to latents
       latents = vae_encode(pipeline.vqvae, png_image)

       #Decode to original size of split images
       decode_image = vae_decode(pipeline.vqvae, latents)

       vae_out_image = Image.fromarray((decode_image[0]*255).astype('uint8'))


    vae_file_name = png_file_name.replace("/train/", "/predict/")
    vae_file_name = vae_file_name.replace(file_type, "/vqvae"+file_type)
    vae_out_image.save(vae_file_name,format="PNG")


vqvae_compression(path_name, file_type, split=True, split_factor = 2)

In [13]:
# VQ Compression on 1024 x 1024 image
from glob import glob
fid_module = FrechetInceptionDistance(feature=64)


In [None]:
batch_size = 20
file_type = "/jpg/"
jpg_file_names = []
for jpg_file_name in glob("/content/gdrive/MyDrive/diffusion/LDM/dataset/predict/vqvae"+ file_type +"labelA/*.jpg"):
  jpg_file_names.append(jpg_file_name)

fid_module.reset()
for count in range(len(jpg_file_names)//batch_size):
    target_images = []
    predict_images = []
    for i in range(batch_size):
      jpg_file_name = jpg_file_names[count*batch_size+i]
      predict_images.append(Image.open(jpg_file_name).convert("RGB"))
      target_file_name = jpg_file_name.replace("/predict/", "/train/")
      target_file_name = target_file_name.replace("/vqvae"+file_type, file_type)
      target_images.append(Image.open(target_file_name).convert("RGB"))
    print(compute_FID(fid_module, predict_images, target_images))


In [19]:
batch_size = 1
file_type = "/jpg/"
jpg_file_names = []
mse_value = 0
for jpg_file_name in glob("/content/gdrive/MyDrive/diffusion/LDM/dataset/predict/vqvae"+ file_type +"labelA/*.jpg"):
  jpg_file_names.append(jpg_file_name)

#fid_module.reset()
for count in range(len(jpg_file_names)//batch_size):
    target_images = []
    predict_images = []
    for i in range(batch_size):
      jpg_file_name = jpg_file_names[count*batch_size+i]
      predict_images.append(Image.open(jpg_file_name).convert("RGB"))
      target_file_name = jpg_file_name.replace("/predict/", "/train/")
      target_file_name = target_file_name.replace("/vqvae"+file_type, file_type)
      target_images.append(Image.open(target_file_name).convert("RGB"))
    mse_value = mse_value + compute_MSE(target_images[0], predict_images[0])
    #print(compute_FID(fid_module, predict_images, target_images))
mse_value = mse_value/len(jpg_file_names)
print(f"PSNR :{-10*np.log10(mse_value)}")

PSNR :29.33060073852539
