In [57]:
from datasets import load_dataset
from utils import jpeg_compress, webp_compress, pad, crop, nn_compress, hific_lo_compress
import compressai
from evaluate import evaluator
from transformers import pipeline
import torch
import numpy as np
import matplotlib.pyplot as plt
from piq import LPIPS
from torch.nn import MSELoss
from torchvision import transforms

In [2]:
def jpeg_compress_in(sample):
    img,bpp = jpeg_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample
    
def webp_compress_in(sample):
    img,bpp = webp_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

net_mbt2018 = compressai.zoo.mbt2018(quality=1, pretrained=True).eval().to("cpu")
def mbt2018_compress_in(sample):
    img,bpp = nn_compress(sample['image'],net_mbt2018,"cpu")
    sample['image'] = img
    sample['bpp'] = bpp
    return sample

def hific_lo_compress_in(sample):
    img,bpp = hific_lo_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

In [3]:
image_compression_methods = [
    jpeg_compress_in,
    webp_compress_in,
    mbt2018_compress_in,
    hific_lo_compress_in
]

In [4]:
imagenet = [load_dataset("imagenet-1k", split="validation[:1000]")]

In [5]:
for method in image_compression_methods:
    imagenet.append(imagenet[0].map(method))

In [6]:
image_bpp = [np.mean(method['bpp']) for method in imagenet[1:]]
image_bpp

[0.26474601405625764,
 0.14787044451833736,
 0.14990276963116184,
 0.02639249681257391]

In [7]:
lpips_metric = LPIPS()
mse_metric = MSELoss()



In [61]:
lpips = []
mse = []
for method in imagenet:
    method = method.with_format("torch")
    lpips.append([])
    mse.append([])
    for i_sample,sample in enumerate(imagenet[0].with_format("torch")):
        compressed_sample = method[i_sample]
        reference = sample['image'].to(torch.float32)
        if (reference.shape[-1] != 3):        
            reference = reference.unsqueeze(2)
        reference = reference.permute(2,0,1).unsqueeze(0)
        distorted = compressed_sample['image'].to(torch.float32)
        if (distorted.shape[-1] != 3):        
            distorted = distorted.unsqueeze(2)
        distorted = distorted.permute(2,0,1).unsqueeze(0)

        if reference.shape[1] != distorted.shape[1]:
            distorted = transforms.functional.rgb_to_grayscale(distorted)
        
        lpips[-1].append(lpips_metric(reference,distorted).detach().item())
        mse[-1].append(mse_metric(reference,distorted).detach().item())

In [62]:
PSNR = [[20*np.log10(255)-10*np.log10(d) for d in m] for m in mse]
PSNR = [np.mean(d) for d in PSNR]
PSNR

  PSNR = [[20*np.log10(255)-10*np.log10(d) for d in m] for m in mse]


[inf,
 23.182460226047187,
 24.766241227347265,
 26.674174333189445,
 26.257512162905442]

In [63]:
neg_log_lpips = [[-10*np.log10(d) for d in m] for m in lpips]
neg_log_lpips = [np.mean(d) for d in neg_log_lpips]
neg_log_lpips

  neg_log_lpips = [[-10*np.log10(d) for d in m] for m in lpips]


[inf,
 6.109442391611871,
 7.017269687247636,
 7.945727968631091,
 10.834666055816148]