In [None]:
from datasets import load_dataset
from utils import jpeg_compress, webp_compress, pad, crop, nn_compress, hific_lo_compress
import compressai
from evaluate import evaluator
from transformers import pipeline
import torch
import numpy as np
import matplotlib.pyplot as plt
from piq import LPIPS
from torch.nn import MSELoss

In [None]:
def jpeg_compress_in(sample):
    img,bpp = jpeg_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample
    
def webp_compress_in(sample):
    img,bpp = webp_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

net_mbt2018 = compressai.zoo.mbt2018(quality=1, pretrained=True).eval().to("cpu")
def mbt2018_compress_in(sample):
    img,bpp = nn_compress(sample['image'],net_mbt2018,"cpu")
    sample['image'] = img
    sample['bpp'] = bpp
    return sample

def hific_lo_compress_in(sample):
    img,bpp = hific_lo_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

In [None]:
image_compression_methods = [
    jpeg_compress_in,
    webp_compress_in,
    mbt2018_compress_in,
    hific_lo_compress_in
]

In [None]:
imagenet = [load_dataset("imagenet-1k", split="validation[:1000]")]

In [None]:
for method in image_compression_methods:
    imagenet.append(imagenet[0].map(method))

In [None]:
image_bpp = [np.mean(method['bpp']) for method in imagenet[1:]]
image_bpp

In [None]:
lpips_metric = LPIPS()
mse_metric = MSELoss()

In [None]:
%%time
lpips = []
mse = []
for method in imagenet:
    method = method.with_format("torch")
    lpips.append([])
    mse.append([])
    for i_sample,sample in enumerate(beans[0].with_format("torch")):
        compressed_sample = method[i_sample]
        reference = sample['image'].to(torch.float32).permute(2,0,1).unsqueeze(0)
        distorted = compressed_sample['image'].to(torch.float32).permute(2,0,1).unsqueeze(0)
        lpips[-1].append(lpips_metric(reference,distorted).detach().item())
        mse[-1].append(mse_metric(reference,distorted).detach().item())

In [None]:
PSNR = [[20*np.log10(255)-10*np.log10(d) for d in m] for m in mse]
PSNR = [np.mean(d) for d in PSNR]
PSNR

In [None]:
neg_log_lpips = [[-10*np.log10(d) for d in m] for m in lpips]
neg_log_lpips = [np.mean(d) for d in neg_log_lpips]
neg_log_lpips