# Distributed Keras Experiment

In this notebook we will evaluate all distributed optimizers and compare their performance. Furthermore, you can extend this notebook with your own algorithms in order to have a baseline metric and proof for future pull requests.

In [1]:
import numpy as np

from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation

from pyspark import SparkContext
from pyspark import SparkConf

from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

from distkeras.distributed import *

Using Theano backend.


In [2]:
# Modify these variables according to your needs.
application_name = "Distributed Keras Experimentation"
using_spark_2 = False
num_cores = 3
num_executors = 1

In [3]:
# This variable is derived from the number of cores and executors, and will 
# be used to assign the number of model trainers.
num_workers = num_executors * num_cores

In [4]:
import os

os.environ['PYSPARK_SUBMIT_ARGS'] = '--packages com.databricks:spark-csv_2.10:1.4.0 pyspark-shell'

## Preparing a Spark Context

In order to read our (big) dataset into our Spark Cluster, we first need a Spark Context. However, since Spark 2.0 there are some changes in the initialization of a Spark Context. For example, SQLContext and HiveContext do not have to be initialized seperatly anymore, i.e., the initialization process is simplified.

In [5]:
conf = SparkConf()
conf.set("spark.master", "spark://127.0.0.1:7077")
conf.set("spark.serializer", "org.apache.spark.serializer.KyroSerializer")
conf.set("spark.executor.cores", `num_cores`)
conf.set("spark.executor.instances", `num_executors`)

# Check if the user is running Spark 2.0 +
if using_spark_2:
    sc = SparkSession.builder.config(conf=conf) \
            .appName(application_name) \
            .getOrCreate()
else:
    # Create the Spark context.
    sc = SparkContext(conf=conf)
    # Add the missing imports
    from pyspark import SQLContext
    sqlContext = SQLContext(sc)

Exception: Java gateway process exited before sending the driver its port number