In [3]:
import sys

sys.path.append("..")

import numpy as np
import torch
import random
from scipy.spatial import distance
from tqdm import tqdm
import ipywidgets as widgets

import matplotlib.pyplot as plt

from src.datasets import data, utils, configs
from src.datasets.utils import dump_generated_dataset
from src.metrics import hungarian_slots_loss
from src.utils.training_utils import sample_z_from_latents


from torchvision import transforms as transforms

import imageio

seed = 43
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x187c80db510>

In [15]:
# Step 1: Create a diagonal dataset, to get valid latents
n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "diagonal"
no_overlap = True
delta = 0.125

diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
#     z=new_z
)

Generating images (sampling: diagonal): 100%|███████████████████████████████████| 10000/10000 [01:27<00:00, 113.84it/s]


In [9]:
torch.save(diagonal_dataset.x, "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/random/mixed/images/images_2.pt")

In [10]:
torch.save(torch.cat([diagonal_dataset.z[:, :, :4], diagonal_dataset.z[:, :, 5:-2]], dim=-1), "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/random/mixed/latents/latents_2.pt")

In [14]:
def display_data(index, dataset):
    plt.imshow(dataset[index][0][-1].permute(1, 2, 0))
    plt.show()

In [15]:
dispaly_diagonal = lambda index: display_data(index, dataset=diagonal_dataset)

num_samples = len(diagonal_dataset)

# slider
widgets.interact(dispaly_diagonal, index=widgets.IntSlider(min=0, max=num_samples-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='index', max=33332), Output()), _dom_classes=('widget-int…

In [11]:
new_z, indicies = filter_objects(diagonal_dataset.z, max_objects=10000, threshold=0.2)

In [12]:
print(new_z.shape)

torch.Size([10000, 2, 8])


In [16]:
torch.save(diagonal_dataset.z, "D:/mnt/qb/heatmap_dataset/initial_id_latents.pt")

In [33]:
diagonal_dataset = data.SpriteWorldDataset(
    len(new_z),
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=new_z
)

Delta is too big for 'no_overlap' mode, setting it to 0.08333333333333333.


Generating images (sampling: diagonal): 100%|██████████████████████████████████████| 2673/2673 [00:38<00:00, 68.78it/s]


In [3]:
# Step 2: Create a off_diagonal dataset, to get valid latents
n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])
)

# Extract ood_latents and replace their x and y coordinate by diagonal latens
no_overlap_z = off_diagonal_dataset.z.clone()
no_overlap_z[:, :, :1] = diagonal_dataset.z[:, :, :1]

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:44<00:00, 111.91it/s]


In [33]:
# Step 3: Make OOD no_overlap dataset

n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

no_overlap_off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=no_overlap_z
)

Generating images (sampling: off_diagonal):  50%|████████████████▌                | 2501/5000 [00:23<00:23, 106.27it/s]


KeyboardInterrupt: 

In [4]:
# Step 2 - alternative:
# Generate off_diagonal dataset and reject the pairs with close x-coordinates

n_samples = 10000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()])
)


Generating images (sampling: off_diagonal): 100%|███████████████████████████████| 10000/10000 [01:32<00:00, 107.55it/s]


In [10]:
# Step 2.1 filtering
no_overlap_z = filter_objects(off_diagonal_dataset.z, max_objects=5000, threshold=0.3)

Filtering objects:  75%|█████████████████████████████████████████▍             | 7524/10000 [00:00<00:00, 20246.88it/s]


In [11]:
# Step 3 - alternative, filter the latents and create new data

n_samples = 5000
n_slots = 2
default_cfg = configs.SpriteWorldConfig()
sample_mode = "off_diagonal"
no_overlap = False
delta = 0.125

no_overlap_off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=no_overlap_z
)

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:38<00:00, 128.89it/s]


In [81]:

num_samples = len(no_overlap_off_diagonal_dataset)

# Convert images to numpy arrays and add them to the images list
images = []
for i in range(num_samples):
    img = no_overlap_off_diagonal_dataset[i][0][-1].permute(1, 2, 0).numpy()
    
    # Scaling the image data to [0, 255] and convert to uint8
    img = (img * 255).astype(np.uint8)
    
    images.append(img)

# Save images as a GIF
imageio.mimsave('output.gif', images)

In [2]:
def filter_objects(latents, max_objects=5000, threshold=0.3, sort=False):
    """
    Filter objects based on their Euclidean distance.
    Args:
        latents: Tensor of shape (batch_size, n_slots, n_latents)
        max_objects: Number of objects to keep at most
        threshold: Distance threshold
        sort: Whether to sort the objects by distance
    """
    N, slots, _ = latents.size()
    mask = torch.zeros(N, dtype=bool)

    # Compute Euclidean distance for each pair of slots in each item
    for n in range(N):
        slots_distances = torch.cdist(latents[n, :, :2], latents[n, :, :2], p=2)
        slots_distances.fill_diagonal_(float('inf'))  # Ignore distance to self

        # Consider an object as "close" if its minimal distance to any other object is below the threshold
        min_distance = slots_distances.min().item()
        if min_distance >= threshold:
            mask[n] = True

    # If all objects are "close", print a message and return
    if not torch.any(mask):
        print("No objects were found that meet the distance threshold.")
        return None, []

    # Apply the mask to the latents
    filtered_objects = latents[mask]
    filtered_indices = torch.arange(N)[mask]

    # If the number of filtered objects exceeds the maximum, truncate them
    if filtered_objects.size(0) > max_objects:
        filtered_objects = filtered_objects[:max_objects]
        filtered_indices = filtered_indices[:max_objects]

    if sort:
        # Sort the filtered objects by minimum distance to any other object
        min_distances = torch.zeros(mask.sum().item())
        for i, n in enumerate(torch.where(mask)[0]):
            slots_distances = torch.cdist(latents[n], latents[n], p=2)
            slots_distances.fill_diagonal_(float('inf'))
            min_distances[i] = slots_distances.min().item()

        indices = torch.argsort(min_distances)
        filtered_objects = filtered_objects[indices]
        filtered_indices = filtered_indices[indices]

    return filtered_objects, filtered_indices.tolist()



def sample_z_from_latents_no_overlap(
    gt_z, hat_z, gt_figures, hat_figures, device, n_samples=1024
):
    _, transposed_indices = hungarian_slots_loss(
        gt_figures.view(gt_figures.shape[0], gt_figures.shape[1], -1),
        hat_figures.view(hat_figures.shape[0], hat_figures.shape[1], -1),
        device=device,
    )

    transposed_indices = transposed_indices.to(device)

    hat_z_permuted = hat_z.gather(
        1,
        transposed_indices[:, :, 1].unsqueeze(-1).expand(-1, -1, hat_z.shape[-1]),
    )
    gt_z_flatten = gt_z.view(-1, gt_z.shape[2])
    z_sampled, indices = sample_z_from_latents(hat_z_permuted.detach(), n_samples=20000)

    # reshape z_flatten with indices
    z_flatten = gt_z_flatten[indices].reshape(-1, gt_z.shape[1], gt_z.shape[2])

    z_sampled, selected_pairs_indices = filter_objects(z_flatten, max_objects=n_samples, threshold=0.3)

    return z_sampled

In [29]:
z_sampled = sample_z_from_latents_no_overlap(diagonal_dataset.z, diagonal_dataset.z, diagonal_dataset.x[:, :-1, ...], diagonal_dataset.x[:, :-1, ...], "cpu", n_samples=5000)

In [30]:
z_sampled.shape

torch.Size([5000, 2, 8])

In [31]:
test = data.SpriteWorldDataset(
    len(z_sampled),
    n_slots,
    default_cfg,
    sample_mode="off_diagonal",
    no_overlap=True,
    delta=delta,
    transform=transforms.Compose(
    [transforms.ToPILImage(), transforms.ToTensor()]),
    z=z_sampled
)

Generating images (sampling: off_diagonal): 100%|█████████████████████████████████| 5000/5000 [00:49<00:00, 100.02it/s]


In [32]:
def display_data(index):
    plt.imshow(test[index][0][-1].permute(1, 2, 0))
    plt.show()

num_samples = len(test)

# slider
widgets.interact(display_data, index=widgets.IntSlider(min=0, max=num_samples-1, step=1, value=0));


interactive(children=(IntSlider(value=0, description='index', max=4999), Output()), _dom_classes=('widget-inte…

In [45]:
num_samples = len(z_sampled)

# Convert images to numpy arrays and add them to the images list
images = []
for i in range(num_samples):
    img = test[i][0][-1].permute(1, 2, 0).numpy()
    
    # Scaling the image data to [0, 255] and convert to uint8
    img = (img * 255).astype(np.uint8)
    
    images.append(img)

# Save images as a GIF
imageio.mimsave('output.gif', images)

In [3]:
dump_generated_dataset(diagonal_dataset, "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/diagonal/4_objects")

5000it [00:00, 6555.41it/s]


In [9]:
path = "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/train/diagonal/4_objects/latents/latents.pt"

In [40]:
latents_4 = torch.load("D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/off_diagonal/4_objects/latents/latents.pt")

In [41]:
torch.save(latents_4[:1666, ...], "D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/test/off_diagonal/mixed/latents/latents_4.pt")

In [4]:
import os

In [5]:
os.listdir('D:/mnt/qb/work/bethge/apanfilov27/object_centric_consistency_project/dsprites/train/diagonal/4_objects/latents')

['latens.pt']

In [11]:
os.path.exists(path)

False

In [2]:
import torch
from scipy.optimize import linear_sum_assignment
import numpy as np

In [31]:
a = torch.tensor([[[0.9041,  0.0196], [-0.3108, 0]], [[-0.4821,  1.059], [-0.4821,  1.059]]])
b = torch.tensor([[[-2.1763, -0.4713], [-0.6986,  1.3702]], [[-0.4821,  1.059], [-0.4821,  1.059]]])


pairwise_cost = torch.cdist(a, b, p=2).transpose(-1, -2)


In [32]:
pairwise_cost

tensor([[[3.1193, 1.9241],
         [2.0959, 1.4240]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]]])

In [33]:
indices = np.array(
        list(map(linear_sum_assignment, pairwise_cost))
    )

In [29]:
map(linear_sum_assignment, pairwise_cost)

<map at 0x1d93ae0f670>

In [34]:
cost = np.array([[np.inf, np.inf, 4], 
                 [np.inf, np.inf, np.inf], 
                 [np.inf, 4, np.inf]])
row_ind, col_ind = linear_sum_assignment(cost)
print(cost)
print(col_ind)

ValueError: cost matrix is infeasible

In [None]:
import torch

# Load tensors
tensor1 = torch.load('images_2.pt')
tensor2 = torch.load('images_3.pt')
tensor3 = torch.load('images_4.pt')

# Make sure the tensors are in the correct shape
assert tensor1.shape == (33333, 3, 3, 64, 64)
assert tensor2.shape == (33333, 4, 3, 64, 64)
assert tensor3.shape == (33333, 5, 3, 64, 64)

# Create tensors of zeros with the desired final shape
zeros_1 = torch.zeros((33333, 4, 3, 64, 64))
zeros_2 = torch.zeros((33333, 4, 3, 64, 64))
zeros_3 = torch.zeros((33333, 4, 3, 64, 64))

# Fill in the parts of the zeros tensor with the loaded tensors
zeros_1[:, :2, ] = tensor1
zeros_2[:, :3, :] = tensor2
zeros_3[:, :4, :] = tensor3

# Concatenate along the first dimension
concat_tensor = torch.cat((zeros_1, zeros_2, zeros_3), dim=0)

# Check the shape of the result
assert concat_tensor.shape == (99999, 4, 5)

# Save the new tensor
torch.save(concat_tensor, 'images.pt')
