In [1]:
import sys

sys.path.append("..")

import numpy as np
import torch
import random


from src.datasets import data, utils, configs
from torchvision import transforms as transforms

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x23153e18370>

In [2]:
# Step 1: Create a diagonal dataset, to get valid latents
n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "diagonal"
no_overlap = True
delta = 0.125

diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])
)

Generating images (sampling: diagonal): 100%|█████████████████████████████████████| 5000/5000 [00:46<00:00, 107.68it/s]


In [3]:
# Step 2: Create a dataloader
scale = torch.FloatTensor(
    [rng.max - rng.min for rng in default_cfg.get_ranges().values()]
).reshape(1, 1, -1)
scale = torch.cat([scale[:, :, :-4], scale[:, :, -3:-2]], dim=-1)


min_offset = torch.FloatTensor(
    [rng.min for rng in default_cfg.get_ranges().values()]
).reshape(1, 1, -1)
min_offset = torch.cat([min_offset[:, :, :-4], min_offset[:, :, -3:-2]], dim=-1)

batch_size = 128
loader = torch.utils.data.DataLoader(
    diagonal_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda b: utils.collate_fn_normalizer(b, min_offset, scale),
)

In [4]:
# Step 3: Make a permutation for every batch and create permuted dataset
perms = []
batch_len_accum = 0
for _, true_latents in loader:
    batch_len = true_latents.shape[0]
    perm = torch.randperm(batch_len) + batch_len_accum
    perms.append(perm)
    batch_len_accum += batch_len

perms_concated = torch.cat(perms)

permuted_latents = torch.cat([diagonal_dataset.z[:, 0].unsqueeze(1), diagonal_dataset.z[perms_concated, 1].unsqueeze(1)], dim=1)

permuted_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    z=permuted_latents, # here we using the provided latents instead of sampling
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])

)

permuted_dataloader = torch.utils.data.DataLoader(
    permuted_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda b: utils.collate_fn_normalizer(b, min_offset, scale),
)

Generating images (sampling: diagonal): 100%|█████████████████████████████████████| 5000/5000 [00:45<00:00, 109.59it/s]


In [86]:
# Step 4: Load the model
from src.models import base_models, slot_attention

checkpoint_path = ("D:/git_projects/bethgelab/lab_rotation/object_centric_ood/notebooks/models/"
                   "SlotAttention, diagonal, 2 objects/"
                   "2027.pt")
checkpoint = torch.load(checkpoint_path)


#SlotAttention
encoder = slot_attention.SlotAttentionEncoder(
    resolution=(64, 64),
    hid_dim=16,
    ch_dim=32,
    dataset_name="dsprites",
)
decoder = slot_attention.SlotAttentionDecoder(
    hid_dim=16,
    ch_dim=32,
    resolution=(64, 64),
    dataset_name="dsprites",
)
model = slot_attention.SlotAttentionAutoEncoder(
    encoder=encoder,
    decoder=decoder,
    num_slots=2,
    num_iterations=3,
    hid_dim=16,
    dataset_name="dsprites",
)

decoder_hook = model.decode

# SlotMLPAdditive
# model = base_models.SlotMLPAdditive(3, 2, 16)
# decoder_hook = model.decoder



model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

SlotAttentionAutoEncoder(
  (encoder_cnn): SlotAttentionEncoder(
    (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (conv4): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (encoder_pos): SoftPositionEmbed(
      (embedding): Linear(in_features=4, out_features=16, bias=True)
    )
  )
  (decoder_cnn): SlotAttentionDecoder(
    (conv_list): ModuleList(
      (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): ReLU()
      (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (3): ReLU()
      (4): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (5): ReLU()
      (6): Conv2d(32, 4, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (decoder_pos): SoftPositionEmbed(
      (embedding): Linear(in_f

In [87]:
# Step 5: Get the latents for the original dataset and reshuffle them
from src.metrics import hungarian_slots_loss

latents = []
with torch.no_grad():
    for images, true_latents in loader:
        true_figures = images[:, :-1, ...]
        images = images[:, -1, ...].squeeze(1)

        output = model(images)
        predicted_figures = output["reconstructed_figures"]
        
        figures_reshaped = true_figures.view(true_figures.shape[0], true_figures.shape[1], -1)

        predicted_figures = predicted_figures.permute(1, 0, 2, 3, 4)
        predicted_figures_reshaped = predicted_figures.reshape(
            predicted_figures.shape[0], predicted_figures.shape[1], -1
        )

        _, indexes = hungarian_slots_loss(figures_reshaped, predicted_figures_reshaped)
        
        indexes = torch.LongTensor(indexes)
        predicted_latents = output["predicted_latents"].detach().cpu()
        true_latents = true_latents.detach().cpu()

        # shuffling predicted latents to match true latents
        predicted_latents = predicted_latents.gather(
            1,
            indexes[:, :, 1].unsqueeze(-1).expand(-1, -1, predicted_latents.shape[-1]),
        )
        latents.append(predicted_latents)

latents = torch.cat(latents)
# after this point we had z_hat matched to original, not permuted z

In [88]:
# Step 6: Permute predicted latents, and compare with permuted dataset
latents = torch.cat([latents[:, 0].unsqueeze(1), latents[perms_concated, 1].unsqueeze(1)], dim=1)

In [89]:
# Step 6: continued
mse = 0
batch_size_accum = 0
with torch.no_grad():
    for permuted_images, permuted_latents in permuted_dataloader:
        true_figures = permuted_images[:, :-1, ...]
        permuted_images = permuted_images[:, -1, ...].squeeze(1)
        
        output = decoder_hook(
            latents[batch_size_accum : batch_size_accum + len(permuted_images)]
        )
        imagined_images = output[0]
        
        batch_size_accum += len(permuted_images)
        # compare reconstructed images with imagined images

        mse += ((permuted_images - imagined_images) ** 2).sum() / len(permuted_dataset)

In [91]:
# Results

# Supervised model
model_1_mse = np.array([2.0411, 2.0226, 2.1136, 2.1914, 2.0097, 2.0147, 2.1721])
print("Supervised model: ", f"{model_1_mse.mean()=}",f"{model_1_mse.std()=}")

# # Unsupervised model
# model_2_mse = np.array([2.1929, 2.0903, 64.8243, 64.7224, 2.0591, 2.0205, 64.3132])
# print("Supervised model: ", f"{model_2_mse.mean()=}",f"{model_2_mse.std()=}")

# Unsupervised model
model_2_mse = np.array([2.1929, 2.0903, 2.0591, 2.0205]) # failed seeds excluded
print("Supervised model: ", f"{model_2_mse.mean()=}",f"{model_2_mse.std()=}")


# SlotAttention model
model_3_mse = np.array([7.4330, 20.6778, 26.6505, 5.7046, 9.1443])
print("SlotAttention model: ", f"{model_3_mse.mean()=}",f"{model_3_mse.std()=}")

Supervised model:  model_1_mse.mean()=2.0807428571428574 model_1_mse.std()=0.07174520107680378
Supervised model:  model_2_mse.mean()=2.0907 model_2_mse.std()=0.06397577666585992
SlotAttention model:  model_3_mse.mean()=13.92204 model_3_mse.std()=8.247602644793213


In [55]:
# For the debugging purposes
# import matplotlib.pyplot as plt
# plt.imshow(imagined_images.clip(0, 1)[1].permute(1, 2, 0))
# plt.show()
# plt.imshow(permuted_images[1].permute(1, 2, 0))
# plt.show()