Copyright 2017 Google Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# dSprites - Disentanglement testing Sprites dataset

## Description
Procedurally generated 2D shapes dataset. This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite (color isn't varying here, its value is fixed).

All possible combinations of the latents are present.

The ordering of images in the dataset (i.e. shape[0] in all ndarrays) is fixed and meaningful, see below.

We chose the smallest changes in latent values that generated different pixel outputs at our 64x64 resolution after rasterization.

No noise added, single image sample for a given latent setting.

## Details about the ordering of the dataset

The dataset was generated procedurally, and its order is deterministic.
For example, the image at index 0 corresponds to the latents (0, 0, 0, 0, 0, 0).

Then the image at index 1 increases the least significant "bit" of the latent:
(0, 0, 0, 0, 0, 1)

And similarly, till we reach index 32, where we get (0, 0, 0, 0, 1, 0). 

Hence the dataset is sequentially addressable using variable bases for every "bit".
Using dataset['metadata']['latents_sizes'] makes this conversion trivial, see below.

In [16]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns

# Change figure aesthetics
%matplotlib inline
sns.set_context('talk', font_scale=1.2, rc={'lines.linewidth': 1.5})


In [17]:
# Load dataset
dataset_zip = np.load('dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', allow_pickle=True, encoding='latin1')

print('Keys in the dataset:', dataset_zip.keys())
imgs = dataset_zip['imgs']
latents_values = dataset_zip['latents_values']
latents_classes = dataset_zip['latents_classes']
metadata = dataset_zip['metadata'][()]

print('Metadata: \n', metadata)


Keys in the dataset: KeysView(NpzFile 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' with keys: metadata, imgs, latents_classes, latents_values)
Metadata: 
 {'date': 'April 2017', 'description': 'Disentanglement test Sprites dataset.Procedurally generated 2D shapes, from 6 disentangled latent factors.This dataset uses 6 latents, controlling the color, shape, scale, rotation and position of a sprite. All possible variations of the latents are present. Ordering along dimension 1 is fixed and can be mapped back to the exact latent values that generated that image.We made sure that the pixel outputs are different. No noise added.', 'version': 1, 'latents_names': ('color', 'shape', 'scale', 'orientation', 'posX', 'posY'), 'latents_possible_values': {'orientation': array([0.        , 0.16110732, 0.32221463, 0.48332195, 0.64442926,
       0.80553658, 0.96664389, 1.12775121, 1.28885852, 1.44996584,
       1.61107316, 1.77218047, 1.93328779, 2.0943951 , 2.25550242,
       2.41660973, 2.577717

In [18]:
# Define number of values per latents and functions to convert to indices
latents_sizes = metadata['latents_sizes']
latents_bases = np.concatenate((latents_sizes[::-1].cumprod()[::-1][1:],
                                np.array([1,])))

def latent_to_index(latents):
  return np.dot(latents, latents_bases).astype(int)


def sample_latent(size=1):
  samples = np.zeros((size, latents_sizes.size))
  for lat_i, lat_size in enumerate(latents_sizes):
    samples[:, lat_i] = np.random.randint(lat_size, size=size)

  return samples


In [19]:
# Helper function to show images
def show_images_grid(imgs_, num_images=25):
  ncols = int(np.ceil(num_images**0.5))
  nrows = int(np.ceil(num_images / ncols))
  _, axes = plt.subplots(ncols, nrows, figsize=(nrows * 3, ncols * 3))
  axes = axes.flatten()

  for ax_i, ax in enumerate(axes):
    if ax_i < num_images:
      ax.imshow(imgs_[ax_i], cmap='Greys_r',  interpolation='nearest')
      ax.set_xticks([])
      ax.set_yticks([])
    else:
      ax.axis('off')

def show_density(imgs):
  _, ax = plt.subplots()
  ax.imshow(imgs.mean(axis=0), interpolation='nearest', cmap='Greys_r')
  ax.grid('off')
  ax.set_xticks([])
  ax.set_yticks([])

('color', 'shape', 'scale', 'orientation', 'posX', 'posY')

In [36]:
## Fix posX latent to left
latents_sampled = sample_latent(size=10000)
indices_sampled = latent_to_index(latents_sampled)
latents_sampled[:, 2] = (latents_sampled[:, 2] >= 3.0).astype(np.float32)
latents_sampled[:, 5] = (latents_sampled[:, 5] >= 16.0).astype(np.float32)
latents_sampled[:, 4] = (latents_sampled[:, 4] >= 16.0).astype(np.float32)
imgs_sampled = imgs[indices_sampled] 
c_dx = (np.arange(3) == latents_sampled[:, 1][:,None]).astype(np.float32) # `shape
c_dx = np.concatenate([c_dx, 
                       (np.arange(2) == latents_sampled[:, 2][:,None]).astype(np.float32), # size
                       (np.arange(2) == latents_sampled[:, 5][:,None]).astype(np.float32), # posY
                       (np.arange(2) == latents_sampled[:, 4][:,None]).astype(np.float32), # posX
                       (np.arange(3) == latents_sampled[:, 0][:,None]).astype(np.float32)], axis=1) # color

In [37]:
c_dx[0]

array([0., 0., 1., 0., 1., 0., 1., 1., 0., 1., 0., 0.], dtype=float32)

In [38]:
filter_shape0_pos0 = ((c_dx[:, 0] == 1).astype(np.float32) *            # shape 0
                 (c_dx[:, 3] == 1).astype(np.float32) *                 # size 0         
                 (c_dx[:, 7] == 1).astype(np.float32)).astype(bool)     # posX 0
latents_sampled_0 = latents_sampled[filter_shape0_pos0]
c_dx_0 = c_dx[filter_shape0_pos0]
imgs_sampled_0 = imgs_sampled[filter_shape0_pos0]

filter_shape0_pos1 = ((c_dx[:, 0] == 1).astype(np.float32) *            # shape 0
                    (c_dx[:, 3] == 1).astype(np.float32) *              # size 0
                    (c_dx[:, 8] == 1).astype(np.float32)).astype(bool)  # posX 1
latents_sampled_1 = latents_sampled[filter_shape0_pos1]
c_dx_1 = c_dx[filter_shape0_pos1]
imgs_sampled_1 = imgs_sampled[filter_shape0_pos1]

filter_shape1_pos0 = ((c_dx[:, 1] == 1).astype(np.float32) *            # shape 1
                    (c_dx[:, 4] == 1).astype(np.float32) *              # size 1              
                    (c_dx[:, 8] == 1).astype(np.float32)).astype(bool)  # posX 1
latents_sampled_2 = latents_sampled[filter_shape1_pos0]
c_dx_2 = c_dx[filter_shape1_pos0]
imgs_sampled_2 = imgs_sampled[filter_shape1_pos0]

filter_shape1_pos1 = ((c_dx[:, 1] == 1).astype(np.float32) *            # shape 1
                    (c_dx[:, 3] == 1).astype(np.float32) *              # size 0
                    (c_dx[:, 7] == 1).astype(np.float32)).astype(bool)  # posX 0
latents_sampled_3 = latents_sampled[filter_shape1_pos1]
c_dx_3 = c_dx[filter_shape1_pos1]
imgs_sampled_3 = imgs_sampled[filter_shape1_pos1]

latent_sample = np.concatenate([latents_sampled_0, latents_sampled_1, latents_sampled_2, latents_sampled_3], axis=0)
c_dx = np.concatenate([c_dx_0, c_dx_1, c_dx_2, c_dx_3], axis=0)
imgs_sample = np.concatenate([imgs_sampled_0, imgs_sampled_1, imgs_sampled_2, imgs_sampled_3], axis=0)


In [39]:
filter_shape1_pos1.sum()

861

In [40]:
filter_pos0_shape0 = ((c_dx[:, 0] == 1).astype(np.float32) *               # shape 0
                      (c_dx[:, 5] == 1).astype(np.float32)).astype(bool)   # posY 0
c_dx[filter_pos0_shape0, -3:] = np.array([1, 0, 0])                        # color 0

filter_pos1_shape0 = ((c_dx[:, 0] == 1).astype(np.float32) *               # shape 0
                      (c_dx[:, 6] == 1).astype(np.float32)).astype(bool)   # posY 1
c_dx[filter_pos1_shape0, -3:] = np.array([1, 0, 0])                        # color 0

filter_pos0_shape1 = ((c_dx[:, 1] == 1).astype(np.float32) *               # shape 1
                      (c_dx[:, 5] == 1).astype(np.float32)).astype(bool)   # posY 0
c_dx[filter_pos0_shape1, -3:] = np.array([1, 0, 0])                        # color 0

filter_pos1_shape1 = ((c_dx[:, 1] == 1).astype(np.float32) *               # shape 1
                      (c_dx[:, 6] == 1).astype(np.float32)).astype(bool)   # posY 1
c_dx[filter_pos1_shape1, -3:] = np.array([0, 1, 0])                        # color 1


y = np.zeros((c_dx.shape[0]))

filter_col1_size0 = ((c_dx[:, 3] == 1).astype(np.float32) *               # size 0
                     (c_dx[:, 9] == 1).astype(np.float32)).astype(bool)   # color 0
y[filter_col1_size0] = 0                                                  # label 0

filter_col2_size0 = ((c_dx[:, 4] == 1).astype(np.float32) *               # size 1
                     (c_dx[:, 10] == 1).astype(np.float32)).astype(bool)  # color 1
y[filter_col2_size0] = 1                                                  # label 1              

filter_col1_size1 = ((c_dx[:, 4] == 1).astype(np.float32) *               # size 1
                     (c_dx[:, 9] == 1).astype(np.float32)).astype(bool)   # color 0
y[filter_col1_size1] = 0                                                  # label 0

filter_col2_size1 = ((c_dx[:, 3] == 1).astype(np.float32) *               # size 0
                     (c_dx[:, 10] == 1).astype(np.float32)).astype(bool)  # color 1
y[filter_col2_size1] = 0                                                  # label 0

In [41]:
c_dx[0]

array([1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0.], dtype=float32)

In [42]:
# drop column if all 0
c_dx = c_dx[:, [0, 1, 3, 4, 5, 6, 7, 8, 9, 10]]

In [43]:
list(range(1, 10, 2))

[1, 3, 5, 7, 9]

In [44]:
c_dx = c_dx[:, list(range(1, 10, 2))]

In [45]:
c_dx

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       ...,
       [1., 0., 0., 0., 0.],
       [1., 0., 1., 0., 1.],
       [1., 0., 0., 0., 0.]], dtype=float32)

In [46]:
c_y = np.concatenate([c_dx, np.expand_dims(y, axis=-1)], axis=1)
np.unique(c_y, axis=0, return_counts=True)

(array([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [0., 0., 1., 1., 0., 0.],
        [1., 0., 0., 0., 0., 0.],
        [1., 0., 1., 0., 1., 0.],
        [1., 1., 0., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1.]]),
 array([422, 421, 399, 402, 433, 428, 398, 428]))

In [14]:
# random shuffle the data
np.random.seed(0)
# shuffle indices
indices = np.arange(y.shape[0])
np.random.shuffle(indices)
# shuffle data
c_dx = c_dx[indices]
latent_sample = latent_sample[indices]
imgs_sample = imgs_sample[indices]
y = y[indices]

# split data into train and test (80%)
c_train = c_dx[:int(0.8*c_dx.shape[0])]
c_test = c_dx[int(0.8*c_dx.shape[0]):]
latent_sample_train = latent_sample[:int(0.8*latent_sample.shape[0])]
latent_sample_test = latent_sample[int(0.8*latent_sample.shape[0]):]
train_set_imgs = imgs_sample[:int(0.8*imgs_sample.shape[0])]
test_set_imgs = imgs_sample[int(0.8*imgs_sample.shape[0]):]
y_train = y[:int(0.8*y.shape[0])]
y_test = y[int(0.8*y.shape[0]):]


In [15]:
import os
save_dir = './datasets/dsprites'
os.makedirs(save_dir, exist_ok=True)

train_images_file = os.path.join(save_dir, 'train_images.npy')
test_images_file = os.path.join(save_dir, 'test_images.npy')
train_labels_file = os.path.join(save_dir, 'train_labels.npy')
test_labels_file = os.path.join(save_dir, 'test_labels.npy')
train_concepts_file = os.path.join(save_dir, 'train_concepts.npy')
test_concepts_file = os.path.join(save_dir, 'test_concepts.npy')

np.save(train_images_file, train_set_imgs)
np.save(test_images_file, test_set_imgs)
np.save(train_labels_file, y_train)
np.save(test_labels_file, y_test)
np.save(train_concepts_file, c_train)
np.save(test_concepts_file, c_test)

In [159]:
train_set_imgs

array([[[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       ...,

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]],

       [[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 