In [1]:
import sys
import os

sys.path.append("..")

import numpy as np
import torch
import random
import ipywidgets as widgets

import matplotlib.pyplot as plt

from src.datasets import data, configs, utils
from src.datasets.utils import PreGeneratedDataset

from torchvision import transforms as transforms

seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fd1982ef610>

The dataset was created in the similar manner as you could see below. 
The exact version can be downloaded at [will be added later].

In [2]:
# Filtering code, to remove overlapping objects OOD
def filter_objects(latents, max_samples=5000, threshold=0.2, sort=False):
    """
    Filter objects based on their Euclidean distance.
    Args:
        latents: Tensor of shape (batch_size, n_slots, n_latents)
        max_objects: Number of objects to keep at most
        threshold: Distance threshold
        sort: Whether to sort the objects by distance
    """
    N, slots, _ = latents.size()
    mask = torch.zeros(N, dtype=bool)

    # Compute Euclidean distance for each pair of slots in each item
    for n in range(N):
        slots_distances = torch.cdist(latents[n, :, :2], latents[n, :, :2], p=2)
        slots_distances.fill_diagonal_(float("inf"))  # Ignore distance to self

        # Only keep samples in which no two objects are closer than the threshold
        min_distance = slots_distances.min().item()
        if min_distance >= threshold:
            mask[n] = True

    # If all objects are "close", print a message and return
    if not torch.any(mask):
        print("No objects were found that meet the distance threshold.")
        return None, []

    # Apply the mask to the latents
    filtered_samples = latents[mask]
    filtered_indices = torch.arange(N)[mask]

    # If the number of filtered samples exceeds the maximum, truncate them
    if filtered_samples.size(0) > max_samples:
        filtered_samples = filtered_samples[:max_samples]
        filtered_indices = filtered_indices[:max_samples]

    # FIXME this part could be made more efficient and readable by saving the
    #   min distances in the step above and reusing them here
    # FIXME setting sort=True is throwing some errors for me
    if sort:
        # Sort the filtered objects by minimum distance to any other object
        min_distances = torch.zeros(mask.sum().item())
        for i, n in enumerate(torch.where(mask)[0]):
            slots_distances = torch.cdist(latents[n], latents[n], p=2)
            slots_distances.fill_diagonal_(float("inf"))
            min_distances[i] = slots_distances.min().item()

        indices = torch.argsort(min_distances)
        filtered_samples = filtered_samples[indices]
        filtered_indices = filtered_indices[indices]

    return filtered_samples, filtered_indices.tolist()

In [4]:
# Create a OOD dataset
n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = True
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
)

Generating images (sampling: off_diagonal): 100%|██████████| 10000/10000 [01:31<00:00, 109.61it/s]


In [17]:
# Checking the generated dataset
def display_data(index, dataset):
    plt.imshow(dataset[index][0][-1].permute(1, 2, 0))
    plt.show()


dispaly_off_diagonal = lambda index: display_data(index, dataset=off_diagonal_dataset)

num_samples = len(off_diagonal_dataset)

# slider
widgets.interact(
    dispaly_off_diagonal,
    index=widgets.IntSlider(min=0, max=num_samples - 1, step=1, value=0),
)

interactive(children=(IntSlider(value=0, description='index', max=9999), Output()), _dom_classes=('widget-inte…

<function __main__.<lambda>(index)>

In [19]:
# Filter the dataset
n_objects = 5000
_, indicies = filter_objects(
    off_diagonal_dataset.z, max_samples=n_objects, threshold=0.2
)

# save the filtered dataset
ood_data_path = "YOUR PATH"

os.makedirs(os.path.join(path, "images"), exists=True)
os.makedirs(os.path.join(path, "latents"), exists=True)
torch.save(off_diagonal_dataset.x[indicies], os.path.join(ood_data_path, "images", "images.pt"))
torch.save(
    torch.cat(
        [
            off_diagonal_dataset.z[indicies, :, :4],
            off_diagonal_dataset.z[indicies, :, 5:-2],
        ],
        dim=-1,
    ),
    os.path.join(path, "latents", "latents.pt"),
)

In [14]:
# Checking the generated dataset
no_overlaps_ood = PreGeneratedDataset(data_path)

dispaly_no_overlaps_ood = lambda index: display_data(index, dataset=no_overlaps_ood)

# slider
widgets.interact(
    dispaly_no_overlaps_ood,
    index=widgets.IntSlider(min=0, max=n_objects - 1, step=1, value=0),
)

interactive(children=(IntSlider(value=0, description='index', max=4999), Output()), _dom_classes=('widget-inte…

<function __main__.<lambda>(index)>

In [None]:
# Generating the ID dataset (test)
delta = 0.125
sample_mode = "diagonal"
n_slots = 2
n_samples = 5000
no_overlap = True
test_diagonal_dataset = data.SpriteWorldDataset(n_samples, n_slots, default_cfg, sample_mode=sample_mode, 
                                            no_overlap=no_overlap,
                                            delta=delta)

utils.dump_generated_dataset(test_diagonal_dataset, "your path to test diagonal")

In [None]:
# Generating the ID dataset (train)
delta = 0.125
sample_mode = "diagonal"
n_slots = 2
n_samples = 100000
no_overlap = True
train_diagonal_dataset = data.SpriteWorldDataset(n_samples, n_slots, default_cfg, sample_mode=sample_mode, 
                                            no_overlap=no_overlap,
                                            delta=delta)

utils.dump_generated_dataset(train_diagonal_dataset, "your path to train diagonal")