In [1]:
import tensorflow as tf

from smot.jupyter import model_reports
from smot.problems.mnist.LeNet5 import lenet5_lib

In [2]:
# Load (and cache) standard MNIST dataset.
(x_train, y_train), (x_test, y_test) = lenet5_lib.load_LeNet5_datasets()

In [3]:
strategy = tf.distribute.MirroredStrategy()
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope():
  model = lenet5_lib.build_LeNet5_model()

# Print the model summary.
model.summary()

batch_size = 128
validation_split=0.2
epochs = 50

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
  rotation_range=10,
  shear_range=0.1,
  width_shift_range=0.1,
  height_shift_range=0.1,
  zoom_range=0.2,
  validation_split=validation_split,
)
datagen.fit(x_train)

training_generator = datagen.flow(
  x_train,
  y_train,
  subset='training',
  batch_size=batch_size,
)
validation_generator = datagen.flow(
  x_train,
  y_train,
  subset='validation',
  batch_size=batch_size,
)

history = model.fit(
  training_generator,
  validation_data=validation_generator,
  epochs=epochs,
  verbose=1,
)

# Evaluate the model with the test data.
test_loss, test_accuracy = model_reports.model_fit_report(
  model=model,
  history=history,
  test_data=(x_test, y_test),
)

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Model: "LeNet5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 28, 28, 6)         156       
________________________________________________________

KeyboardInterrupt: 