In [1]:
%reload_ext autoreload
%autoreload 2

import sys
sys.path.append("..")

import os
import torch
from torchvision import transforms as transforms


from src.datasets import data, configs, utils
from src.utils import training_utils

training_utils.set_seed(42)

The dataset for the paper was generated in a similar way using the code below.
Note that `delta` in this code is a distance from the diagonal in the latent space.
However, in the paper `delta` describes the width of the band around the diagonal, 
meaning that `2 * delta_in_code = delta_in_paper`.

### Setting the path to the data

In [2]:
SAVE_PATH = "data_path"  # replace with your path
SAVE_PATH = os.path.join(SAVE_PATH, "dsprites")

os.makedirs(SAVE_PATH, exist_ok=True)

For training it is necessary to have the dataset in the following structure:

```
data_path
└── dsprites
   ├── train
   │   └── diagonal
   │       ├── images
   │       │   └── images.pt
   │       └── latents
   │           └── latents.pt
   └── test
       ├── diagonal
       │   ├── images
       │   │   └── images.pt
       │   └── latents
       │       └── latents.pt
       └── no_overlap_off_diagonal
           ├── images
           │   └── images.pt
           └── latents
               └── latents.pt
```

### Generating in-diagonal training data

In [3]:
default_cfg = configs.SpriteWorldConfig()

delta = 0.125
sample_mode = "diagonal"
n_slots = 2
n_samples = 100
no_overlap = True
train_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
)

# making the folder for the dataset
TRAIN_PATH = os.path.join(SAVE_PATH, "train", f"{sample_mode}")
os.makedirs(TRAIN_PATH, exist_ok=True)

# saving the dataset
utils.dump_generated_dataset(train_diagonal_dataset, TRAIN_PATH)

Generating images (sampling: diagonal): 100%|██████████| 100/100 [00:00<00:00, 199.00it/s]
100it [00:00, 11343.93it/s]


### Generating in-diagonal test dataset

In [4]:
default_cfg = configs.SpriteWorldConfig()

delta = 0.125
sample_mode = "diagonal"
n_slots = 2
n_samples = 100
no_overlap = True
test_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
)

# making the folder for the dataset
TEST_PATH = os.path.join(SAVE_PATH, "test", f"{sample_mode}")
os.makedirs(TEST_PATH, exist_ok=True)

# saving the dataset
utils.dump_generated_dataset(test_diagonal_dataset, TEST_PATH)

Generating images (sampling: diagonal): 100%|██████████| 100/100 [00:00<00:00, 190.44it/s]
100it [00:00, 16881.20it/s]


### Generating off-diagonal test dataset

In [5]:
default_cfg = configs.SpriteWorldConfig()

delta = 0.125
sample_mode = "off_diagonal"
n_slots = 2
n_samples = 100
no_overlap = True # this ensures that this case is the opposite of the diagonal case, but it is not accounting for overlaps

off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
)

# making the folder for the dataset
TEST_PATH = os.path.join(SAVE_PATH, "test", f"{sample_mode}")
os.makedirs(TEST_PATH, exist_ok=True)

# saving the dataset
utils.dump_generated_dataset(off_diagonal_dataset, TEST_PATH)

Generating images (sampling: off_diagonal): 100%|██████████| 100/100 [00:00<00:00, 180.10it/s]
100it [00:00, 13297.52it/s]


### (Optional) Excluding overlapping OOD samples

For the paper, we excluded the overlapping OOD samples from the test dataset.
You might want to filter the dataset we created in the previous step, or create a new one
and save it under the "no_overlap_off_diagonal" directory.

In [6]:
default_cfg = configs.SpriteWorldConfig()

delta = 0.125
sample_mode = "off_diagonal"
n_slots = 2
n_samples = 100
no_overlap = True  # this ensures that this case is the opposite of the diagonal case, but it is not accounting for overlaps

no_overlap_off_diagonal_dataset = data.SpriteWorldDataset(
    n_samples,
    n_slots,
    default_cfg,
    sample_mode=sample_mode,
    no_overlap=no_overlap,
    delta=delta,
    transform=transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()]),
)

n_objects = 50
_, indicies = utils.filter_objects(
    no_overlap_off_diagonal_dataset.z, max_samples=n_objects, threshold=0.2
)

# making the folder for the dataset
TEST_PATH = os.path.join(SAVE_PATH, "test", "no_overlap_off_diagonal")
os.makedirs(TEST_PATH, exist_ok=True)
os.makedirs(os.path.join(TEST_PATH, "images"), exist_ok=True)
os.makedirs(os.path.join(TEST_PATH, "latents"), exist_ok=True)

# saving the dataset
torch.save(
    off_diagonal_dataset.x[indicies], os.path.join(TEST_PATH, "images", "images.pt")
)
torch.save(
    torch.cat(
        [
            off_diagonal_dataset.z[indicies, :, :4],
            off_diagonal_dataset.z[indicies, :, 5:-2],
        ],
        dim=-1,
    ),
    os.path.join(TEST_PATH, "latents", "latents.pt"),
)

Generating images (sampling: off_diagonal): 100%|██████████| 100/100 [00:00<00:00, 183.26it/s]
