<!-- Copyright 2020 InterDigital Communications, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->

# CompressAI Experiments

This notwbook shows some experiments done with the VAE implementation with compressai.

In [None]:
import torch
from torchvision import transforms
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt

from evaluation_functions import *

device = "cuda" if torch.cuda.is_available() else "cpu"

## Load a model

In [None]:
from compressai.zoo import bmshj2018_hyperprior

inference_ready_model_path = #...

net = bmshj2018_hyperprior(quality=5, pretrained=False)
net.load_state_dict(torch.load(inference_ready_model_path))
net.eval()
net.to(device)
print(f"Parameters: {sum(p.numel() for p in net.parameters())}")

In [None]:
import matplotlib

matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
    "font.family": "serif",
    "text.usetex": True,
    "pgf.rcfonts": False,
})

## Inference

### Load image and convert to 4D float tensor

In [None]:
image_path = #...

original_image = Image.open(image_path).convert("RGB") 
shape_original = original_image.size
shape_input = (256, 256)
input_image = original_image.resize(shape_input)  

%matplotlib inline
plt.figure()
plt.axis("off")
plt.imshow(input_image.resize(shape_original))
plt.show()

In [None]:
print(shape_original)

### Run the network

In [None]:
input_tensor = transforms.ToTensor()(input_image).unsqueeze(0).to(device)

with torch.no_grad():
    output_net = net.forward(input_tensor)
output_net["x_hat"].clamp_(0, 1)
output_tensor = output_net["x_hat"]

### Visualize result

In [None]:
output_image = transforms.ToPILImage()(output_tensor.squeeze().cpu())
diff = transforms.ToPILImage()(torch.mean((output_tensor - input_tensor).abs(), axis=1).squeeze().cpu())

In [None]:
figsize_3_images = (5, 4)
matplotlib.rcParams.update({'font.size': 9})

from matplotlib import gridspec

In [None]:
%matplotlib inline
fig, axes = plt.subplots(1, 3, figsize=figsize_3_images)
for ax in axes:
    ax.axis("off")
    
axes[0].imshow(input_image.resize(shape_original))
axes[0].title.set_text("Original")

axes[1].imshow(output_image.resize(shape_original))
axes[1].title.set_text("Reconstructed")

axes[2].imshow(diff.resize(shape_original), cmap="viridis")
axes[2].title.set_text("Difference")

plt.savefig("./for_latex/Orig_Recon_Diff_plot.pgf")

plt.show()
plt.close()


### Compute metrics

Now, let"s compute some common metrics...

In [None]:
print(f"PSNR: {compute_psnr(input_tensor, output_tensor):.2f}dB")
print(f"MS-SSIM: {compute_msssim(input_tensor, output_tensor):.4f}")
print(f"Bit-rate: {compute_bpp(output_net):.3f} bpp")


## Comparison to classical codecs

### Quality comparison at similar bit-rate

In [None]:
target_bpp = compute_bpp(output_net)
net_msssim = compute_msssim(input_tensor, output_tensor)
rec_jpeg, bpp_jpeg = find_closest_bpp(target_bpp, input_image) 
rec_jpeg_msssim = compute_msssim(input_tensor, transforms.ToTensor()(rec_jpeg).unsqueeze(0).to(device))
gan_image_path = #...
gan_image_bpp = #...
gan_image = Image.open(gan_image_path)
rec_gan_msssim = compute_msssim(input_tensor, transforms.ToTensor()(gan_image.resize(shape_input)).unsqueeze(0).to(device))

fig = plt.figure()
gs = gridspec.GridSpec(2, 2) #, wspace=0.05, bottom=0.3) 

ax = plt.subplot(gs[0,0])
ax.imshow(original_image.resize(shape_original))
ax.set_title("Original")
ax.set_axis_off()
ax = plt.subplot(gs[0,1])
ax.imshow(rec_jpeg.resize(shape_original))
ax.set_title(f"Compressed by JPEG ({bpp_jpeg:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[1,0])
ax.imshow(output_image.resize(shape_original))
ax.set_title(f"Reconstruction VAE ({target_bpp:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[1,1])
ax.imshow(gan_image.resize(shape_original))
ax.set_title(f"Reconstruction GAN ({gan_image_bpp:.3f} bpp)")
ax.set_axis_off()

plt.savefig("./for_latex/Compare_Compressions_with_JPEG_Same_BPP.pgf", dpi=400, bbox_inches='tight')

plt.show()
#plt.close()

### Script for out of domain plot

In [None]:
""""
matplotlib.rcParams.update({'font.size': 6})
def make_prediction(image_path):
    original_image = Image.open(image_path).convert("RGB") 
    shape_original = original_image.size
    shape_input = (256, 256)#(384, 192) #(1152,384)
    input_image = original_image.resize(shape_input)  
    input_tensor = transforms.ToTensor()(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        output_net = net.forward(input_tensor)
    output_net["x_hat"].clamp_(0, 1)
    output_tensor = output_net["x_hat"]
    output_image = transforms.ToPILImage()(output_tensor.squeeze().cpu())
    target_bpp = compute_bpp(output_net)
    net_msssim = compute_msssim(input_tensor, output_tensor)
    rec_jpeg, bpp_jpeg = find_closest_bpp(target_bpp, input_image)

    return (original_image, shape_original, output_image, output_tensor, target_bpp, net_msssim, rec_jpeg, bpp_jpeg)

image_path_lion = #..
(original_image_lion,
    shape_original_lion,
    output_image_lion, 
    output_tensor_lion, 
    target_bpp_lion, 
    net_msssim_lion,
    rec_jpeg_lion, 
    bpp_jpeg_lion)  = make_prediction(image_path_lion)
gan_image_lion = Image.open(#..) 


image_path_night = #..
(original_image_night,
    shape_original_night,
    output_image_night, 
    output_tensor_night, 
    target_bpp_night, 
    net_msssim_night,
    rec_jpeg_night, 
    bpp_jpeg_night)  = make_prediction(image_path_night)
gan_image_night = Image.open(#..)

fig = plt.figure(figsize=(7,2.5))
gs = gridspec.GridSpec(2, 4)
plt.tight_layout()

ax = plt.subplot(gs[0,0])
ax.imshow(original_image.resize(shape_original_lion))
ax.set_title("Original corner case")
ax.set_axis_off()
ax = plt.subplot(gs[0,1])
ax.imshow(rec_jpeg_lion.resize(shape_original_lion))
ax.set_title(f"Compressed by JPEG\n({bpp_jpeg_lion:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[0,2])
ax.imshow(output_image_lion.resize(shape_original_lion))
ax.set_title(f"Reconstruction VAE\n({target_bpp_lion:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[0,3])
ax.imshow(gan_image_lion.resize(shape_original_lion))
ax.set_title(f"Reconstruction GAN\n({#..:.3f} bpp)")
ax.set_axis_off()

ax = plt.subplot(gs[1,0])
ax.imshow(original_image_night.resize(shape_original_night))
ax.set_title("Original night scene")
ax.set_axis_off()
ax = plt.subplot(gs[1,1])
ax.imshow(rec_jpeg_night.resize(shape_original_night))
ax.set_title(f"Compressed by JPEG\n({bpp_jpeg_night:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[1,2])
ax.imshow(output_image_night.resize(shape_original_night))
ax.set_title(f"Reconstruction VAE\n({target_bpp_night:.3f} bpp)")
ax.set_axis_off()
ax = plt.subplot(gs[1,3])
ax.imshow(gan_image_night.resize(shape_original_night))
ax.set_title(f"Reconstruction GAN\n({#..:.3f} bpp)")
ax.set_axis_off()

plt.savefig("./for_latex/compare_cases_compression.pgf", dpi=400, bbox_inches='tight')
"""

## Inference Pipeline Test

In [None]:
with torch.no_grad():
    compress = net.compress(input_tensor)
    decompress = net.decompress(compress["strings"], compress["shape"])
print(compress.keys())
print(decompress.keys())

In [None]:
input_tensor_np = np.asarray(input_image)
input_size = input_tensor_np.nbytes / 1024
print(f"Inputs size: {input_size} KBytes")
compressed_np = np.append(np.asarray(compress["strings"]), np.asarray(compress["shape"]))
output_size = compressed_np.nbytes / 1024
print(f"Compressed size: {output_size} KBytes")
print(f"Compression: {round(output_size / input_size * 100, 2)}%")

In [None]:
decrompress_image = transforms.ToPILImage()(decompress["x_hat"].squeeze().cpu())
plt.imshow(decrompress_image)

## Object Detection Test

In [None]:
import tensorflow_hub as hub
import tensorflow as tf

from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
from PIL import ImageOps

In [None]:
def display_image(image):
  fig = plt.figure(figsize=(20, 15))
  plt.grid(False)
  plt.imshow(image)

def draw_bounding_box_on_image(image,
                               ymin,
                               xmin,
                               ymax,
                               xmax,
                               color,
                               font,
                               thickness=4,
                               display_str_list=()):
  """Adds a bounding box to an image."""
  draw = ImageDraw.Draw(image)
  im_width, im_height = image.size
  (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                ymin * im_height, ymax * im_height)
  draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
             (left, top)],
            width=thickness,
            fill=color)

  # If the total height of the display strings added to the top of the bounding
  # box exceeds the top of the image, stack the strings below the bounding box
  # instead of above.
  display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
  # Each display_str has a top and bottom margin of 0.05x.
  total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)

  if top > total_display_str_height:
    text_bottom = top
  else:
    text_bottom = top + total_display_str_height
  # Reverse list and print from bottom to top.
  for display_str in display_str_list[::-1]:
    text_width, text_height = font.getsize(display_str)
    margin = np.ceil(0.05 * text_height)
    draw.rectangle([(left, text_bottom - text_height - 2 * margin),
                    (left + text_width, text_bottom)],
                   fill=color)
    draw.text((left + margin, text_bottom - text_height - margin),
              display_str,
              fill="black",
              font=font)
    text_bottom -= text_height - 2 * margin


def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
  """Overlay labeled boxes on an image with formatted scores and label names."""
  colors = list(ImageColor.colormap.values())

  try:
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf",
                              25)
  except IOError:
    print("Font not found, using default font.")
    font = ImageFont.load_default()

  for i in range(min(boxes.shape[0], max_boxes)):
    if scores[i] >= min_score:
      ymin, xmin, ymax, xmax = tuple(boxes[i])
      display_str = "{}: {}%".format(class_names[i].decode("ascii"),
                                     int(100 * scores[i]))
      color = colors[hash(class_names[i]) % len(colors)]
      image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
      draw_bounding_box_on_image(
          image_pil,
          ymin,
          xmin,
          ymax,
          xmax,
          color,
          font,
          display_str_list=[display_str])
      np.copyto(image, np.array(image_pil))
  return image

def run_detector(detector, img):

  converted_img  = tf.convert_to_tensor(tf.keras.preprocessing.image.img_to_array(img)[tf.newaxis, ...] / 255.)
  result = detector(converted_img)

  result = {key:value.numpy() for key,value in result.items()}

  print("Found %d objects." % len(result["detection_scores"]))

  image_with_boxes = draw_boxes(
      np.array(img), result["detection_boxes"],
      result["detection_class_entities"], result["detection_scores"])

  display_image(image_with_boxes)

In [None]:
module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
detector = hub.load(module_handle).signatures["default"]

In [None]:
run_detector(detector, input_image)

In [None]:
run_detector(detector, output_image)

In [None]:
converted_img  = tf.convert_to_tensor(tf.keras.preprocessing.image.img_to_array(output_image)[tf.newaxis, ...] / 255.)
result = detector(converted_img)
result = {key:value.numpy() for key,value in result.items()}
result