In [1]:
from datasets import load_dataset
from utils import jpeg_compress, webp_compress, pad, crop, nn_compress, hific_lo_compress
import compressai
from evaluate import evaluator
from transformers import pipeline
import torch
import numpy as np
import matplotlib.pyplot as plt
from piq import LPIPS
from torch.nn import MSELoss

2023-10-26 15:33:14.584539: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-10-26 15:33:14.603349: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-26 15:33:14.603364: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-26 15:33:14.603380: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-26 15:33:14.607127: I tensorflow/core/platform/cpu_feature_g

In [2]:
def jpeg_compress_xray(sample):
    img,bpp = jpeg_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample
    
def webp_compress_xray(sample):
    img,bpp = webp_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

net_mbt2018 = compressai.zoo.mbt2018(quality=1, pretrained=True).eval().to("cpu")
def mbt2018_compress_xray(sample):
    img,bpp = nn_compress(sample['image'],net_mbt2018,"cpu")
    sample['image'] = img
    sample['bpp'] = bpp
    return sample

def hific_lo_compress_xray(sample):
    img,bpp = hific_lo_compress(sample['image'])
    sample['image'] = img
    sample['bpp' ] = bpp
    return sample

In [3]:
image_compression_methods = [
    jpeg_compress_xray,
    webp_compress_xray,
    mbt2018_compress_xray,
    hific_lo_compress_xray
]

In [4]:
xray = [load_dataset("keremberke/chest-xray-classification", "full", split="validation")]

In [5]:
for method in image_compression_methods:
    xray.append(xray[0].map(method))

Map:   0%|          | 0/1165 [00:00<?, ? examples/s]

In [6]:
image_bpp = [np.mean(method['bpp']) for method in xray[1:]]
image_bpp

[0.16469623457618024,
 0.045985783261802575,
 0.03231692596566524,
 0.010857920634388412]

In [7]:
lpips_metric = LPIPS()
mse_metric = MSELoss()



In [9]:
lpips = []
mse = []
for method in xray:
    method = method.with_format("torch")
    lpips.append([])
    mse.append([])
    for i_sample,sample in enumerate(xray[0].with_format("torch")):
        compressed_sample = method[i_sample]
        reference = sample['image'].to(torch.float32)
        if (reference.shape[-1] != 3):        
            reference = reference.unsqueeze(2)
        reference = reference.permute(2,0,1).unsqueeze(0)
        distorted = compressed_sample['image'].to(torch.float32)
        if (distorted.shape[-1] != 3):        
            distorted = distorted.unsqueeze(2)
        distorted = distorted.permute(2,0,1).unsqueeze(0)

        if reference.shape[1] != distorted.shape[1]:
            distorted = transforms.functional.rgb_to_grayscale(distorted)
        
        lpips[-1].append(lpips_metric(reference,distorted).detach().item())
        mse[-1].append(mse_metric(reference,distorted).detach().item())

In [10]:
PSNR = [[20*np.log10(255)-10*np.log10(d) for d in m] for m in mse]
PSNR = [np.mean(d) for d in PSNR]
PSNR

  PSNR = [[20*np.log10(255)-10*np.log10(d) for d in m] for m in mse]


[inf,
 30.006285440033686,
 32.69496335930339,
 34.703742060097206,
 36.44207814200317]

In [11]:
neg_log_lpips = [[-10*np.log10(d) for d in m] for m in lpips]
neg_log_lpips = [np.mean(d) for d in neg_log_lpips]
neg_log_lpips

  neg_log_lpips = [[-10*np.log10(d) for d in m] for m in lpips]


[inf, 6.851561133139928, 7.5847302963319, 7.799718846584859, 13.24957840692775]