In [1]:
from datasets import load_dataset
from torch.utils.data import DataLoader
from fastprogress.fastprogress import master_bar, progress_bar
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from compressai.models import CompressionModel
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN1
import PIL
from torchvision.transforms import ToPILImage
from IPython.display import display
import zlib
from transformers import AutoImageProcessor, AutoModelForImageClassification, pipeline

2023-11-22 08:18:31.302064: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-22 08:18:31.302109: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-22 08:18:31.302158: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
checkpoint = torch.load("checkpoint_dino_rdae_215.pth", map_location={'cuda:0': 'cpu'})
# checkpoint = torch.load("checkpoint_dino_rdae_215.pth")

In [None]:
imagenet_valid = load_dataset("danjacobellis/imagenet_dino",split='validation').with_format("torch")

In [None]:
def conv(in_channels, out_channels, kernel_size=5, stride=2, groups=32):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
        groups=groups
    )

In [None]:
def deconv(in_channels, out_channels, kernel_size=5, stride=2, groups=32):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
        groups=groups
    )

In [None]:
class RateDistortionAutoEncoder(CompressionModel):
    def __init__(self, N=4096):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(1536, N, kernel_size=1, stride=1),
            GDN1(N),
            conv(N, N, kernel_size=5, stride=2),
            # GDN1(N),
            # conv(N, N, kernel_size=5, stride=2),
        )

        self.decode = nn.Sequential(
            # deconv(N, N, kernel_size=5, stride=2),
            # GDN1(N, inverse=True),
            deconv(N, N, kernel_size=5, stride=2),
            GDN1(N, inverse=True),
            deconv(N, 1536, kernel_size=1, stride=1),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

In [None]:
net = RateDistortionAutoEncoder()
net.load_state_dict(checkpoint['model_state_dict'])

In [None]:
def lossy_compress_patch_tokens(sample):
    with torch.no_grad():
        x = sample['patch_tokens']
        xr = x.reshape((1,16,16,1536)).permute((0,3,1,2))
        z = net.encode(xr)
        z = z.round()
        z = z.clamp(-128,127)
        z = z.to(torch.int8)
        z = z.numpy()
        original_shape = z.shape
        compressed = zlib.compress(z.tobytes(), level=9)
        decompressed = zlib.decompress(compressed)
        ẑ = np.frombuffer(decompressed, dtype=np.int8)
        ẑ = ẑ.reshape(original_shape)
        ẑ = torch.tensor(ẑ)
        ẑ = ẑ.to(torch.float)
        x̂ = net.decode(ẑ)
        x̂ = x̂.permute((0, 2, 3, 1)).reshape((1, 256, 1536))
        sample['patch_tokens'] = x̂
        sample['bps'] = 8*len(compressed)/(1536*16*16)
        return sample

In [None]:
imagenet_valid = imagenet_valid.map(lossy_compress_patch_tokens)

In [3]:
task_evaluator = evaluator("image-classification")

In [4]:
processor = AutoImageProcessor.from_pretrained('facebook/dinov2-giant-imagenet1k-1-layer')
model = AutoModelForImageClassification.from_pretrained('facebook/dinov2-giant-imagenet1k-1-layer')
pipe = pipeline(
            task="image-classification",
            model=model,
            image_processor=processor,
            device="cuda:0"
        )

In [5]:
resnet = AutoModelForImageClassification.from_pretrained("microsoft/resnet-50")
label2id = resnet.config.label2id
del resnet

In [6]:
%%time
results = task_evaluator.compute(
                model_or_pipeline=pipe,
                data=imagenet,
                metric="accuracy",
                label_mapping=label2id)

CPU times: user 2h 22min 10s, sys: 13.8 s, total: 2h 22min 23s
Wall time: 27min 35s
