In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import time

In [2]:
inputs = keras.Input(shape=(784,), name="digits")
x1 = layers.Dense(64, activation="relu")(inputs)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

In [3]:
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

In [4]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss_value

In [5]:
summary_writer = tf.summary.create_file_writer('./logdir') 

In [6]:
epochs = 2
tf.profiler.experimental.start('./logdir')
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))
    print("Time taken: %.2fs" % (time.time() - start_time))
tf.profiler.experimental.stop()


Start of epoch 0
Training loss (for one batch) at step 0: 76.0347
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.3555
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.0161
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.4908
Seen so far: 38464 samples
Time taken: 1.74s

Start of epoch 1
Training loss (for one batch) at step 0: 0.9098
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4573
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.7323
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.6312
Seen so far: 38464 samples
Time taken: 1.20s
