First of all, set environment variables and initialize spark context:

In [None]:
%env SPARK_DRIVER_MEMORY=8g
%env PYSPARK_PYTHON=/usr/bin/python3.5
%env PYSPARK_DRIVER_PYTHON=/usr/bin/python3.5

from zoo.common.nncontext import *
sc = init_nncontext(init_spark_conf().setMaster("local[4]"))

# MNIST

The problem we are trying to solve here is to classify grayscale images of handwritten digits (28 pixels by 28 pixels), into their 10 categories (0 to 9). The dataset we will use is the MNIST dataset.

The MNIST dataset comes pre-loaded in Keras, in the form of a set of four Numpy arrays.

In [None]:
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

Import the modules we need to build the network. In Keras it is:

    from keras import models
    from keras import layers
Just replace it with following in order to use analytics-zoo:

In [None]:
from zoo.pipeline.api.keras import models
from zoo.pipeline.api.keras import layers

Build the network, compile and fit:

In [None]:
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

network.compile(optimizer='rmsprop',
                loss='categorical_crossentropy',
                metrics=['accuracy'])

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255

test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255


from keras.utils.np_utils import to_categorical
train_labels = to_categorical(train_labels)
#test_labels = to_categorical(test_labels)

network.fit(train_images, train_labels, nb_epoch=5, batch_size=128)

#### Evaluate return
Check our result on test set. In Keras it is:

    test_loss, test_acc = network.evaluate(test_images, test_labels)
In analytics-zoo, the return of `evaluate` method is an `EvaluationResult` object, which is different from Keras. We use following code to check:

In [None]:
test_result = network.evaluate(test_images, test_labels, batch_size=32)
print('test_acc:', test_result[0].result)