diff --git a/train_variational_autoencoder_pytorch.py b/train_variational_autoencoder_pytorch.py index 5a662a9..bda329f 100644 --- a/train_variational_autoencoder_pytorch.py +++ b/train_variational_autoencoder_pytorch.py @@ -154,7 +154,7 @@ def load_binary_mnist(cfg, **kwcfg): if not fname.exists(): print('Downloading binary MNIST data...') data.download_binary_mnist(fname) - f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binarized_mnist.hdf5'), 'r') + f = h5py.File(pathlib.os.path.join(pathlib.os.environ['DAT'], 'binary_mnist.h5'), 'r') x_train = f['train'][::] x_val = f['valid'][::] x_test = f['test'][::] diff --git a/train_variational_autoencoder_tensorflow.py b/train_variational_autoencoder_tensorflow.py index 922b83f..b7531f2 100644 --- a/train_variational_autoencoder_tensorflow.py +++ b/train_variational_autoencoder_tensorflow.py @@ -135,7 +135,7 @@ def train(): sess = tfc.InteractiveSession() sess.run(init_op) - mnist_data = tfds.load(name='binarized_mnist', split='train', shuffle_files=False) + mnist_data = tfds.load(name='binary_mnist', split='train', shuffle_files=False) dataset = mnist_data.repeat().shuffle(buffer_size=1024).batch(FLAGS.batch_size) print('Saving TensorBoard summaries and images to: %s' % FLAGS.logdir)