Skip to content

Commit

Permalink
Many bug fixes and simplifications to the codebase.
Browse files Browse the repository at this point in the history
Bug fixes for LARS:
* When used as proposal of VAE, correctly use current sample.
* Exp moving average was double exponentiating.
Switch to TFDS for all datasets and compute mean on the fly.
Implement some Conv networks and versions of the models.
Remove flat datasets. Now all methods should handle the flattening themselves.
Fix bugs with cacheing in dataset.py.
Bug fixes and updates for TF for small_problems.py.

PiperOrigin-RevId: 275357997
  • Loading branch information
gjtucker authored and Copybara-Service committed Oct 17, 2019
1 parent 4318424 commit 4850f1e
Show file tree
Hide file tree
Showing 11 changed files with 742 additions and 539 deletions.
123 changes: 55 additions & 68 deletions eim/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,15 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp

tfd = tfp.distributions
flags = tf.flags

flags.DEFINE_string("ROOT_PATH", "/tmp", "The root directory of datasets.")
flags.DEFINE_string("data_dir", None, "Directory to store datasets.")
FLAGS = flags.FLAGS

ROOT_PATH = lambda: FLAGS.ROOT_PATH
MNIST_PATH = "data/mnist"
STATIC_BINARIZED_MNIST_PATH = "data/static_binarized_mnist"
CELEBA_PATH = "data/celeba"
FASHION_MNIST_PATH = "data/fashion_mnist"
CELEBA_IMAGE_SIZE = 64


Expand All @@ -53,27 +45,47 @@ def get_nine_gaussians(batch_size, scale=0.1, spacing=1.0):
return batch


def get_mnist(split="train", data_dir=MNIST_PATH, shuffle_files=None):
"""Get MNIST dataset."""
del shuffle_files # Ignored
path = os.path.join(ROOT_PATH(), data_dir, split + ".npy")
with tf.io.gfile.GFile(path, "rb") as f:
np_ims = np.load(f)
# Always load the train mean, no matter what split.
mean_path = os.path.join(ROOT_PATH(), data_dir, "train_mean.npy")
with tf.io.gfile.GFile(mean_path, "rb") as f:
mean = np.load(f).astype(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(np_ims)

mean *= 255.
dataset = dataset.map(lambda im: tf.to_float(im) * 255.)
def compute_mean(dataset):
def _helper(aggregate, x):
total, n = aggregate
return total + x, n + 1

return dataset, mean
total, n = tfds.as_numpy(dataset.reduce((0., 0), _helper))
return tf.to_float(total / n)


def get_static_mnist(split="train", shuffle_files=None):
return get_mnist(split, data_dir=STATIC_BINARIZED_MNIST_PATH,
shuffle_files=shuffle_files)
def get_mnist(split="train", shuffle_files=False):
"""Get FashionMNIST dataset."""
split_map = {
"train": "train",
"valid": "validation",
"test": "test",
}
datasets = dict(
zip(["train", "validation", "test"],
tfds.load(
"mnist:3.*.*",
split=["train[:50000]", "train[50000:]", "test"],
shuffle_files=shuffle_files,
data_dir=FLAGS.data_dir)))
preprocess = lambda x: tf.to_float(x["image"])
train_mean = compute_mean(datasets[split_map["train"]].map(preprocess))
return datasets[split_map[split]].map(preprocess), train_mean


def get_static_mnist(split="train", shuffle_files=False):
"""Get Static Binarized MNIST dataset."""
split_map = {
"train": "train",
"valid": "validation",
"test": "test",
}
preprocess = lambda x: tf.cast(x["image"], tf.float32) * 255.
datasets = tfds.load(name="binarized_mnist",
shuffle_files=shuffle_files,
data_dir=FLAGS.data_dir)
train_mean = compute_mean(datasets[split_map["train"]].map(preprocess))
return datasets[split_map[split]].map(preprocess), train_mean


def get_celeba(split="train", shuffle_files=False):
Expand All @@ -83,13 +95,10 @@ def get_celeba(split="train", shuffle_files=False):
"valid": "validation",
"test": "test",
}
datasets = tfds.load("celeb_a:2.*.*", shuffle_files=shuffle_files)

mean_path = os.path.join(ROOT_PATH(), CELEBA_PATH, "train_mean.npy")
with tf.io.gfile.GFile(mean_path, "rb") as f:
train_mean = np.load(f).astype(np.float32)

def _preprocess(sample, crop_width=80, image_size=CELEBA_IMAGE_SIZE):
datasets = tfds.load("celeb_a:2.*.*",
shuffle_files=shuffle_files,
data_dir=FLAGS.data_dir)
def preprocess(sample, crop_width=80, image_size=CELEBA_IMAGE_SIZE):
"""Output images are in [0, 255]."""
image_shape = sample["image"].shape
crop_slices = [
Expand All @@ -99,9 +108,8 @@ def _preprocess(sample, crop_width=80, image_size=CELEBA_IMAGE_SIZE):
image_resized = tf.image.resize_images(image_cropped, [image_size] * 2)
x = tf.to_float(image_resized)
return x

data = datasets[split_map[split]].map(_preprocess)
return data, train_mean
train_mean = compute_mean(datasets[split_map["train"]].map(preprocess))
return datasets[split_map[split]].map(preprocess), train_mean


def get_fashion_mnist(split="train", shuffle_files=False):
Expand All @@ -111,17 +119,12 @@ def get_fashion_mnist(split="train", shuffle_files=False):
"valid": "train", # No validation set, so reuse train.
"test": "test",
}
dataset = (
tfds.load(name="fashion_mnist",
split=split_map[split],
shuffle_files=shuffle_files,
).map(lambda x: tf.to_float(x["image"])))

train_mean_path = os.path.join(ROOT_PATH(), FASHION_MNIST_PATH,
"train_mean.npy")
with tf.io.gfile.GFile(train_mean_path, "rb") as f:
train_mean = np.load(f).astype(np.float32)
return dataset, train_mean
datasets = tfds.load("fashion_mnist",
shuffle_files=shuffle_files,
data_dir=FLAGS.data_dir)
preprocess = lambda x: tf.to_float(x["image"])
train_mean = compute_mean(datasets[split_map["train"]].map(preprocess))
return datasets[split_map[split]].map(preprocess), train_mean


def dataset_and_mean_to_batch(dataset,
Expand All @@ -131,7 +134,6 @@ def dataset_and_mean_to_batch(dataset,
repeat=True,
shuffle=True,
initializable=False,
flatten=False,
jitter=False):
"""Transforms data based on args (assumes images in [0, 255])."""

Expand All @@ -153,17 +155,13 @@ def _preprocess(im):
else: # [0, 1]
im /= 255.

if flatten:
im = tf.reshape(im, [-1])
return im

dataset = dataset.map(_preprocess)

if repeat:
dataset = dataset.repeat()

dataset = dataset.cache()

if shuffle:
dataset = dataset.shuffle(1024)

Expand All @@ -177,8 +175,6 @@ def _preprocess(im):

ims = itr.get_next()

if flatten:
train_mean = tf.reshape(train_mean, [-1])
if jitter:
train_mean += 0.5
elif binarize:
Expand All @@ -198,35 +194,26 @@ def get_dataset(dataset,
"""Return the reference dataset with options."""
dataset_map = {
"dynamic_mnist": (get_mnist, {
"binarize": True
"binarize": True,
}),
"raw_mnist": (get_mnist, {}),
"static_mnist": (get_static_mnist, {}),
"jittered_mnist": (get_mnist, {
"jitter": True
"jitter": True,
}),
"jittered_celeba": (get_celeba, {
"jitter": True
}),
"jittered_flat_celeba": (get_celeba, {
"jitter": True,
"flatten": True
}),
"fashion_mnist": (get_fashion_mnist, {
"binarize": True
}),
"flat_fashion_mnist": (get_fashion_mnist, {
"binarize": True,
"flatten": True
}),
"jittered_flat_fashion_mnist": (get_fashion_mnist, {
"jittered_fashion_mnist": (get_fashion_mnist, {
"jitter": True,
"flatten": True
}),
}

dataset_fn, dataset_kwargs = dataset_map[dataset]
raw_dataset, mean = dataset_fn(split, shuffle_files=shuffle)
raw_dataset, mean = dataset_fn(split, shuffle_files=False)
data_batch, mean, itr = dataset_and_mean_to_batch(
raw_dataset,
mean,
Expand Down

0 comments on commit 4850f1e

Please sign in to comment.