In [None]:
!pip3 install tqdm requests dill

In [1]:
import requests
from tqdm import tqdm
import os

def download_from_url(url, dst):
    file_size = int(requests.head(url).headers["Content-Length"])
    if os.path.exists(dst):
        first_byte = os.path.getsize(dst)
    else:
        first_byte = 0
    if first_byte >= file_size:
        return file_size
    header = {"Range": "bytes=%s-%s" % (first_byte, file_size)}
    pbar = tqdm(
        total=file_size, initial=first_byte,
        unit='B', unit_scale=True, desc=url.split('/')[-1])
    req = requests.get(url, headers=header, stream=True)
    with(open(dst, 'ab')) as f:
        for chunk in req.iter_content(chunk_size=1024):
            if chunk:
                f.write(chunk)
                pbar.update(1024)
    pbar.close()
    return file_size

In [2]:
download_from_url('https://raw.githubusercontent.com/sjwhitworth/golearn/master/examples/datasets/mnist_train.csv', 
                  'mnist_train.csv')

1115034

In [3]:
from sparkflow.graph_utils import build_graph
from sparkflow.tensorflow_async import SparkAsyncDL
import tensorflow as tf
from pyspark.ml.feature import VectorAssembler, OneHotEncoder
from pyspark.ml.pipeline import Pipeline
from sparkflow.graph_utils import build_adam_config
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
    
def small_model():
    x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
    y = tf.placeholder(tf.float32, shape=[None, 10], name='y')
    layer1 = tf.layers.dense(x, 256, activation=tf.nn.relu)
    layer2 = tf.layers.dense(layer1, 256, activation=tf.nn.relu)
    out = tf.layers.dense(layer2, 10)
    z = tf.argmax(out, 1, name='out')
    loss = tf.losses.softmax_cross_entropy(y, out)
    return loss

In [4]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import rand
sparkSession = SparkSession.builder.appName("csv").getOrCreate()

In [5]:
df = sparkSession.read.csv('mnist_train.csv',header=True,inferSchema=True)

In [6]:
va = VectorAssembler(inputCols=df.columns[1:785], outputCol='features').transform(df)

In [7]:
va.select('label').show(1)

+-----+
|label|
+-----+
|    1|
+-----+
only showing top 1 row



In [8]:
encoded = OneHotEncoder(inputCol='label', outputCol='labels', dropLast=False).transform(va).select(['features', 'labels'])

In [9]:
mg = build_graph(small_model)
adam_config = build_adam_config(learning_rate=0.001, beta1=0.9, beta2=0.999)

In [10]:
spark_model = SparkAsyncDL(
    inputCol='features',
    tensorflowGraph=mg,
    tfInput='x:0',
    tfLabel='y:0',
    tfOutput='out:0',
    tfOptimizer='adam',
    miniBatchSize=300,
    miniStochasticIters=1,
    shufflePerIter=True,
    iters=50,
    predictionCol='predicted',
    labelCol='labels',
    partitions=3,
    verbose=1,
    optimizerOptions=adam_config
)

In [12]:
fitted_model = spark_model.fit(encoded)

 * Serving Flask app "sparkflow.HogwildSparkModel" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: on


 * Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)
172.23.0.2 - - [22/Nov/2018 06:14:55] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:55] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:55] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:14:56] "[37mPOST /update HTTP/1.1[0m" 200 -


172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:12] "[37mGET /parameters H

172.23.0.2 - - [22/Nov/2018 06:15:14] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:14] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:14] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mGET /parameters HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mPOST /update HTTP/1.1[0m" 200 -
172.23.0.2 - - [22/Nov/2018 06:15:15] "[37mGET /parameter

In [13]:
predictions = fitted_model.transform(encoded)

In [14]:
predictions.show(1)

+--------------------+--------------+---------+
|            features|        labels|predicted|
+--------------------+--------------+---------+
|(784,[132,133,134...|(10,[1],[1.0])|       []|
+--------------------+--------------+---------+
only showing top 1 row



In [None]:
evaluator = MulticlassClassificationEvaluator(
    labelCol="labels", predictionCol="predicted", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))