In [10]:
import sys
import os

sys.path.append("..")

import numpy as np
import torch
import random
import ipywidgets as widgets

import matplotlib.pyplot as plt

from src.datasets import data, configs
from src.datasets.utils import PreGeneratedDataset

from torchvision import transforms as transforms

seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x2232f097510>

In [2]:
# Filtering code, to remove overlapping objects OOD
def filter_objects(latents, max_objects=5000, threshold=0.2, sort=False):
    """
    Filter objects based on their Euclidean distance.
    Args:
        latents: Tensor of shape (batch_size, n_slots, n_latents)
        max_objects: Number of objects to keep at most
        threshold: Distance threshold
        sort: Whether to sort the objects by distance
    """
    N, slots, _ = latents.size()
    mask = torch.zeros(N, dtype=bool)

    # Compute Euclidean distance for each pair of slots in each item
    for n in range(N):
        slots_distances = torch.cdist(latents[n, :, :2], latents[n, :, :2], p=2)
        slots_distances.fill_diagonal_(float("inf"))  # Ignore distance to self

        # Consider an object as "close" if its minimal distance to any other object is below the threshold
        min_distance = slots_distances.min().item()
        if min_distance >= threshold:
            mask[n] = True

    # If all objects are "close", print a message and return
    if not torch.any(mask):
        print("No objects were found that meet the distance threshold.")
        return None, []

    # Apply the mask to the latents
    filtered_objects = latents[mask]
    filtered_indices = torch.arange(N)[mask]

    # If the number of filtered objects exceeds the maximum, truncate them
    if filtered_objects.size(0) > max_objects:
        filtered_objects = filtered_objects[:max_objects]
        filtered_indices = filtered_indices[:max_objects]

    if sort:
        # Sort the filtered objects by minimum distance to any other object
        min_distances = torch.zeros(mask.sum().item())
        for i, n in enumerate(torch.where(mask)[0]):
            slots_distances = torch.cdist(latents[n], latents[n], p=2)
            slots_distances.fill_diagonal_(float("inf"))
            min_distances[i] = slots_distances.min().item()

        indices = torch.argsort(min_distances)
        filtered_objects = filtered_objects[indices]
        filtered_indices = filtered_indices[indices]

    return filtered_objects, filtered_indices.tolist()

In [3]:
# Step 1: Create a OOD dataset
n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = True
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
)

Generating images (sampling: off_diagonal): 100%|███████████████████████████████| 10000/10000 [01:25<00:00, 117.23it/s]


In [4]:
# Checking the generated dataset
def display_data(index, dataset):
    plt.imshow(dataset[index][0][-1].permute(1, 2, 0))
    plt.show()


dispaly_off_diagonal = lambda index: display_data(index, dataset=off_diagonal_dataset)

num_samples = len(off_diagonal_dataset)

# slider
widgets.interact(
    dispaly_off_diagonal,
    index=widgets.IntSlider(min=0, max=num_samples - 1, step=1, value=0),
)

In [6]:
# Step 2: Filter the dataset
n_objects = 5000
_, indicies = filter_objects(
    off_diagonal_dataset.z, max_objects=n_objects, threshold=0.2
)

# save the filtered dataset
path = "mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/2_objects/"
torch.save(off_diagonal_dataset.x[indicies], os.join(path, "images", "images.pt"))
torch.save(
    torch.cat(
        [
            off_diagonal_dataset.z[indicies, :, :4],
            off_diagonal_dataset.z[indicies, :, 5:-2],
        ],
        dim=-1,
    ),
    os.join(path, "latents", "latents.pt"),
)

In [13]:
# Checking the generated dataset

no_overlaps_ood = PreGeneratedDataset(
    "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/2_objects"
)

dispaly_no_overlaps_ood = lambda index: display_data(index, dataset=no_overlaps_ood)

# slider
widgets.interact(
    dispaly_no_overlaps_ood,
    index=widgets.IntSlider(min=0, max=n_objects - 1, step=1, value=0),
)

In [None]:
# Generating the OOD dataset for 2-4 objects
n_samples = 10000
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

for n_slots in [2, 3, 4]:
    off_diagonal_dataset = data.SpriteWorldDataset(
        n_samples,
        n_slots,
        default_cfg,
        sample_mode=sample_mode,
        no_overlap=no_overlap,
        delta=delta,
        transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
    )
    n_objects = 1666
    _, indicies = filter_objects(
        off_diagonal_dataset.z, max_objects=n_objects, threshold=0.2
    )
    path = "mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/mixed"

    torch.save(
        off_diagonal_dataset.x[indicies],
        os.join(path, "images", f"images_{n_slots}.pt"),
    )
    torch.save(
        torch.cat(
            [
                off_diagonal_dataset.z[indicies, :, :4],
                off_diagonal_dataset.z[indicies, :, 5:-2],
            ],
            dim=-1,
        ),
        os.join(path, "latents", f"latents_{n_slots}.pt"),
    )