In [None]:
import findspark
findspark.init()

from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("MNIST_Neural_Networks") \
    .getOrCreate()

In [None]:
import os
path = os.getcwd()
print(path)

## Loading Data 

In [None]:
feature_culumns = ["_c" + str(i+1) for i in range(784)]

In [None]:
df_training = (spark
               .read
               .options(header = False, inferSchema = True)
               .csv(f"file://{path}/mnist-data/mnist_train.csv"))

from pyspark.ml.feature import VectorAssembler

vectorizer = VectorAssembler(inputCols=feature_culumns, outputCol="features")
training = (vectorizer
            .transform(df_training)
            .select("_c0", "features")
            .toDF("label", "features")
            .repartition(15)
            .cache())

In [None]:
df_testing = (spark
              .read
              .options(header = False, inferSchema = True)
              .csv(f"file://{path}/mnist-data/mnist_test.csv"))

testing = (vectorizer
           .transform(df_testing)
           .select("_c0", "features")
           .toDF("label", "features")
           .cache())

## Multilayer Perceptron Classifier

In [None]:
%%time
from pyspark.ml.classification import MultilayerPerceptronClassifier

layers = [28*28, 300, 10]

mpc = MultilayerPerceptronClassifier(maxIter=30, layers=layers)

model = mpc.fit(training)

In [None]:
from pyspark.sql.functions import expr
result = model.transform(testing).withColumn("matched", expr("label == prediction"))

In [None]:
result.show(3)

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")

In [None]:
evaluator.evaluate(result)

In [None]:
wrong_df = result.filter("matched = false")

In [None]:
images = wrong_df.take(36)

In [None]:
import matplotlib.pyplot as plt

fig, _ = plt.subplots(6, 6, figsize = (20, 20))
for i, ax in enumerate(fig.axes):
    r = images[i]
    label = r.label
    prediction = int(r.prediction)
    features = r.features
    ax.imshow(features.toArray().reshape(28, 28), cmap = "Greys")
    ax.set_title(f"True: {str(label)} / Pred: {str(prediction)}")

plt.show()

In [None]:
spark.stop()