In [1]:
from PIL import Image
from datasets import load_dataset, Dataset
from transformers import CLIPProcessor, CLIPModel
from IPython.display import display
from io import BytesIO
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision.transforms import PILToTensor, ToTensor, ToPILImage
from dance.image import compress, decompress, preprocess, postprocess, RateDistortionAutoEncoder
import compressai
import zlib

In [2]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14").cuda()
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
dataset = load_dataset("danjacobellis/aria_ea_rgb_100k",split="validation")

Resolving data files:   0%|          | 0/70 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/70 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/18 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/18 [00:00<?, ?it/s]

In [3]:
classes = [
    "kitchen",
    "cooking",
    "food",
    "drink",
    "spill",
    "table",
    "television",
    "phone",
    "laptop",
    "video game",
    "board game",
    "clothes",
    "laundry",
]

In [4]:
def get_clip(sample):
    with torch.no_grad():
        image = sample['image']
        inputs = clip_processor(text=classes, images=image, return_tensors="pt", padding=True)
        for k in inputs.keys():
            if hasattr(inputs[k], "device"):
                inputs[k] = inputs[k].cuda()
        outputs = clip_model(**inputs)
        sample['clip_logit'] = outputs.logits_per_image
        return sample

In [None]:
clip_dataset = dataset.map(get_clip);
clip_logit = clip_dataset.with_format("torch")['clip_logit']

In [None]:
def jpeg_compress(img, rr=3, subsampling=0, quality=5):
    w = img.width
    h = img.height
    with BytesIO() as f:
        img = img.resize((w//rr,h//rr))
        img = img.save(f,
                       format='JPEG',
                       subsampling=subsampling,
                       quality=quality
                      )
        img = f.getvalue()
    bpp = 8*len(img)/(w*h)
    img = Image.open(BytesIO(img)).resize((w,h))
    return img,bpp

In [None]:
def get_clip_jpeg(sample):
    with torch.no_grad():
        image, bpp = jpeg_compress(sample['image'])
        inputs = clip_processor(text=classes, images=image, return_tensors="pt", padding=True)
        for k in inputs.keys():
            if hasattr(inputs[k], "device"):
                inputs[k] = inputs[k].cuda()
        outputs = clip_model(**inputs)
        sample['clip_logit'] = outputs.logits_per_image
        sample['bpp'] = bpp
        return sample

In [None]:
jpeg_dataset = dataset.map(get_clip_jpeg)
jpeg_clip_logit = jpeg_dataset.with_format("torch")['clip_logit']
jpeg_bpp = jpeg_dataset.with_format("torch")['bpp']

In [None]:
def dance_compress(img,device):
    with torch.no_grad():
        img = PILToTensor()(img).permute(1,2,0)
        batch = preprocess(img,device)
        compressed_img, original_shape = compress(batch, dance_model)
        bpp = 8*len(compressed_img)/(batch.shape[2]*batch.shape[3])
        rec = decompress(compressed_img, original_shape, dance_model)
        return postprocess(rec),bpp

In [None]:
def get_clip_dance(sample):
    with torch.no_grad():
        image, bpp = dance_compress(sample['image'],dance_model.device)
        inputs = clip_processor(text=classes, images=image, return_tensors="pt", padding=True)
        for k in inputs.keys():
            if hasattr(inputs[k], "device"):
                inputs[k] = inputs[k].cuda()
        outputs = clip_model(**inputs)
        sample['clip_logit'] = outputs.logits_per_image
        sample['bpp'] = bpp
        return sample

In [None]:
device = "cuda"
dance_model = RateDistortionAutoEncoder()
checkpoint = torch.load("dance/image.pth")
dance_model.load_state_dict(checkpoint['model_state_dict'])
dance_model = dance_model.to(device)

In [None]:
dance_dataset = dataset.map(get_clip_dance)
dance_clip_logit = dance_dataset.with_format("torch")['clip_logit']
dance_bpp = dance_dataset.with_format("torch")['bpp']

In [None]:
def dgml_compress(img,model,device):
    w = img.width
    h = img.height
    
    if (img.mode == 'L') | (img.mode == 'CMYK') | (img.mode == 'RGBA'):
        rgbimg = PIL.Image.new("RGB", img.size)
        rgbimg.paste(img)
        img = rgbimg
    
    x = ToTensor()(img).unsqueeze(0)
    orig_size = (x.size(2), x.size(3))
    x = x.to(device)
    
    with torch.no_grad():
        z = model.g_a(x)
        compressed = z.to(torch.int8).detach().cpu().numpy()
        num_bytes = len(zlib.compress(compressed.tobytes(), level=9))
        recovered  = model.g_s(z)
    recovered.clamp_(0, 1);
    img = ToPILImage()(recovered.squeeze())
    bpp = 8*num_bytes/(w*h)
    return img,bpp

In [None]:
def get_clip_dgml(sample):
    with torch.no_grad():
        image, bpp = dgml_compress(sample['image'],model=dgml_model,device="cuda")
        inputs = clip_processor(text=classes, images=image, return_tensors="pt", padding=True)
        for k in inputs.keys():
            if hasattr(inputs[k], "device"):
                inputs[k] = inputs[k].cuda()
        outputs = clip_model(**inputs)
        sample['clip_logit'] = outputs.logits_per_image
        sample['bpp'] = bpp
        return sample

In [None]:
device = "cuda"
dgml_model = compressai.zoo.cheng2020_attn(quality=1, pretrained=True)
dgml_model = dgml_model.to(device)

In [None]:
dgml_dataset = dataset.map(get_clip_dgml)
dgml_clip_logit = dgml_dataset.with_format("torch")['clip_logit']
dgml_bpp = dgml_dataset.with_format("torch")['bpp']

In [None]:
combined_dataset = Dataset.from_dict(
    {"jpeg_bpp": jpeg_bpp,
     "dance_bpp": dance_bpp,
     "dgml_bpp": dgml_bpp,
     **{c: clip_logit[:, 0, i] for i, c in enumerate(classes)},
     **{"jpeg_" + c: jpeg_clip_logit[:, 0, i] for i, c in enumerate(classes)},
     **{"dance_" + c: dance_clip_logit[:, 0, i] for i, c in enumerate(classes)},
     **{"dgml_" + c: dgml_clip_logit[:, 0, i] for i, c in enumerate(classes)}})
combined_dataset.push_to_hub("danjacobellis/area_compression",split="validation")