## Compression using Latent Diffusion Model

Latent Diffusion Models (LDM) achieve compression by mapping pixel data to discrete latent representations and denoising diffusion on the latent space. The transformation to latent space provides perceptual compression and diffusion model provides semantic representation. Typically Variational Auto Encoders (VAEs) are used for the first step of latent space compression and Denoising Diffision Probabilistic Models (DDPMs) are used as diffusion model. This two step process is euqivalent to learning respresentations using VAE and then learning probability distribution in the second phase.

### Advatanges of LDM over GANS
  * LDMs are trained using Maximum Likelihood cost function (Stable and asymptotically optimal with large dataset)
  * Strong foundation in neurophysics as Theory of Active Inference (Learning is inference in the latent domain and Baysean Brain Hypothesis)
  * Better image quality (IS and FID) than GANs due to inherent inductive bias of learning image represenations using Covnets
  * LDM also supports other modality inputs to condition/guide the generation process


In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
!ln -s /content/gdrive/MyDrive /mydrive
%cd /mydrive/diffusion/LDM/super_resolution

Mounted at /content/gdrive
/content/gdrive/MyDrive/diffusion/LDM/super_resolution


In [2]:
import torch
import os
import torch.nn as nn
import torchvision
from torchvision import models, transforms, utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import scipy.misc
from PIL import Image
import json
%matplotlib inline

In [3]:
!pip install git+https://github.com/huggingface/diffusers.git


Collecting git+https://github.com/huggingface/diffusers.git
  Cloning https://github.com/huggingface/diffusers.git to /tmp/pip-req-build-7tpwwvej
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/diffusers.git /tmp/pip-req-build-7tpwwvej
  Resolved https://github.com/huggingface/diffusers.git to commit ee6a3a993dd3111146669a30b253af296c3910a8
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: diffusers
  Building wheel for diffusers (pyproject.toml) ... [?25l[?25hdone
  Created wheel for diffusers: filename=diffusers-0.27.0.dev0-py3-none-any.whl size=2025638 sha256=b37d89525aee28c7a0a7d90ad738692f2dfdd3dd10ed41faa62cdb9e73d6568f
  Stored in directory: /tmp/pip-ephem-wheel-cache-z8z7smt5/wheels/4d/b7/a8/6f9549ceec5daad78675b857ac57d697c387062506520a7b50
Successfully built diffusers
Installing

In [None]:
!pip install torch-fidelity
!pip install torchmetrics
from torchmetrics.image.fid import FrechetInceptionDistance

In [5]:

from diffusers import LDMSuperResolutionPipeline
#from latent_diffusion import LDMSuperResolutionPipeline

from tqdm import tqdm


device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "CompVis/ldm-super-resolution-4x-openimages"

# load model and scheduler

pipeline = LDMSuperResolutionPipeline.from_pretrained(model_id)
pipeline = pipeline.to(device)
_ = torch.manual_seed(123)
#fid_module = FrechetInceptionDistance(feature=64)

Cannot initialize model with low cpu memory usage because `accelerate` was not found in the environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install `accelerate` for faster and less memory-intense model loading. You can do so with: 
```
pip install accelerate
```
.
unet/diffusion_pytorch_model.safetensors not found


Loading pipeline components...:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
import torchvision
import torchvision.transforms as transforms

def preprocess(image):
    w, h = image.size
    #w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
    #image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0

def postprocess(image):
    image = torch.clamp(image, -1.0, 1.0)
    image = image / 2 + 0.5
    decode_image = image.cpu().permute(0, 2, 3, 1).detach().numpy()
    return(decode_image)

def vae_encode(model, image):
    image = preprocess(image)
    image = image.to(model.device)
    encode_image= model.encode(image).latents
    return(encode_image)


def vae_decode(model, latents):
    latents = latents.to(model.device)
    decode_image= model.decode(latents).sample
    decode_image = postprocess(decode_image)
    return(decode_image)


def quantize_int8(in_tensor):
    max_val_tensor = torch.max(torch.abs(in_tensor))
    tensor_scale = 127*in_tensor/max_val_tensor
    tensor_int8 = tensor_scale.to(torch.int8)
    tensor_quant = tensor_int8.to(in_tensor.dtype)
    tensor_quant = tensor_quant * max_val_tensor/127
    return(tensor_quant)


def compute_MSE(org_image, decode_image):
    image1 = preprocess(org_image)
    image2 = preprocess(decode_image)
    loss = nn.MSELoss(reduction='none')
    loss_result = torch.sum(loss(image1,image2))
    return(loss_result/(org_image.size[0]*org_image.size[1]*3))


def convertint8(image):
    w, h = image.size
    #w, h = (x - x % 32 for x in (w, h))  # resize to integer multiple of 32
    #image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
    image = np.array(image).astype(np.uint8)
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return image

def compute_FID(fid_module, predict_array, target_array):
    images = [convertint8(image) for image in target_array]
    target_images = torch.cat(images, 0)
    images = [convertint8(image) for image in predict_array]

    predict_images = torch.cat(images, 0)
    fid_module.update(target_images, real=True)
    fid_module.update(predict_images, real=False)
    return(fid_module.compute())


In [11]:
from glob import glob

#file_type and path_name should match
path_name ="/content/gdrive/MyDrive/diffusion/LDM/dataset/train/lr_jpg/labelA/*.jpg"
file_type = "/lr_jpg/"

counter = 0
for jpg_file_name in glob(path_name):
    jpg_image = Image.open(jpg_file_name).convert("RGB")
    ldm_out_image =  pipeline(jpg_image, num_inference_steps=3, eta=0.1).images[0]

    ldm_file_name = "/content/gdrive/MyDrive/diffusion/LDM/sample_photo_LDM_3.png"
    ldm_out_image.save(ldm_file_name,format="PNG")
    if counter == 0:
       break




  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
from glob import glob

#file_type and path_name should match
path_name ="/content/gdrive/MyDrive/diffusion/LDM/dataset/train/lr_jpg/labelA/*.jpg"
file_type = "/lr_jpg/"
def ldm_compression(path_name, file_type):

  for jpg_file_name in glob(path_name):
    jpg_image = Image.open(jpg_file_name).convert("RGB")
    ldm_out_image =  pipeline(jpg_image, num_inference_steps=20, eta=0.1).images[0]



    ldm_file_name = jpg_file_name.replace("/train/", "/predict/")
    ldm_file_name = ldm_file_name.replace(file_type, "/ldm/png/")
    ldm_file_name = ldm_file_name.replace(".jpg", ".png")
    ldm_out_image.save(ldm_file_name,format="PNG")

ldm_compression(path_name, file_type)

In [10]:
batch_size = 20
file_type = "/png/"
png_file_names = []
for png_file_name in glob("/content/gdrive/MyDrive/diffusion/LDM/dataset/predict/ldm"+ file_type +"labelA/*.png"):
  png_file_names.append(png_file_name)

fid_module.reset()
for count in range(len(png_file_names)//batch_size):
    target_images = []
    predict_images = []
    for i in range(batch_size):
      png_file_name = png_file_names[count*batch_size+i]
      predict_images.append(Image.open(png_file_name).convert("RGB"))
      target_file_name = png_file_name.replace("/predict/", "/train/")
      target_file_name = target_file_name.replace("/ldm"+file_type, file_type)
      target_images.append(Image.open(target_file_name).convert("RGB"))
    print(compute_FID(fid_module, predict_images, target_images))

tensor(0.0893)
tensor(0.0864)
tensor(0.0833)
tensor(0.0778)
tensor(0.0781)
tensor(0.0761)
tensor(0.0733)
tensor(0.0736)
tensor(0.0719)
tensor(0.0703)
tensor(0.0703)
tensor(0.0684)
tensor(0.0661)
tensor(0.0643)
tensor(0.0623)
tensor(0.0613)
tensor(0.0604)
tensor(0.0591)
tensor(0.0579)
tensor(0.0569)
tensor(0.0555)
tensor(0.0546)
tensor(0.0535)
tensor(0.0521)
tensor(0.0509)
tensor(0.0503)
tensor(0.0490)
tensor(0.0483)
tensor(0.0478)
tensor(0.0472)
tensor(0.0464)
tensor(0.0455)
tensor(0.0448)
tensor(0.0442)
tensor(0.0435)
tensor(0.0429)
tensor(0.0421)
tensor(0.0417)
tensor(0.0413)
tensor(0.0405)
tensor(0.0400)
tensor(0.0396)
tensor(0.0391)
tensor(0.0386)
tensor(0.0382)
tensor(0.0378)
tensor(0.0374)
tensor(0.0371)
tensor(0.0369)
tensor(0.0366)


In [None]:
from glob import glob
target_images = []
predict_images = []
for jpg_file_name in glob('/content/gdrive/MyDrive/diffusion/LDM/dataset/predict/jpg/labelA/*.jpg'):
    predict_images.append(Image.open(jpg_file_name).convert("RGB"))
    #output_image =  pipeline(lr_jpg_image, num_inference_steps=20, eta=0.1).images[0]

    target_file_name = jpg_file_name.replace("/predict/", "/train/")
    target_images.append(Image.open(target_file_name).convert("RGB"))


In [None]:
from glob import glob
target_images = []
predict_images = []
for jpg_file_name in glob('/content/gdrive/MyDrive/diffusion/LDM/dataset/train/lr_jpg/labelA/*.jpg'):
    predict_images.append(Image.open(jpg_file_name).convert("RGB"))
    #output_image =  pipeline(lr_jpg_image, num_inference_steps=20, eta=0.1).images[0]

    #target_file_name = jpg_file_name.replace("/predict/", "/train/")
    target_file_name = jpg_file_name.replace("/lr_jpg/", "/jpg/")
    target_images.append(Image.open(target_file_name).convert("RGB"))


In [None]:
predict_images = pred_images
images = [convertint8(image) for image in target_images]
target_img = torch.cat(images, 0)
images = [convertint8(image) for image in predict_images]
pred_img = torch.cat(images, 0)
fid_module = FrechetInceptionDistance(feature=768)
fid_module.update(target_img, real=True)
fid_module.update(pred_img, real=False)
print(fid_module.compute())
del fid_module

tensor(0.0521)


In [None]:
def split_image(image, width_splits, height_splits):
   image_array = []
   width, height = image.size
   split_width = width//width_splits
   split_height = height//height_splits

   h_start = 0
   h_end = split_height

   for i in range(height_splits):
      w_start = 0
      w_end = split_width
      for j in range(width_splits):
          #print(w_start, w_end, h_start, h_end)
          crop_image = image.crop((w_start, h_start, w_end, h_end))
          image_array.append(crop_image)
          w_start = w_end
          w_end = w_end+split_width
          if (w_end > width):
             w_end = width

      h_start = h_end
      h_end = h_end+split_height
      if (h_end > height):
          h_end = height

   return(image_array)


def join_images(image_array, width_splits, height_splits):

   total_width = 0
   total_height = 0


   for i in range(width_splits):
       img_width, img_height = image_array[i].size
       total_width = total_width + img_width

   for i in range(height_splits):
       img_width, img_height = image_array[i*width_splits].size
       total_height = total_height + img_height

   print(total_width, total_height)

   image = Image.new('RGB', (total_width, total_height))
   h_start = 0
   for i in range(height_splits):
      w_start = 0
      for j in range(width_splits):
          image.paste(image_array[j+i*width_splits], (w_start, h_start))
          img_width, img_height = image_array[j+i*width_splits].size
          w_start = w_start+img_width


      h_start = h_start + img_height


   return(image)