<a href="https://colab.research.google.com/github/Le-Zheng/analytics-zoo/blob/colab/docs/docs/colab-tutorials/lenet_mnist_graph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Environment Preparation**

In [1]:
# Install jdk8
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
!update-alternatives --set java /usr/lib/jvm/java-8-openjdk-amd64/jre/bin/java
!java -version

openjdk version "1.8.0_275"
OpenJDK Runtime Environment (build 1.8.0_275-8u275-b01-0ubuntu1~18.04-b01)
OpenJDK 64-Bit Server VM (build 25.275-b01, mixed mode)


In [None]:
# Install latest release version of analytics-zoo 
# Installing analytics-zoo from pip will automatically install pyspark, bigdl, and their dependencies.
!pip install analytics-zoo

In [None]:
# Install required dependencies
!pip install tensorflow==1.15.0 tensorflow-probability==0.7.0 tensorflow-datasets==2.1.0

**lenet_mnist_graph example**

In [4]:
#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import argparse

import tensorflow as tf
import tensorflow_datasets as tfds
from zoo.orca.learn.tf.estimator import Estimator
from zoo.orca import init_orca_context, stop_orca_context


def accuracy(logits, labels):
    predictions = tf.argmax(logits, axis=1, output_type=labels.dtype)
    is_correct = tf.cast(tf.equal(predictions, labels), dtype=tf.float32)
    return tf.reduce_mean(is_correct)


def lenet(images):
    with tf.variable_scope('LeNet', [images]):
        net = tf.layers.conv2d(images, 32, (5, 5), activation=tf.nn.relu, name='conv1')
        net = tf.layers.max_pooling2d(net, (2, 2), 2, name='pool1')
        net = tf.layers.conv2d(net, 64, (5, 5), activation=tf.nn.relu, name='conv2')
        net = tf.layers.max_pooling2d(net, (2, 2), 2, name='pool2')
        net = tf.layers.flatten(net)
        net = tf.layers.dense(net, 1024, activation=tf.nn.relu, name='fc3')
        logits = tf.layers.dense(net, 10)
        return logits


def preprocess(data):
    data['image'] = tf.cast(data["image"], tf.float32) / 255.
    return data['image'], data['label']


def main(max_epoch, dataset_dir):

    mnist_train = tfds.load(name="mnist", split="train", data_dir=dataset_dir)
    mnist_test = tfds.load(name="mnist", split="test", data_dir=dataset_dir)

    mnist_train = mnist_train.map(preprocess)
    mnist_test = mnist_test.map(preprocess)

    # tensorflow inputs
    images = tf.placeholder(dtype=tf.float32, shape=(None, 28, 28, 1))
    # tensorflow labels
    labels = tf.placeholder(dtype=tf.int32, shape=(None,))

    logits = lenet(images)

    loss = tf.reduce_mean(tf.losses.sparse_softmax_cross_entropy(logits=logits, labels=labels))

    acc = accuracy(logits, labels)

    # create an estimator
    est = Estimator.from_graph(inputs=images,
                               outputs=logits,
                               labels=labels,
                               loss=loss,
                               optimizer=tf.train.AdamOptimizer(),
                               metrics={"acc": acc})
    est.fit(data=mnist_train,
            batch_size=320,
            epochs=max_epoch,
            validation_data=mnist_test,
            # tfds mnist only has one file and cannot be sharded on files,
            # falling back on sharding on records.
            auto_shard_files=False)

    result = est.evaluate(mnist_test, auto_shard_files=False)
    print(result)

    est.save_tf_checkpoint("/tmp/lenet/model")


if __name__ == '__main__':

    init_orca_context(cluster_mode="local", cores=4)
    dataset_dir = "~/tensorflow_datasets"

    main(1, dataset_dir)

Prepending /usr/local/lib/python3.6/dist-packages/bigdl/share/conf/spark-bigdl.conf to sys.path
Adding /usr/local/lib/python3.6/dist-packages/zoo/share/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.9.0-jar-with-dependencies.jar to BIGDL_JARS
Prepending /usr/local/lib/python3.6/dist-packages/zoo/share/conf/spark-analytics-zoo.conf to sys.path

Initializing orca context
Current pyspark location is : /usr/local/lib/python3.6/dist-packages/pyspark/__init__.py
Start to getOrCreate SparkContext
pyspark_submit_args is:  --driver-class-path /usr/local/lib/python3.6/dist-packages/zoo/share/lib/analytics-zoo-bigdl_0.12.1-spark_2.4.3-0.9.0-jar-with-dependencies.jar:/usr/local/lib/python3.6/dist-packages/bigdl/share/lib/bigdl-0.12.1-jar-with-dependencies.jar pyspark-shell 
Successfully got a SparkContext


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/tensorflow_datasets/mnist/3.0.0...[0m


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m
Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.


Instructions for updating:
Use `tf.keras.layers.Conv2D` instead.


Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Please use `layer.__call__` method instead.


Instructions for updating:
Use keras.layers.MaxPooling2D instead.


Instructions for updating:
Use keras.layers.MaxPooling2D instead.


Instructions for updating:
Use keras.layers.flatten instead.


Instructions for updating:
Use keras.layers.flatten instead.


Instructions for updating:
Use keras.layers.Dense instead.


Instructions for updating:
Use keras.layers.Dense instead.


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
































creating: createFakeOptimMethod
Instructions for updating:
Use `tf.cast` instead.


Instructions for updating:
Use `tf.cast` instead.








creating: createStatelessMetric
























creating: createTFTrainingHelper
creating: createIdentityCriterion
creating: createEstimator
creating: createMaxEpoch
creating: createEveryEpoch
INFO:tensorflow:Restoring parameters from /tmp/tmp3ikl1wdc/model


INFO:tensorflow:Restoring parameters from /tmp/tmp3ikl1wdc/model


creating: createStatelessMetric
Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save.


Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.simple_save.


Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.


Instructions for updating:
This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.


INFO:tensorflow:Assets added to graph.


INFO:tensorflow:Assets added to graph.


INFO:tensorflow:No assets to write.


INFO:tensorflow:No assets to write.


'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'


INFO:tensorflow:SavedModel written to: /tmp/tmp1wow28y6/saved_model.pb


INFO:tensorflow:SavedModel written to: /tmp/tmp1wow28y6/saved_model.pb


{'acc': 0.9807000160217285}
'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'


'TFDataDataset' object has no attribute 'name'
