In [1]:
from __future__ import print_function

from pyspark import SparkContext, SparkConf
from pyspark.ml.linalg import DenseVector, VectorUDT
from pyspark.sql import SQLContext

from pyspark.ml.classification import MultilayerPerceptronClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, ArrayType

import time

In [2]:
# extract .gz files downloaded from MNIST website
!gunzip -k ./MNIST_data/*.gz

In [3]:
# Transform to csv
def convert(imgf, labelf, outf, n):
    
    f = open(imgf, "rb")
    o = open(outf, "w")
    l = open(labelf, "rb")

    f.read(16)
    l.read(8)
    images = []

    for i in range(n):
        image = [ord(l.read(1))]
        for j in range(28*28):
            image.append(ord(f.read(1)))
        images.append(image)

    for image in images:
        o.write(",".join(str(pix) for pix in image)+"\n")
    f.close()
    o.close()
    l.close()

convert("./MNIST_data/train-images-idx3-ubyte", "./MNIST_data/train-labels-idx1-ubyte",
        "./MNIST_data/mnist_train.csv", 60000)
convert("./MNIST_data/t10k-images-idx3-ubyte", "./MNIST_data/t10k-labels-idx1-ubyte",
        "./MNIST_data/mnist_test.csv", 10000)

In [4]:
# Data IO
def data_frame_from_file(sqlContext, file_name, fraction):
    lines = sc.textFile(file_name).sample(False, fraction)
    parts = lines.map(lambda l: map(lambda s: int(s), l.split(",")))
    samples = parts.map(lambda p: (
        float(p[0]),
        DenseVector(map(lambda el: el / 255.0, p[1:]))
    ))

    fields = [
        StructField("label", DoubleType(), True),
        StructField("features", VectorUDT(), True)
    ]
    schema = StructType(fields)

    data = sqlContext.createDataFrame(samples, schema)
    return data

In [5]:
# Spark
conf = SparkConf(True)
#set memory to 3/4 memory of your machine
conf.set("spark.executor.memory", "12g")
#conf.set("master","Local[4]")

sc = SparkContext(
    # allow remote login in system setting, and change "Administrators-MacBook-Pro" with your computer name
    master="spark://Administrators-MacBook-Pro.local:7077",
    appName="multilayer_perceptron_classification_example",
    conf=conf
)

sqlContext = SQLContext(sc)

In [6]:
# Prepare data
train = data_frame_from_file(sqlContext, "./MNIST_data/mnist_train.csv", 1)
test = data_frame_from_file(sqlContext, "./MNIST_data/mnist_test.csv", 1)

In [7]:
# Model
layers = [28*28, 14*14, 5*5, 10]
#layers = [28*28, 1024, 10]

# create the trainer and set its parameters
trainer = MultilayerPerceptronClassifier(maxIter=40, layers=layers, blockSize=128, stepSize=0.01, seed=1234)

In [8]:
# Train model
start_time = time.time()
model = trainer.fit(train)
train_time = time.time() - start_time
print("Training time: " + str(train_time) + "seconds")

Training time: 214.66910696seconds


In [9]:
# Test model
result = model.transform(test)
predictionAndLabels = result.select("prediction", "label")
evaluator = MulticlassClassificationEvaluator(metricName="accuracy")

In [10]:
result.show()

+-----+--------------------+----------+
|label|            features|prediction|
+-----+--------------------+----------+
|  7.0|[0.0,0.0,0.0,0.0,...|       7.0|
|  2.0|[0.0,0.0,0.0,0.0,...|       2.0|
|  1.0|[0.0,0.0,0.0,0.0,...|       1.0|
|  0.0|[0.0,0.0,0.0,0.0,...|       0.0|
|  4.0|[0.0,0.0,0.0,0.0,...|       4.0|
|  1.0|[0.0,0.0,0.0,0.0,...|       1.0|
|  4.0|[0.0,0.0,0.0,0.0,...|       4.0|
|  9.0|[0.0,0.0,0.0,0.0,...|       9.0|
|  5.0|[0.0,0.0,0.0,0.0,...|       5.0|
|  9.0|[0.0,0.0,0.0,0.0,...|       9.0|
|  0.0|[0.0,0.0,0.0,0.0,...|       0.0|
|  6.0|[0.0,0.0,0.0,0.0,...|       6.0|
|  9.0|[0.0,0.0,0.0,0.0,...|       9.0|
|  0.0|[0.0,0.0,0.0,0.0,...|       0.0|
|  1.0|[0.0,0.0,0.0,0.0,...|       1.0|
|  5.0|[0.0,0.0,0.0,0.0,...|       5.0|
|  9.0|[0.0,0.0,0.0,0.0,...|       9.0|
|  7.0|[0.0,0.0,0.0,0.0,...|       7.0|
|  3.0|[0.0,0.0,0.0,0.0,...|       3.0|
|  4.0|[0.0,0.0,0.0,0.0,...|       4.0|
+-----+--------------------+----------+
only showing top 20 rows



In [11]:
print("Accuracy: " + str(evaluator.evaluate(predictionAndLabels)))
sc.stop()

Accuracy: 0.9676
