Skip to content

Commit

Permalink
Only initialize datasets once and cache the result.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 251296188
Change-Id: I4d6e5529f7fddd5517c77087c93e8641bdee25a0
  • Loading branch information
tomhennigan authored and sonnet-copybara committed Jun 3, 2019
1 parent bc966e7 commit 56f22a1
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions examples/simple_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,22 @@

def mnist(split, batch_size):
"""Returns a tf.data.Dataset with MNIST image/label pairs."""
@tf.function
def map_fn(images, labels):
def preprocess_dataset(images, labels):
# Mnist images are int8 [0, 255], we cast and rescale to float32 [-1, 1].
images = ((tf.cast(images, tf.float32) / 255.) - .5) * 2.
return images, labels

dataset = tfds.load(name="mnist", split=split, as_supervised=True)
dataset = dataset.map(map_fn)
dataset = dataset.map(preprocess_dataset)
dataset = dataset.shuffle(buffer_size=4 * batch_size)
dataset = dataset.batch(batch_size)
# Autotune the number of prefetched records to avoid becoming input bound.
# Cache the result of the data pipeline to avoid recomputation. The pipeline
# is only ~100MB so this should not be a significant cost and will afford a
# decent speedup.
dataset = dataset.cache()
# Prefetching batches onto the GPU will help avoid us being too input bound.
# We allow tf.data to determine how much to prefetch since this will vary
# between GPUs.
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset

Expand All @@ -56,21 +61,21 @@ def train_step(model, optimizer, images, labels):


@tf.function
def train_epoch(model, optimizer):
train_data = mnist("train", batch_size=128)
def train_epoch(model, optimizer, dataset):
loss = 0.
for images, labels in train_data:
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
return loss


def test_accuracy(model):
@tf.function
def test_accuracy(model, dataset):
correct, total = 0, 0
for images, labels in mnist("test", batch_size=1000):
for images, labels in dataset:
preds = tf.argmax(model(images), axis=1)
correct += tf.math.count_nonzero(tf.equal(preds, labels))
total += len(labels)
accuracy = (correct / tf.cast(total, tf.int64)) * 100.
correct += tf.math.count_nonzero(tf.equal(preds, labels), dtype=tf.int32)
total += tf.shape(labels)[0]
accuracy = (correct / tf.cast(total, tf.int32)) * 100.
return {"accuracy": accuracy, "incorrect": total - correct}


Expand All @@ -86,9 +91,12 @@ def main(unused_argv):

optimizer = snt.optimizers.SGD(0.1)

train_data = mnist("train", batch_size=128)
test_data = mnist("test", batch_size=1000)

for epoch in range(5):
train_loss = train_epoch(model, optimizer)
test_metrics = test_accuracy(model)
train_loss = train_epoch(model, optimizer, train_data)
test_metrics = test_accuracy(model, test_data)
print("[Epoch %d] train loss: %.05f, test acc: %.02f%% (%d wrong)" %
(epoch, train_loss, test_metrics["accuracy"],
test_metrics["incorrect"]))
Expand Down

0 comments on commit 56f22a1

Please sign in to comment.