In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import zlib
import numpy as np
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv
from datasets import load_dataset
import PIL.Image as Image

In [2]:
def pil_to_pt(img):
    t = transforms.functional.pil_to_tensor(img)
    t = t.to(torch.float)
    t = t/255
    t = t-0.5
    t = t.unsqueeze(0)
    return t
def pt_to_pil(t):
    t = t+0.5
    t = t*255
    t = torch.clamp(t, min=-0.49, max=255.49)
    t = t.round()
    t = t.to(torch.uint8)
    return t

In [3]:
class Network(CompressionModel):
    def __init__(self, N=128):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods


In [4]:
def lossy_analysis_transform(img):
    x = pil_to_pt(img).to("cuda")
    z = net.encode(x).round().to(torch.int8).detach().to("cpu").numpy()
    return z
    
def lossless_entropy_encode(z):
    original_shape = z.shape
    compressed_img = zlib.compress(z.tobytes(), level=9)
    return compressed_img, original_shape

def compress(img):
    z = lossy_analysis_transform(img)
    compressed_img, original_shape = lossless_entropy_encode(z)
    return compressed_img, original_shape
    
def compress_dataset(sample):
    img = sample['image']
    if (img.mode == 'L'):
        rgbimg = Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg
    if (img.mode == 'CMYK'):
        rgbimg = Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg
    
    compressed_image, original_shape = compress(img)
    sample['compressed_image'] = compressed_image
    sample['latent_shape'] = original_shape
    return sample

In [5]:
net = Network()
net = net.to("cuda")
checkpoint = torch.load("checkpoint.pth")
net.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [None]:
dataset = load_dataset("imagenet-1k",split='train')
dataset = dataset.map(compress_dataset)
dataset = dataset.remove_columns('image');
dataset.push_to_hub("danjacobellis/imagenet_RDAE",split='train')

In [None]:
dataset = load_dataset("imagenet-1k",split='test')
dataset = dataset.map(compress_dataset)
dataset = dataset.remove_columns('image');
dataset.push_to_hub("danjacobellis/imagenet_RDAE",split='test')

In [None]:
dataset = load_dataset("imagenet-1k",split='validation')
dataset = dataset.map(compress_dataset)
dataset = dataset.remove_columns('image');
dataset.push_to_hub("danjacobellis/imagenet_RDAE",split='validation')