In [None]:
from tqdm import tqdm
from PIL import Image
import math
import torch
import torchvision.transforms.functional as tvf
torch.set_grad_enabled(False)

from lvae import get_model, known_datasets

In [None]:
img_dir = known_datasets['kodak']
print(img_dir)
fig_save_path = 'bpp-distribution-abs.pdf'

def get_bpp_distribution(model: torch.nn.Module):
    device = next(model.parameters()).device
    img_paths = list(img_dir.rglob('*.*'))
    bpps_all = None
    log2_e = math.log2(math.e)
    for impath in img_paths:
        im = tvf.to_tensor(Image.open(impath)).unsqueeze_(0).to(device=device)
        _, stats_all = model.forward_end2end(im, lmb=model.default_lmb)
        nB, imC, imH, imW = im.shape
        npix = float(imH * imW)
        bpps = [stat['kl'].sum() * log2_e / npix for stat in stats_all]
        bpps = torch.stack(bpps)
        bpps_all = bpps if (bpps_all is None) else (bpps_all + bpps)
    bpps_all = bpps_all / len(img_paths)
    return bpps_all


In [None]:
model = get_model('qarv_base', pretrained=True)

model = model.cuda()
model.eval()

# lambdas = [16, 32, 64, 128, 256, 512, 1024, 2048]
steps = 15
_loglow, _loghigh = math.log(model.lmb_range[0]), math.log(model.lmb_range[1])
lambdas = torch.linspace(_loglow, _loghigh, steps=steps).exp()

stats_all = []
for lmb in tqdm(lambdas):
    model.default_lmb = lmb
    bpps = get_bpp_distribution(model)
    stats_all.append(bpps)
stats_all = torch.stack(stats_all, dim=0)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

data = stats_all.cpu().numpy()
# data = np.flip(data, axis=1) # change the order to Z_N -> Z_1
data_cum = data.cumsum(axis=1)
# category_colors = plt.get_cmap('Set3')
category_colors = plt.get_cmap('tab20')
category_colors = category_colors(np.linspace(0, 1, data.shape[1]))

fig, ax = plt.subplots(figsize=(13.4, 4.8))

num_latents = data.shape[1]
labels         = [f'$\lambda = {lmb:.0f}$' for lmb in lambdas]
category_names = [f'$Z_{{ {i} }}$'         for i   in range(1, num_latents+1)]
for i, (colname, color) in enumerate(zip(category_names, category_colors)):
    widths = data[:, i]
    starts = data_cum[:, i] - widths
    rects = ax.barh(labels, widths, left=starts, height=0.8,
                    label=colname, color=color)

legend_handles, legend_labels = ax.get_legend_handles_labels()
ax.legend(
    legend_handles[::-1], legend_labels[::-1], ncol=len(category_names), loc='lower left',
    fontsize=12.2, bbox_to_anchor=(0.2, 1.0), handletextpad=0.24
)
# ax.set_title('Distribution of bit rates over latent variables')
ax.tick_params(axis='both', which='major', labelsize=12)
ax.set_ylim(-0.6, len(labels)-0.4)
ax.invert_yaxis()
# ax.yaxis.set_label_position("right")
ax.yaxis.tick_right()
# ax.xaxis.set_visible(False)
ax.set_xticks(np.arange(0, 2.31, 0.1))
ax.set_xlim(0, np.sum(data, axis=1).max()+0.02)
ax.invert_xaxis()
ax.set_xlabel('Bits per pixel (bpp)', fontdict={'size':14})
fig.tight_layout()
# plt.subplots_adjust(left=0.106, right=0.97, bottom=0.1, top=0.94)
plt.subplots_adjust(bottom=0.11, top=0.91)
fig.savefig(fig_save_path)