**In this notebook we check whether the actual number of bits that our latents take up is equal to the BPP that we report.**

In [1]:
import torch

import pytorch_lightning as pl

import matplotlib.pyplot as plt

import models, train, datasets

# Helper functions

In [2]:
def load_compressor(path, model_class, map_location="cpu"):
    ckpt_params = torch.load(path, map_location=map_location)
    model = model_class(**ckpt_params["hyper_parameters"]).eval()
    model.load_state_dict(ckpt_params["state_dict"])
    model.update_bottleneck_values()
    
    return model

# Checking on a single batch

In [36]:
pl.seed_everything(21)

DATASET = "clevr"

TASKS = ["rgb", "depth_euclidean", "normal"]

BATCH_SIZE = 128

_, DATALOADER = train.get_dataloader(
    dataset_name=DATASET,
    batch_size=BATCH_SIZE,
    num_workers=0,
    tasks=TASKS,
    is_train=True,
    collate=datasets.transforms.make_collate_fn(TASKS),
)

BATCH = next(iter(DATALOADER))


MODEL_PATHS = {
    # lambda 0.1
    # models.MultiTaskMixedLatentCompressor: "../model_weights/model-2-01-balance.ckpt",
    # models.MultiTaskDisjointLatentCompressor: "../model_weights/model-3-01-balance.ckpt",
    # models.MultiTaskSharedLatentCompressor: "../model_weights/model-4-01-balance.ckpt"

    # lambda 0.01
    # models.MultiTaskMixedLatentCompressor: "../model_weights/model-2-01-balance.ckpt",
    # models.MultiTaskDisjointLatentCompressor: "../model_weights/model-3-01-balance.ckpt",
    # models.MultiTaskSharedLatentCompressor: "../model_weights/model-4-01-balance.ckpt"
    
    # lambda 0.001
    # models.MultiTaskMixedLatentCompressor: "../model_weights/model-2-0001-v10-balance.ckpt",
    models.MultiTaskDisjointLatentCompressor: "../model_weights/model-3-0001-v10-balance.ckpt",
    # models.MultiTaskSharedLatentCompressor: "../model_weights/model-4-0001-v10-balance.ckpt",
}

MODEL_NAME = {
    models.MultiTaskMixedLatentCompressor: "Mixed",
    models.MultiTaskDisjointLatentCompressor: "Disjoint",
    models.MultiTaskSharedLatentCompressor: "Shared"
}

Global seed set to 21


In [37]:
for model_class, weights_path in MODEL_PATHS.items():
    print(f"--- {MODEL_NAME[model_class]} ---")
    compressor = load_compressor(weights_path, model_class)
    _, _, _ = compressor.compress(BATCH, print_info=True)

--- Disjoint ---
Number of actual bytes in a string is: 12648, which gives a BPP = 0.004
Estimated BPP (compression loss) is: 0.004


In [38]:
ans, number_of_bytes, _ = compressor.compress(BATCH, print_info=True)

Number of actual bytes in a string is: 12648, which gives a BPP = 0.004
Estimated BPP (compression loss) is: 0.004


In [39]:
with open("lol.txt", "wb") as f:
    for latents in ans['strings']:
        for string in latents:
            f.write(string)