In [10]:
import sys

sys.path.append("..")

import numpy as np
import torch
import random
from scipy.spatial import distance
from tqdm import tqdm
import ipywidgets as widgets

import matplotlib.pyplot as plt

from src.datasets import data, utils, configs
from src.datasets.utils import dump_generated_dataset, PreGeneratedDataset
from src.metrics import hungarian_slots_loss
from src.utils.training_utils import sample_z_from_latents


from torchvision import transforms as transforms

import imageio

seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x2232f097510>

In [2]:
def filter_objects(latents, max_objects=5000, threshold=0.3, sort=False):
    """
    Filter objects based on their Euclidean distance.
    Args:
        latents: Tensor of shape (batch_size, n_slots, n_latents)
        max_objects: Number of objects to keep at most
        threshold: Distance threshold
        sort: Whether to sort the objects by distance
    """
    N, slots, _ = latents.size()
    mask = torch.zeros(N, dtype=bool)

    # Compute Euclidean distance for each pair of slots in each item
    for n in range(N):
        slots_distances = torch.cdist(latents[n, :, :2], latents[n, :, :2], p=2)
        slots_distances.fill_diagonal_(float('inf'))  # Ignore distance to self

        # Consider an object as "close" if its minimal distance to any other object is below the threshold
        min_distance = slots_distances.min().item()
        if min_distance >= threshold:
            mask[n] = True

    # If all objects are "close", print a message and return
    if not torch.any(mask):
        print("No objects were found that meet the distance threshold.")
        return None, []

    # Apply the mask to the latents
    filtered_objects = latents[mask]
    filtered_indices = torch.arange(N)[mask]

    # If the number of filtered objects exceeds the maximum, truncate them
    if filtered_objects.size(0) > max_objects:
        filtered_objects = filtered_objects[:max_objects]
        filtered_indices = filtered_indices[:max_objects]

    if sort:
        # Sort the filtered objects by minimum distance to any other object
        min_distances = torch.zeros(mask.sum().item())
        for i, n in enumerate(torch.where(mask)[0]):
            slots_distances = torch.cdist(latents[n], latents[n], p=2)
            slots_distances.fill_diagonal_(float('inf'))
            min_distances[i] = slots_distances.min().item()

        indices = torch.argsort(min_distances)
        filtered_objects = filtered_objects[indices]
        filtered_indices = filtered_indices[indices]

    return filtered_objects, filtered_indices.tolist()



In [3]:
# Step 1: Create a diagonal dataset, to get valid latents
n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = True
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
)

Generating images (sampling: off_diagonal): 100%|███████████████████████████████| 10000/10000 [01:25<00:00, 117.23it/s]


In [4]:
def display_data(index, dataset):
    plt.imshow(dataset[index][0][-1].permute(1, 2, 0))
    plt.show()

In [5]:
dispaly_off_diagonal = lambda index: display_data(index, dataset=off_diagonal_dataset)

num_samples = len(off_diagonal_dataset)

# slider
widgets.interact(dispaly_off_diagonal, index=widgets.IntSlider(min=0, max=num_samples-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='index', max=9999), Output()), _dom_classes=('widget-inte…

In [6]:
_, indicies = filter_objects(off_diagonal_dataset.z, max_objects=5000, threshold=0.2)

In [7]:
torch.save(off_diagonal_dataset.x[indicies], "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/2_objects/images/images.pt")

In [12]:
torch.save(torch.cat([off_diagonal_dataset.z[indicies, :, :4], off_diagonal_dataset.z[indicies, :, 5:-2]], dim=-1), "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/2_objects/latents/latents.pt")

In [13]:
no_overlaps_ood = PreGeneratedDataset("D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/new_ood/2_objects")

In [14]:
dispaly_no_overlaps_ood = lambda index: display_data(index, dataset=no_overlaps_ood)

num_samples = len(no_overlaps_ood)

# slider
widgets.interact(dispaly_no_overlaps_ood, index=widgets.IntSlider(min=0, max=num_samples-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='index', max=4999), Output()), _dom_classes=('widget-inte…

In [9]:
torch.save(diagonal_dataset.x, "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/random/mixed/images/images_2.pt")

In [10]:
torch.save(torch.cat([diagonal_dataset.z[:, :, :4], diagonal_dataset.z[:, :, 5:-2]], dim=-1), "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/random/mixed/latents/latents_2.pt")

torch.Size([10000, 2, 8])


In [16]:
torch.save(diagonal_dataset.z, "D:/mnt/qb/heatmap_dataset/initial_id_latents.pt")

In [33]:
diagonal_dataset = data.SpriteWorldDataset(
    len(new_z),
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=new_z
)

Delta is too big for 'no_overlap' mode, setting it to 0.08333333333333333.


Generating images (sampling: diagonal): 100%|██████████████████████████████████████| 2673/2673 [00:38<00:00, 68.78it/s]


In [3]:
# Step 2: Create a off_diagonal dataset, to get valid latents
n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])
)

# Extract ood_latents and replace their x and y coordinate by diagonal latens
no_overlap_z = off_diagonal_dataset.z.clone()
no_overlap_z[:, :, :1] = diagonal_dataset.z[:, :, :1]

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:44<00:00, 111.91it/s]


In [33]:
# Step 3: Make OOD no_overlap dataset

n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

no_overlap_off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=no_overlap_z
)

Generating images (sampling: off_diagonal):  50%|████████████████▌                | 2501/5000 [00:23<00:23, 106.27it/s]


KeyboardInterrupt: 

In [4]:
# Step 2 - alternative:
# Generate off_diagonal dataset and reject the pairs with close x-coordinates

n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])
)


Generating images (sampling: off_diagonal): 100%|███████████████████████████████| 10000/10000 [01:32<00:00, 107.55it/s]


In [10]:
# Step 2.1 filtering
no_overlap_z = filter_objects(off_diagonal_dataset.z, max_objects=5000, threshold=0.3)

Filtering objects:  75%|█████████████████████████████████████████▍             | 7524/10000 [00:00<00:00, 20246.88it/s]


In [11]:
# Step 3 - alternative, filter the latents and create new data

n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

no_overlap_off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=no_overlap_z
)

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:38<00:00, 128.89it/s]


In [81]:

num_samples = len(no_overlap_off_diagonal_dataset)

# Convert images to numpy arrays and add them to the images list
images = []
for i in range(num_samples):
    img = no_overlap_off_diagonal_dataset[i][0][-1].permute(1, 2, 0).numpy()
    
    # Scaling the image data to [0, 255] and convert to uint8
    img = (img * 255).astype(np.uint8)
    
    images.append(img)

# Save images as a GIF
imageio.mimsave('output.gif', images)

In [31]:
test = data.SpriteWorldDataset(
    len(z_sampled),
    n_slots,
    default_cfg,
    sample_mode="off_diagonal",
    no_overlap=True,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=z_sampled
)

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:49<00:00, 100.02it/s]


In [32]:
def display_data(index):
    plt.imshow(test[index][0][-1].permute(1, 2, 0))
    plt.show()

num_samples = len(test)

# slider
widgets.interact(display_data, index=widgets.IntSlider(min=0, max=num_samples-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='index', max=4999), Output()), _dom_classes=('widget-inte…

In [45]:
num_samples = len(z_sampled)

# Convert images to numpy arrays and add them to the images list
images = []
for i in range(num_samples):
    img = test[i][0][-1].permute(1, 2, 0).numpy()
    
    # Scaling the image data to [0, 255] and convert to uint8
    img = (img * 255).astype(np.uint8)
    
    images.append(img)

# Save images as a GIF
imageio.mimsave('output.gif', images)

In [3]:
dump_generated_dataset(diagonal_dataset, "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/diagonal/4_objects")

5000it [00:00, 6555.41it/s]


In [9]:
path = "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/train/diagonal/4_objects/latents/latents.pt"

In [40]:
latents_4 = torch.load("D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/off_diagonal/4_objects/latents/latents.pt")

In [41]:
torch.save(latents_4[:1666, ...], "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/off_diagonal/mixed/latents/latents_4.pt")

In [16]:
def sample_delta_diagonal_cube(
    n_samples: int, n_slots: int, n_latents: int, delta: float, oversampling: int = 100
) -> torch.Tensor:
    """
    Sample near the diagonal in latent space i.e. all distances from the diagonal are less than delta.

    Algorithm:
        1. Draw points on the diagonal of [0, 1)^(n_slots, n_latents) cube.
        2. For every latent draw uniformly noise from n_slots-dimensional ball. For drawing uniformly inside the ball we
            use the following theorem (http://compneuro.uwaterloo.ca/files/publications/voelker.2017.pdf):
            if point uniformly sampled from the (n+1)-sphere, then n-first coordinates are uniformly sampled from the n-ball.
        3. Project sampled inside-ball points to the hyperplane perpendicular to the diagonal and normalize them
            (this gives us points on (n_slots-2)-sphere embedded in n_slots-space).
        4. Get final points by adding the diagonal point to the projected points.
        5. Keep only points inside the [0, 1)^(n_slots, n_latents) cube.
    """
    _n = oversampling * n_samples
    z_out = torch.Tensor(0, n_slots, n_latents)
    while z_out.shape[0] < n_samples:
        # sample randomly on diagonal
        z_sampled = torch.repeat_interleave(
            torch.rand(_n, n_latents), n_slots, dim=0
        ).reshape(_n, n_slots, n_latents)

        # sample noise from n_slots-ball
        noise = torch.randn(_n, n_slots + 2, n_latents)
        noise = noise / torch.norm(noise, dim=1, keepdim=True)  # points on n-sphere
        noise = noise[:, :n_slots, :]  # remove two last points

        # project to hyperplane perpendicular to diagonal
        ort_vec = noise - z_sampled * (noise * z_sampled).sum(axis=1, keepdim=True) / (
            z_sampled * z_sampled
        ).sum(axis=1, keepdim=True)
        ort_vec /= torch.norm(ort_vec, p=2, dim=1, keepdim=True)

        # final step
        # why n - 1 here? because we sample
        # "radius" not in the original space, but in the embedded
        final = z_sampled + (
            ort_vec
            * torch.pow(torch.rand([_n, 1, n_latents]), 1 / (n_slots - 1))
            * delta
        )

        # only keep samples inside [0, 1]^{k×l}
        mask = ((final - 0.5).abs() <= 0.5).flatten(1).all(1)
        idx = mask.nonzero().squeeze(1)

        z_out = torch.cat([z_out, final[idx]])
    z_out = z_out[:n_samples]
    return z_out[:n_samples]


In [17]:
n = 10000
n_slots = 2
n_latents = 5
delta = 0.5  # in [0 .. sqrt(n_slots)/2]

# check that ID samples don't contain any OOD samples
z_ID = sample_delta_diagonal_cube(n, n_slots, n_latents, delta).numpy()

diag_unit = np.ones(n_slots) / np.sqrt(n_slots)
# calculate the projection onto the diagonal (and from there the distance) along the
# `slots`-dimension since one diagonal contains the i-th latent of each slot
z_diag_scalar_component = np.dot(z_ID.transpose(0, 2, 1), diag_unit)
z_diag_component = z_diag_scalar_component[:, None, :] * diag_unit[None, :, None]
z_orth_component = z - z_diag_component
z_orth_component_norm = np.linalg.norm(z_orth_component, axis=1)

mask_ID = np.all(z_orth_component_norm <= delta, axis=1)
mask_OOD_any = np.any(z_orth_component_norm > delta, axis=1)
# mask_OOD_all = np.all(z_orth_component_norm > delta, axis=1)

n_ID = mask_ID.sum()
n_OOD_any = mask_OOD_any.sum()
# n_OOD_all = mask_OOD_all.sum()

print(n_ID, n_OOD_any, n_ID + n_OOD_any)

NameError: name 'z' is not defined