In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
import zlib
import numpy as np
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN
from compressai.models import CompressionModel
from compressai.models.utils import conv, deconv
from datasets import load_dataset, Dataset
import PIL.Image as Image

In [2]:
class Network(CompressionModel):
    def __init__(self, N=128):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
        )

        self.decode = nn.Sequential(
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods


In [3]:
def lossy_analysis_transform(img):
    x = img.to("cuda")
    z = net.encode(x).round().to(torch.int8).detach().to("cpu").numpy()
    return z
    
def lossless_entropy_encode(z):
    original_shape = z.shape
    compressed_img = zlib.compress(z.tobytes(), level=9)
    return compressed_img, original_shape

def prep_dataset(sample):
    img = sample['image']
    sample['width'] = img.width
    sample['height'] = img.height

    if (img.mode == 'L') | (img.mode == 'CMYK') | (img.mode == 'RGBA'):
        rgbimg = Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg

    t = transforms.functional.pil_to_tensor(img)
    t = t.to(torch.float)
    t = t/255
    t = t-0.5
    
    sample['img_tensor'] = t
    return sample

In [4]:
net = Network()
net = net.to("cuda")
checkpoint = torch.load("checkpoint.pth")
net.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [5]:
%%time
dataset = load_dataset("imagenet-1k",split='train[0:250000]')
dataset = dataset.map(prep_dataset)
dataset = dataset.remove_columns('image')
dataset = dataset.with_format("torch")

Map:   0%|          | 0/250000 [00:00<?, ? examples/s]

CPU times: user 3h 15min 46s, sys: 15min 18s, total: 3h 31min 5s
Wall time: 49min 45s


In [6]:
width = dataset['width'];
height = dataset['height'];
unique_pairs = torch.unique(torch.stack([width, height], dim=1), dim=0)
pair_counts = {(w.item(), h.item()): ((width == w) & (height == h)).sum().item() 
               for w, h in unique_pairs}
sizes = sorted(pair_counts.items(), key=lambda x: x[1], reverse=True)

# only keep aspect ratios that have at least 1000 examples 
N = 0;
while (sizes[N][1]>=1000):
    N +=1
sizes = sizes[:N]

In [7]:
%%time
batch_size = 64
compressed_batch = []
label = []
latent_size = []
for size, count in sizes:
    w = size[0]; h = size[1]
    filtered = dataset.filter(lambda x: x['width']==w and x['height']==h)
    for i_batch in range(len(filtered)//batch_size):
        ind = range(i_batch * batch_size, (i_batch + 1) * batch_size)
        batch_img = filtered[ind]['img_tensor']
        z = lossy_analysis_transform(batch_img)
        compressed = [lossless_entropy_encode(z[i])[0] for i in range(batch_size)]
        batch_label = filtered[ind]['label']
        compressed_batch.append(compressed)
        label.append(batch_label)
        latent_size.append(z.shape)

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/250000 [00:00<?, ? examples/s]

CPU times: user 5d 17h 14min 58s, sys: 1h 27min 3s, total: 5d 18h 42min 1s
Wall time: 7h 49min 5s


In [8]:
%%time
new_dataset = Dataset.from_dict({
    "compressed_batch" : compressed_batch,
    "label" : label,
    "latent_size" : latent_size})

CPU times: user 238 ms, sys: 716 ms, total: 953 ms
Wall time: 1.18 s


In [9]:
new_dataset.push_to_hub("danjacobellis/imagenet_RDAE_batched_250k",split='train')

Pushing dataset shards to the dataset hub:   0%|          | 0/2 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

In [10]:
new_dataset

Dataset({
    features: ['compressed_batch', 'label', 'latent_size'],
    num_rows: 2062
})

In [None]:
# dataset = load_dataset("imagenet-1k",split='test')
# dataset = dataset.map(compress_dataset)
# dataset = dataset.remove_columns('image');
# dataset.push_to_hub("danjacobellis/imagenet_RDAE_dry",split='test')

In [None]:
# dataset = load_dataset("imagenet-1k",split='validation')
# dataset = dataset.map(compress_dataset)
# dataset = dataset.remove_columns('image');
# dataset.push_to_hub("danjacobellis/imagenet_RDAE_dry",split='validation')