In [10]:
import sys

sys.path.append("..")

import numpy as np
import torch
import random


from src.datasets import data, utils, configs

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x1bf1cb28270>

In [11]:
# Step 1: Create a diagonal dataset, to get valid latents
n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "diagonal"
no_overlap = True
delta = 0.125

diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
)

Generating images (sampling: diagonal): 100%|█████████████████████████████████████| 5000/5000 [00:36<00:00, 135.78it/s]


In [27]:
# Step 2: Create a dataloader
scale = torch.FloatTensor(
    [rng.max - rng.min for rng in default_cfg.get_ranges().values()]
).reshape(1, 1, -1)
scale = torch.cat([scale[:, :, :-4], scale[:, :, -3:-2]], dim=-1)


min_offset = torch.FloatTensor(
    [rng.min for rng in default_cfg.get_ranges().values()]
).reshape(1, 1, -1)
min_offset = torch.cat([min_offset[:, :, :-4], min_offset[:, :, -3:-2]], dim=-1)

batch_size = 128
loader = torch.utils.data.DataLoader(
    diagonal_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda b: utils.collate_fn_normalizer(b, min_offset, scale),
)

In [28]:
# Step 3: Make a permutation for every batch and create permuted dataset
perms = []
batch_len_accum = 0
for _, true_latents in loader:
    batch_len = true_latents.shape[0]
    perm = torch.randperm(batch_len) + batch_len_accum
    perms.append(perm)
    batch_len_accum += batch_len

perms_concated = torch.cat(perms)

permuted_latents = diagonal_dataset.z[perms_concated]

permuted_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    z=permuted_latents,
)

permuted_dataloader = torch.utils.data.DataLoader(
    permuted_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda b: utils.collate_fn_normalizer(b, min_offset, scale),
)

Generating images (sampling: diagonal): 100%|█████████████████████████████████████| 5000/5000 [00:35<00:00, 139.71it/s]


In [30]:
# Step 4: Load the model
model = ...
model.eval()

In [None]:
# Step 5: Get the latents for the original dataset and reshuffle them
from src.metrics import hungarian_slots_loss

latents = []
with torch.no_grad():
    for images, true_latents in loader:
        output = model(images)
        predicted_latents = output["predicted_latents"]

        _, indexes = hungarian_slots_loss(predicted_latents, true_latents)

        indexes = torch.LongTensor(indexes)
        predicted_latents = predicted_latents.detach().cpu()
        true_latents = true_latents.detach().cpu()

        # shuffling predicted latents to match true latents
        predicted_latents = predicted_latents.gather(
            1,
            indexes[:, :, 1].unsqueeze(-1).expand(-1, -1, true_latents.shape[-1]),
        )
        latents.append(predicted_latents)

latents = torch.cat(latents)
# after this point we had z_hat matched to original, not permuted z

In [None]:
# Step 6: Permute predicted latents, and compare with permuted dataset
latents = latents[perms_concated]
mse = 0
batch_size_accum = 0
with torch.no_grad():
    for permuted_images, permuted_latents in permuted_dataloader:
        imagined_images = model.decoder(
            latents[batch_size_accum : batch_size_accum + len(images)]
        )

        batch_size_accum += len(permuted_images)
        # compare reconstructed images with imagined images

        mse += ((images - imagined_images) ** 2).sum() / len(permuted_dataset)