In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import argparse
import subprocess
from tensorflowonspark import TFCluster
import mnist_dist

In [2]:
reload(logging)
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')

In [3]:
sc.addPyFile("mnist_dist.py")

In [4]:
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", help="number of epochs", type=int, default=1)
parser.add_argument("--images", help="HDFS path to MNIST images in parallelized format")
parser.add_argument("--labels", help="HDFS path to MNIST labels in parallelized format")
parser.add_argument("--format", help="example format", choices=["csv","pickle","tfr"], default="csv")
parser.add_argument("--model", help="HDFS path to save/load model during train/test", default="mnist_model")
parser.add_argument("--readers", help="number of reader/enqueue threads", type=int, default=1)
parser.add_argument("--steps", help="maximum number of steps", type=int, default=500)
parser.add_argument("--batch_size", help="number of examples per batch", type=int, default=100)
parser.add_argument("--mode", help="train|inference", default="train")
parser.add_argument("--rdma", help="use rdma connection", default=False)
num_executors = 3

In [5]:
train_images_files = "hdfs://10.110.18.217:8020/user/root/mnist/csv/train/images"
train_labels_files = "hdfs://10.110.18.217:8020/user/root/mnist/csv/train/labels"

In [6]:
args = parser.parse_args(['--mode', 'train', '--steps', '3000', '--epochs', '1',
                          '--images', train_images_files, 
                          '--labels', train_labels_files])

In [7]:
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, True, TFCluster.InputMode.SPARK)

07:46:39 INFO:Reserving TFSparkNodes w/ TensorBoard
07:46:39 INFO:listening for reservations at ('idap-agent-218.idap.com', 32967)
07:46:39 INFO:Starting TensorFlow on executors
07:46:39 INFO:Waiting for TFSparkNodes to start
07:46:39 INFO:waiting for 3 reservations
07:46:40 INFO:waiting for 3 reservations
07:46:41 INFO:all reservations completed
07:46:41 INFO:All TFSparkNodes started
07:46:41 INFO:{'addr': '/tmp/pymp-NFrXpy/listener-iZ0hvY', 'task_index': 0, 'port': 33074, 'authkey': 'Y\xa3SC\xcf\rD\xf7\x81Yx\xe1\x02w9\xfd', 'worker_num': 1, 'host': 'idap-server-216.idap.com', 'ppid': 7294, 'job_name': 'worker', 'tb_pid': 7314, 'tb_port': 49400}
07:46:41 INFO:{'addr': ('idap-agent-217.idap.com', 41033), 'task_index': 0, 'port': 44425, 'authkey': '\xce2\xb4J\xf5*Oz\xb3\x0c\xf6\x7f\xb6i\xf1\xf8', 'worker_num': 0, 'host': 'idap-agent-217.idap.com', 'ppid': 26458, 'job_name': 'ps', 'tb_pid': 0, 'tb_port': 0}
07:46:41 INFO:{'addr': '/tmp/pymp-IMZUQf/listener-bZZZkw', 'task_index': 1, 'port

In [8]:
# train model
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
cluster.train(dataRDD, args.epochs)

07:46:45 INFO:Feeding training data


In [9]:
cluster.shutdown()

07:47:28 INFO:Stopping TensorFlow nodes
07:47:28 INFO:Shutting down cluster


In [10]:
test_images_files = "hdfs://10.110.18.217:8020/user/root/mnist/csv/test/images"
test_labels_files = "hdfs://10.110.18.217:8020/user/root/mnist/csv/test/labels"

In [11]:
#Parse arguments for inference
args = parser.parse_args(['--mode', 'inference', 
                          '--images', test_images_files, 
                          '--labels', test_labels_files])

In [12]:
cluster = TFCluster.run(sc, mnist_dist.map_fun, args, num_executors, 1, False, TFCluster.InputMode.SPARK)

07:47:33 INFO:Reserving TFSparkNodes 
07:47:33 INFO:listening for reservations at ('idap-agent-218.idap.com', 42731)
07:47:33 INFO:Starting TensorFlow on executors
07:47:34 INFO:Waiting for TFSparkNodes to start
07:47:34 INFO:waiting for 3 reservations
07:47:35 INFO:all reservations completed
07:47:35 INFO:All TFSparkNodes started
07:47:35 INFO:{'addr': '/tmp/pymp-wS9V8W/listener-aTau95', 'task_index': 1, 'port': 54526, 'authkey': '\x8bJ\xc35\x16\x8aB\x00\xbc\xc0$\\Q\xfb\xe8\xc3', 'worker_num': 2, 'host': 'idap-server-216.idap.com', 'ppid': 7294, 'job_name': 'worker', 'tb_pid': 0, 'tb_port': 0}
07:47:35 INFO:{'addr': '/tmp/pymp-4rCYEk/listener-luPchF', 'task_index': 0, 'port': 49621, 'authkey': '_\xd1\x1b\xbf\x9e\x1bG\xd5\x88t\xb3LXO\xce\xef', 'worker_num': 1, 'host': 'idap-agent-218.idap.com', 'ppid': 16705, 'job_name': 'worker', 'tb_pid': 0, 'tb_port': 0}
07:47:35 INFO:{'addr': ('idap-agent-217.idap.com', 57144), 'task_index': 0, 'port': 40200, 'authkey': '\xd0\xc5\xa4\xbd\xc5\x9f@\x

In [13]:
#prepare data as Spark RDD
images = sc.textFile(args.images).map(lambda ln: [int(x) for x in ln.split(',')])
labels = sc.textFile(args.labels).map(lambda ln: [float(x) for x in ln.split(',')])
dataRDD = images.zip(labels)
#feed data for inference
prediction_results = cluster.inference(dataRDD)
prediction_results.take(20)

07:47:35 INFO:Feeding inference data


['2017-09-04T19:47:46.427081 Label: 7, Prediction: 7',
 '2017-09-04T19:47:46.427149 Label: 2, Prediction: 2',
 '2017-09-04T19:47:46.427167 Label: 1, Prediction: 1',
 '2017-09-04T19:47:46.427177 Label: 0, Prediction: 0',
 '2017-09-04T19:47:46.427187 Label: 4, Prediction: 4',
 '2017-09-04T19:47:46.427198 Label: 1, Prediction: 1',
 '2017-09-04T19:47:46.427207 Label: 4, Prediction: 4',
 '2017-09-04T19:47:46.427217 Label: 9, Prediction: 9',
 '2017-09-04T19:47:46.427227 Label: 5, Prediction: 6',
 '2017-09-04T19:47:46.427237 Label: 9, Prediction: 9',
 '2017-09-04T19:47:46.427247 Label: 0, Prediction: 0',
 '2017-09-04T19:47:46.427256 Label: 6, Prediction: 6',
 '2017-09-04T19:47:46.427265 Label: 9, Prediction: 9',
 '2017-09-04T19:47:46.427275 Label: 0, Prediction: 0',
 '2017-09-04T19:47:46.427294 Label: 1, Prediction: 1',
 '2017-09-04T19:47:46.427339 Label: 5, Prediction: 5',
 '2017-09-04T19:47:46.427351 Label: 9, Prediction: 9',
 '2017-09-04T19:47:46.427362 Label: 7, Prediction: 7',
 '2017-09-

In [14]:
cluster.shutdown()

07:47:47 INFO:Stopping TensorFlow nodes
07:47:47 INFO:Shutting down cluster
