In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import zlib
import numpy as np
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv
from datasets import load_dataset, Dataset, Image
import PIL

In [2]:
class Network(CompressionModel):
    def __init__(self, N=128):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods


In [3]:
def lossy_analysis_transform(img):
    x = img.to("cuda")
    z = net.encode(x).round().to(torch.int8).detach().to("cpu").numpy()
    return z
    
def lossless_entropy_encode(z):
    original_shape = z.shape
    compressed_img = zlib.compress(z.tobytes(), level=9)
    return compressed_img, original_shape

def prep_dataset(sample):
    img = sample['image']
    sample['width'] = img.width
    sample['height'] = img.height

    if (img.mode == 'L') | (img.mode == 'CMYK') | (img.mode == 'RGBA'):
        rgbimg = PIL.Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg
    
    sample['image'] = Image().encode_example(img)

    return sample

In [4]:
net = Network()
net = net.to("cuda")
checkpoint = torch.load("checkpoint.pth")
net.load_state_dict(checkpoint['model_state_dict'])
batch_size = 64

# Training split

Cutoff for training split is 2048 samples

In [None]:
%%time
dataset = load_dataset("imagenet-1k",split='train')
dataset = dataset.map(prep_dataset)

In [None]:
train_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [None]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=2048):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        train_dataset = train_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

In [None]:
train_dataset.push_to_hub("danjacobellis/imagenet_batched_64",split='train')

# Test split

Cutoff for the test split is 256 samples

In [5]:
%%time
dataset = load_dataset("imagenet-1k",split='test')
dataset = dataset.map(prep_dataset)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

CPU times: user 3min 19s, sys: 8.55 s, total: 3min 27s
Wall time: 3min 31s


In [6]:
test_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [7]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=256):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        test_dataset = test_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/100000 [00:00<?, ? examples/s]

In [None]:
test_dataset.push_to_hub("danjacobellis/imagenet_batched_64",split='test')

# Validation split

Cutoff for training split is 64 samples

In [None]:
%%time
dataset = load_dataset("imagenet-1k",split='validation')
dataset = dataset.map(prep_dataset)

In [None]:
val_dataset = Dataset.from_dict({
    "img_batch" : [],
    "label_batch" : [],
    "width" : [],
    "height": [],
})

In [None]:
width = torch.tensor(dataset['width']);
height = torch.tensor(dataset['height']);
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)
N = 0;
while (sizes[N][1]>=64):
    N +=1
sizes = sizes[:N]

In [None]:
%%time
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        img_batch = filtered[ind]['image']
        img_batch = [Image().encode_example(pil_img) for pil_img in img_batch] 
        label_batch = filtered[ind]['label']

        val_dataset = val_dataset.add_item({
            "img_batch" : img_batch,
            "label_batch" : label_batch,
            "width" : w,
            "height": h,
        })

In [None]:
val_dataset.push_to_hub("danjacobellis/imagenet_batched_64",split='val')