In [1]:
import PIL
from datasets import load_dataset
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as nnf
from torchvision.transforms import ToPILImage
from torchvision import transforms
from torchmetrics.classification import MulticlassJaccardIndex
import numpy as np
from transformers import AutoModel
from compressai.models import CompressionModel
from compressai.entropy_models import EntropyBottleneck
from compressai.layers import GDN1
import zlib

In [2]:
ade20k = load_dataset("scene_parse_150",split='validation')

In [3]:
class LinearSegmentationHead(nn.Module):
    def __init__(self, in_channels=1536, num_classes=150):
        super(LinearSegmentationHead, self).__init__()
        self.conv_seg = nn.Conv2d(in_channels, num_classes, kernel_size=1)

    def forward(self, x):
        x = self.conv_seg(x)
        return x

In [4]:
model = AutoModel.from_pretrained('facebook/dinov2-giant')
model = model.cuda()

In [5]:
def conv(in_channels, out_channels, kernel_size=5, stride=2, groups=32):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
        groups=groups
    )

In [6]:
def deconv(in_channels, out_channels, kernel_size=5, stride=2, groups=32):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
        groups=groups
    )

In [7]:
class RateDistortionAutoEncoder(CompressionModel):
    def __init__(self, N=4096):
        super().__init__()
        self.entropy_bottleneck = EntropyBottleneck(N)
        self.encode = nn.Sequential(
            conv(1536, N, kernel_size=1, stride=1),
            GDN1(N),
            conv(N, N, kernel_size=5, stride=2),
            GDN1(N),
            conv(N, N, kernel_size=5, stride=2),
        )

        self.decode = nn.Sequential(
            deconv(N, N, kernel_size=5, stride=2),
            GDN1(N, inverse=True),
            deconv(N, N, kernel_size=5, stride=2),
            GDN1(N, inverse=True),
            deconv(N, 1536, kernel_size=1, stride=1),
        )

    def forward(self, x):
        y = self.encode(x)
        y_hat, y_likelihoods = self.entropy_bottleneck(y)
        x_hat = self.decode(y_hat)
        return x_hat, y_likelihoods

In [8]:
head_checkpoint = torch.load("dinov2_vitg14_ade20k_linear_head.pth")
new_state_dict = {key.replace("decode_head.", ""): value for key, value in head_checkpoint['state_dict'].items()}
new_state_dict.pop('bn.weight', None)
new_state_dict.pop('bn.bias', None)
new_state_dict.pop('bn.running_var', None)
new_state_dict.pop('bn.running_mean', None)
new_state_dict.pop('bn.num_batches_tracked', None)
head = LinearSegmentationHead()
head.load_state_dict(new_state_dict)
head = head.to("cuda")

In [9]:
rdae_checkpoint = torch.load("checkpoint_dino_rdae_20.pth", map_location={'cuda:0': 'cpu'})
rdae = RateDistortionAutoEncoder()
rdae.load_state_dict(rdae_checkpoint['model_state_dict'])
rdae = rdae.to("cuda")

In [10]:
def lossy_compress_patch_tokens(x):
    with torch.no_grad():
        z = rdae.encode(x)
        z = z.round()
        z = z.clamp(-128,127)
        z = z.to(torch.int8)
        z = z.to("cpu").detach().numpy()
        original_shape = z.shape
        compressed = zlib.compress(z.tobytes(), level=9)
        decompressed = zlib.decompress(compressed)
        ẑ = np.frombuffer(decompressed, dtype=np.int8)
        ẑ = ẑ.reshape(original_shape)
        ẑ = torch.tensor(ẑ)
        ẑ = ẑ.to(torch.float).to("cuda")
        x̂ = rdae.decode(ẑ)
        bps = 8*len(compressed)/(1536*32*18)
        return x̂,bps

In [14]:
%%time
iou = []
bps = []
for i_sample,sample in enumerate(ade20k):
    img = sample['image'].resize((448,252))
    ground_truth = sample['annotation']
    while ground_truth.width > 1000:
        ground_truth = ground_truth.resize((ground_truth.width//2,ground_truth.height//2),PIL.Image.Resampling.NEAREST)
    x = transforms.ToTensor()(img).unsqueeze(0).to("cuda")
    with torch.no_grad():
        
        y = model.forward(x)[0]
        cls_token = y[:, 0]
        patch_tokens = y[:, 1:]
        patch_tokens = patch_tokens.reshape((1,18,32,1536)).permute((0,3,1,2))

        patch_tokens,bps_i = lossy_compress_patch_tokens(patch_tokens)
        bps.append(bps_i)
        
        patch_tokens = nnf.interpolate(patch_tokens,
                                       size=(ground_truth.height,ground_truth.width),
                                       mode='bilinear')
        logits = head(patch_tokens)
        predicted = ToPILImage()(logits[0].argmax(dim=0).to(torch.uint8))
        
        x1 = transforms.PILToTensor()(ground_truth) 
        x2 = transforms.PILToTensor()(predicted)
        x1 = x1-1

        iou.append(
            MulticlassJaccardIndex(num_classes=151,average='micro',ignore_index=255)
            (x1,x2).item()
        )

CPU times: user 6min 49s, sys: 1.11 s, total: 6min 51s
Wall time: 6min 7s


In [15]:
np.mean(iou)

0.3364191932779795

In [16]:
32/np.mean(bps)

40.31396681539592