In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import os, time

import PIL
import tensorflow_datasets as tfds
from matplotlib.gridspec import GridSpec
import tensorflow_hub as hub

import utils

In [2]:
#@title load dataset
def load_dataset(dataset_name, data_dir='Data/'):
  ####################################################################################################################
  if dataset_name in ['dsprites', 'cifar10', 'shapes3d', 'mnist', 'fashion_mnist', 'plant_village']:
    data_rescaling_factor = 1.

    if dataset_name == 'dsprites':
      data_rescaling_factor = 255.

    dset_loaded, dset_info = tfds.load(dataset_name, data_dir=os.path.join(data_dir, 'tensorflow_datasets'), with_info=True, decoders={
        'image': tfds.decode.SkipDecoding(),
    })

    dset = dset_loaded['train']
    if dataset_name == 'cifar10':  ## I didn't use the test set pre-cifar
      dset = dset.concatenate(dset_loaded['test'])
    dset = dset.map(lambda example: example['image'])

    # dset = dset.shuffle(1_000_000)

    dset = dset.map(
        lambda image: dset_info.features['image'].decode_example(image))
    dset = dset.map(lambda image: tf.image.convert_image_dtype(image, tf.float32)*data_rescaling_factor)

    return dset
  ####################################################################################################################
  elif dataset_name == 'smallnorb':
    SMALLNORB_TEMPLATE = os.path.join(data_dir,
        "disentanglement_lib", "small_norb",
        "smallnorb-{}-{}.mat")

    SMALLNORB_CHUNKS = [
        "5x46789x9x18x6x2x96x96-training",
        "5x01235x9x18x6x2x96x96-testing",
    ]

    def _load_small_norb_chunks(path_template, chunk_names):
      """Loads several chunks of the small norb data set for final use."""
      list_of_images, list_of_features = _load_chunks(path_template, chunk_names)
      features = np.concatenate(list_of_features, axis=0)
      features[:, 3] = features[:, 3] / 2  # azimuth values are 0, 2, 4, ..., 24
      return np.concatenate(list_of_images, axis=0), features


    def _load_chunks(path_template, chunk_names):
      """Loads several chunks of the small norb data set into lists."""
      list_of_images = []
      list_of_features = []
      for chunk_name in chunk_names:
        norb = _read_binary_matrix(path_template.format(chunk_name, "dat"))
        list_of_images.append(_resize_images(norb[:, 0]))
        norb_class = _read_binary_matrix(path_template.format(chunk_name, "cat"))
        norb_info = _read_binary_matrix(path_template.format(chunk_name, "info"))
        list_of_features.append(np.column_stack((norb_class, norb_info)))
      return list_of_images, list_of_features


    def _read_binary_matrix(filename):
      """Reads and returns binary formatted matrix stored in filename."""
      with tf.io.gfile.GFile(filename, "rb") as f:
        s = f.read()
        magic = int(np.frombuffer(s, "int32", 1))
        ndim = int(np.frombuffer(s, "int32", 1, 4))
        eff_dim = max(3, ndim)
        raw_dims = np.frombuffer(s, "int32", eff_dim, 8)
        dims = []
        for i in range(0, ndim):
          dims.append(raw_dims[i])

        dtype_map = {
            507333717: "int8",
            507333716: "int32",
            507333713: "float",
            507333715: "double"
        }
        data = np.frombuffer(s, dtype_map[magic], offset=8 + eff_dim * 4)
      data = data.reshape(tuple(dims))
      return data

    def _resize_images(integer_images):
      resized_images = np.zeros((integer_images.shape[0], 64, 64))
      for i in range(integer_images.shape[0]):
        image = PIL.Image.fromarray(integer_images[i, :, :])
        image = image.resize((64, 64), PIL.Image.ANTIALIAS)
        resized_images[i, :, :] = image
      return resized_images / 255.

    images, features = _load_small_norb_chunks(SMALLNORB_TEMPLATE,
                                                SMALLNORB_CHUNKS)
    factor_sizes = [5, 10, 9, 18, 6]
    # Instances are not part of the latent space.
    latent_factor_indices = [0, 2, 3, 4]
    num_total_factors = features.shape[1]
    np.random.shuffle(images)
    return tf.data.Dataset.from_tensor_slices(np.expand_dims(images, -1).astype(np.float32))

  ####################################################################################################################
  elif dataset_name == 'cars3d':
    import scipy.io as sio
    from sklearn.utils import extmath

    CARS3D_PATH = os.path.join(data_dir, "disentanglement_lib", "cars")
    """Cars3D data set.

    The data set was first used in the paper "Deep Visual Analogy-Making"
    (https://papers.nips.cc/paper/5845-deep-visual-analogy-making) and can be
    downloaded from http://www.scottreed.info/. The images are rescaled to 64x64.

    The ground-truth factors of variation are:
    0 - elevation (4 different values)
    1 - azimuth (24 different values)
    2 - object type (183 different values)
    """

    class StateSpaceAtomIndex(object):
      """Index mapping from features to positions of state space atoms."""

      def __init__(self, factor_sizes, features):
        """Creates the StateSpaceAtomIndex.

        Args:
          factor_sizes: List of integers with the number of distinct values for each
            of the factors.
          features: Numpy matrix where each row contains a different factor
            configuration. The matrix needs to cover the whole state space.
        """
        self.factor_sizes = factor_sizes
        num_total_atoms = np.prod(self.factor_sizes)
        self.factor_bases = num_total_atoms / np.cumprod(self.factor_sizes)
        feature_state_space_index = self._features_to_state_space_index(features)
        if np.unique(feature_state_space_index).size != num_total_atoms:
          raise ValueError("Features matrix does not cover the whole state space.")
        lookup_table = np.zeros(num_total_atoms, dtype=np.int64)
        lookup_table[feature_state_space_index] = np.arange(num_total_atoms)
        self.state_space_to_save_space_index = lookup_table

      def features_to_index(self, features):
        """Returns the indices in the input space for given factor configurations.

        Args:
          features: Numpy matrix where each row contains a different factor
            configuration for which the indices in the input space should be
            returned.
        """
        state_space_index = self._features_to_state_space_index(features)
        return self.state_space_to_save_space_index[state_space_index]

      def _features_to_state_space_index(self, features):
        """Returns the indices in the atom space for given factor configurations.

        Args:
          features: Numpy matrix where each row contains a different factor
            configuration for which the indices in the atom space should be
            returned.
        """
        if (np.any(features > np.expand_dims(self.factor_sizes, 0)) or
            np.any(features < 0)):
          raise ValueError("Feature indices have to be within [0, factor_size-1]!")
        return np.array(np.dot(features, self.factor_bases), dtype=np.int64)

    def _load_data():
      dataset = np.zeros((24 * 4 * 183, 64, 64, 3))
      all_files = [x for x in tf.io.gfile.listdir(CARS3D_PATH) if ".mat" in x]
      for i, filename in enumerate(all_files):
        data_mesh = _load_mesh(filename)
        factor1 = np.array(list(range(4)))
        factor2 = np.array(list(range(24)))
        all_factors = np.transpose([
            np.tile(factor1, len(factor2)),
            np.repeat(factor2, len(factor1)),
            np.tile(i,
                    len(factor1) * len(factor2))
        ])
        indexes = index.features_to_index(all_factors)
        dataset[indexes] = data_mesh
      return dataset


    def _load_mesh(filename):
      """Parses a single source file and rescales contained images."""
      with open(os.path.join(CARS3D_PATH, filename), "rb") as f:
        mesh = np.einsum("abcde->deabc", sio.loadmat(f)["im"])
      flattened_mesh = mesh.reshape((-1,) + mesh.shape[2:])
      rescaled_mesh = np.zeros((flattened_mesh.shape[0], 64, 64, 3))
      for i in range(flattened_mesh.shape[0]):
        pic = PIL.Image.fromarray(flattened_mesh[i, :, :, :])
        # pic.thumbnail((64, 64, 3), PIL.Image.ANTIALIAS)
        pic = pic.resize((64, 64), PIL.Image.ANTIALIAS)
        rescaled_mesh[i, :, :, :] = np.array(pic)
      return rescaled_mesh * 1. / 255


    factor_sizes = [4, 24, 183]

    latent_factor_indices = [0, 1, 2]

    features = extmath.cartesian(
            [np.array(list(range(i))) for i in factor_sizes])
    index = StateSpaceAtomIndex(factor_sizes, features)

    data_shape = [64, 64, 3]
    images = _load_data()

    np.random.shuffle(images)
    return tf.data.Dataset.from_tensor_slices(images.astype(np.float32))
  elif dataset_name == 'celebA':
      images = np.load(os.path.join(data_dir, 'celebA/data.npy'))
      dataset = tf.data.Dataset.from_tensor_slices(images)
      return dataset.map(lambda img: tf.image.convert_image_dtype(img, tf.float32))

In [7]:
model_start = 7750
model_end = 7755
number_models = model_end - model_start

number_bottleneck_channels = 10
monte_carlo_number_random_samples = 20_00
models_dir = 'trained_models/'
dataset_name = 'smallnorb'

dataset_sizes = {'dsprites': 737280,
                 'cars3d': 17568,
                 'smallnorb': 48600}

ct = time.time()
image_dataset = load_dataset(dataset_name, data_dir='Data/')
print(f'Loaded {dataset_name}, took {time.time()-ct:.3f} sec')

### Embed the full dataset
embs_mus_all, embs_logvars_all = [[], []]
ct = time.time()
for model_num in range(model_start, model_end):
  embed = hub.load(os.path.join(models_dir, str(model_num), 'model/tfhub'))

  image_chunk_size = 1000
  embs_mus, embs_logvars = [[], []]
  for image_chunk in image_dataset.batch(image_chunk_size):
    embs = embed.signatures['gaussian_encoder'](image_chunk)
    embs_mus.append(embs['mean'])
    embs_logvars.append(embs['logvar'])
  embs_mus_all.append(np.concatenate(embs_mus, 0))
  embs_logvars_all.append(np.concatenate(embs_logvars, 0))

print(f'Embedded full dataset (number instances: {embs_mus_all[0].shape[0]}) for {number_models} models.  Took {time.time()-ct:.3f} sec.')
### Now we have them, run through everything
single_infos, double_infos, combined_infos = [[], [], []]
nmis, vis = [[], []]
nmi_errs, vi_errs = [[], []]
ct = time.time()
for model_num1 in range(number_models):
  single_infos.append(
      utils.monte_carlo_info(embs_mus_all[model_num1],
                             embs_logvars_all[model_num1],
                             number_random_samples=monte_carlo_number_random_samples))
  double_infos.append(
      utils.monte_carlo_info(np.tile(embs_mus_all[model_num1], [1, 2]),
                             np.tile(embs_logvars_all[model_num1], [1, 2]),
                             number_random_samples=monte_carlo_number_random_samples))
  for model_num2 in range(model_num1+1, number_models):
    combined_infos.append(
        utils.monte_carlo_info(np.concatenate([embs_mus_all[model_num1], embs_mus_all[model_num2]], 1),
                               np.concatenate([embs_logvars_all[model_num1], embs_logvars_all[model_num2]], 1),
                               number_random_samples=monte_carlo_number_random_samples))
print(f'Computed infos {int(number_models*(number_models+3)/2)} times.  Took {time.time()-ct:.3f} sec.')
running_index = 0
for model_num1 in range(number_models):
  for model_num2 in range(model_num1+1, number_models):
    i1 = single_infos[model_num1][0]
    i2 = single_infos[model_num2][0]
    i11 = double_infos[model_num1][0]
    i22 = double_infos[model_num2][0]
    i12 = combined_infos[running_index][0]
    nmis.append((i1+i2-i12) / np.sqrt((2*i1-i11)*(2*i2-i22)))
    vis.append(2*i12 - i11 - i22)

    i11_err = double_infos[model_num1][1]
    i1_err = single_infos[model_num1][1]
    i22_err = double_infos[model_num2][1]
    i2_err = single_infos[model_num2][1]
    i12_err = combined_infos[running_index][1]
    vi_errs.append(np.sqrt(4*i12_err**2 - i11_err**2 - i22_err**2))

    partial11_sq = (i1+i2-i12)**2/4/(2*i1-i11)**3/(2*i2-i22)
    partial1_sq = (i1+i12-i11-i2)**2/(2*i1-i11)**3/(2*i2-i22)
    partial12_sq = 1/(2*i1-i11)/(2*i2-i22)
    partial22_sq = (i1+i2-i12)**2/4/(2*i2-i22)**3/(2*i1-i11)
    partial2_sq = (i2+i12-i22-i1)**2/(2*i2-i22)**3/(2*i1-i11)
    combined_nmi_err = np.sqrt(partial11_sq*i11_err*22 + partial1_sq*i1_err**2 + partial12_sq*i12_err**2 + partial22_sq*i22_err**2 + partial2_sq*i2_err**2)
    nmi_errs.append(combined_nmi_err)


    running_index += 1

nmis = np.array(nmis)
nmi_errs = np.array(nmi_errs)
vis = np.array(vis)
vi_errs = np.array(vi_errs)
single_infos_errs = np.array(single_infos)[:, 1]
single_infos = np.array(single_infos)[:, 0]

comb_info = np.sum(single_infos/single_infos_errs**2)/np.sum(1./single_infos_errs**2)
comb_info_err = 1/np.sqrt(np.sum(1./single_infos_errs**2))

comb_nmi = np.sum(nmis/nmi_errs**2)/np.sum(1./nmi_errs**2)
comb_nmi_err = 1/np.sqrt(np.sum(1./nmi_errs**2))

comb_vi = np.sum(vis/vi_errs**2)/np.sum(1./vi_errs**2)
comb_vi_err = 1/np.sqrt(np.sum(1./vi_errs**2))

print(f'I(U;X)/H(X): {comb_info/np.log2(dataset_sizes[dataset_name]):.3f} +- {comb_info_err/np.log2(dataset_sizes[dataset_name]):.3f} bits')
print(f'NMI: {comb_nmi:.3f} +- {comb_nmi_err:.3f}')
print(f'VI: {comb_vi:.3f} +- {comb_vi_err:.3f} bits')

Loaded smallnorb, took 9.365 sec
Embedded full dataset (number instances: 48600) for 5 models.  Took 5.194 sec.
Computed infos 20 times.  Took 12.526 sec.
I(U;X)/H(X): 0.753 +- 0.002 bits
NMI: 0.965 +- 0.018
VI: 0.731 +- 0.028 bits
