In [1]:
import tensorflow as tf

from smot.training import build_management
from smot.jupyter import model_reports
from smot.problems.mnist import mnist_lib

model_build_target = build_management.build_cache().target(
    name="mnist/simple",
)

In [2]:
# Load (and cache) standard MNIST dataset.
(x_train, y_train), (x_test, y_test) = mnist_lib.load_mnist_data_28x28x1()

In [3]:
# Build a single softmax categorical layer.
model = tf.keras.Sequential(
    [
        tf.keras.layers.Flatten(
            input_shape=mnist_lib.INPUT_SHAPE,
        ),
        tf.keras.layers.Dense(
            units=128,
            activation="relu",
        ),
        tf.keras.layers.Dense(
            units=mnist_lib.N_CLASSES,
            activation="softmax",
        ),
    ]
)

# Compile using 'Adam'
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    # Needed for the categorical softmax layer.
    loss=tf.keras.losses.categorical_crossentropy,
    metrics=["accuracy"],
)

# Print the model summary.
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [4]:
history = model.fit(
    x=x_train,
    y=y_train,
    batch_size=128,
    epochs=50,
    verbose=1,
    validation_split=0.2,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3),
    ],
)

# Evaluate the model with the test data.
test_loss, test_accuracy = model_reports.model_fit_report(
    model=model,
    history=history,
    test_data=(x_test, y_test),
)

Epoch 1/50


InternalError:  Blas xGEMM launch failed : a.shape=[1,128,784], b.shape=[1,784,128], m=128, n=128, k=784
	 [[node sequential/dense/MatMul (defined at <ipython-input-4-511e3263857f>:1) ]] [Op:__inference_train_function_598]

Function call stack:
train_function


In [10]:
# Optional:
# model_build_target.save_model(model)

# model = model_build_target.load_model()
# model.summary()