## Notebook requirements

* DBR 5.2 ML cluster
* install pyarrow v0.12 on the cluster.  It is available as a pypi package which is installed via `workspace libraries`.  **Note:  should not be needed in 5.3 DBR anymore***
  * Apache Arrow allows you to communicate with multiple data sources in-memory.  
  * It allows you to translate from one format to Arrow, and back, without having to write or find a custom adapter (cassandra to pandas for instance).  
* use an init-script to mount persistent-storage to your cluster

In [2]:
from tensorflow import keras
def get_keras_dataset():
  (x_train, y_train),(x_test,y_test) = keras.datasets.mnist.load_data()
  
  # add channel dimension so it works with keras.layers.Conv2D
  x_train = x_train.reshape(-1,28,28,1)
  x_test = x_test.reshape(-1,28,28,1)
  
  return (x_train,y_train), (x_test,y_test)


MNIST is 28x28 pixels

Here are some examples:

![image](https://corochann.com/wp-content/uploads/2017/02/mnist_plot.png)

In [4]:
import random
from scipy import ndimage

from pyspark.sql.types import IntegerType, StructField, StructType
from pyspark.ml.image import ImageSchema

uint8_one_channel = ImageSchema.ocvTypes['CV_8UC1']

def spark_image(numpy_image):
  height, width, num_chan = numpy_image.shape
  data = bytearray(numpy_image.tobytes())
  return ['',height,width,num_chan,uint8_one_channel,data]

enhancement_factor = 5

# for every image, give me 5 new images, slightly rotated
def add_rotation(example):
  """rotate the images to reflect real handwriting"""
  image, label = example
  for i in range(enhancement_factor):
    angle = random.randint(-45,45)
    rotatedImage = ndimage.rotate(image, angle, reshape=False).astype('uint8')
    yield [spark_image(rotatedImage), label]
    
schema = (StructType()
          .add(StructField("image",ImageSchema.columnSchema))
          .add(StructField("label", IntegerType())))

(x_train,y_train),(x_test,y_test) = get_keras_dataset()
enhanced_data = (sc.parallelize(zip(x_train,y_train.tolist()),8)
                .flatMap(add_rotation)
                .toDF(schema))
    
display(enhanced_data)



image,label
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQ0lEQVR42mNgoDVgZGRgYsIhBaVYWRiZmDHUCHkra5iJMUGVsbIxIWu0+v/z6uU9D1YEFsYHsSGZBWaZPPoPAf9ePbVomBfGhSTLwFfefOv//w/Pvvx/tnQ7UE01iqVszPP/P00K2Pq/TMwztSZfkQHJViYGrmV/whkYeFZJA73EyYzqXGbG9v8NDFwMfGAtjOi+Mf89mQGbOMTg0w+TZBhwhVLE/6fTmHAFL1PL/+8BLFiNZWBmkJ31/7gmpqWMELvYnv+fwYZiGCskroQ17CKj//9/JYSqjZuZzTR08vaHH159+P9zuShCgsc9N/XIvg3v//wBhvn//z+/ODKwwiWtb0Jj4//715fm9vnpcCB8ysigkL3r+f4rLy43hsiLMcMcBg8XoB/kXfm4QUxmZrQAYmJlBfuRmRF7qDIwMVI3JQMAwDVsPRxAxO8AAAAASUVORK5CYII=)",5
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABQ0lEQVR42mNgoDVgZGBiYmDCJcvGCCSYWLBJsTC0bHCzkgZKszIwY8qe+v//5uZ2V7AVjGiS3N23/v//9///+qx4JQYmRjZUWR1XoN5vf///v3/JgJ2BgRnZcGYG1s3/P688+Pn/x4vN27fZoRnsdO+aK4Pc5v9x3kv+/P8Vi2qv8KP7MgwM9tMY2OWdqnPVUf3JcuIKI9BLfEDMxIHhm7UvnMDeYASFF7pvUv4fCcYQhAH+w/+XMaEGLysDK9AQRiZGFqbM7w8tkHWihsS6/7VIPJOWtvBpUZrBPomJJfGxyg//X5NBSDq//P//1/83/7//+vPr/8tn/3+9lkRyElv5x18f/4PBv/9//3/O5EGxkiOgcdvVbcsmLJu9YMWMWQzIfmECSrOyA4OFlYEd7DpW1MTDBBRkZmBmZGJmYGEYYAAAYpJsg1Ni4TQAAAAASUVORK5CYII=)",5
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABRklEQVR42mNgoDlgZGBgBhHY5UT5gBQLVmkmhnkXPLVZGRhZsMmKXfn/+3iXOQMDG9B4dCAw49L//x//r8ZuuILPrz///v1/sihajAFNMyMzg8ePX18ffPz//+F6oEY2FGlm/uX/ey30tv7/f2pjVr8iurkr/vsyMGh+qNdo/vT/fySKFAvDzC8KDBy8+QwMMr7zTtZyouqc/d4MTLMyMIvJ8EBDDQby/682BIUBI0KcjQEWJkoP/+chm6RuCgw3iFcYmXmX/5/OyYSQ3Pxrz5q5ljqKijxCDOwMpl8vqiHp1L/4////X89fXzi9e3GNVfr/954osdbw4gFQ/h8Q/nzy5///TCQpVmCQpm17/eTl4//fgWr+HrVCDlUmIJuTQ0fe2CenbtWGMwoYscwGxExAu9gZhBiwRjUjIxMzExMDrkRELwAApupsfN86vYgAAAAASUVORK5CYII=)",5
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA0ElEQVR42mNgGMyAWUhIqK5jvdSy/9/rUSTkVOJmrfoLAg/X/P102AFZzvDdXyj4HRsUZKGOolHoNljm2LbvH7HYFzAn++/fs9wM2rOwuYaPcdbfKNyO7f67jwmnJPe+v264tSp/fLgghxGXbOCHv3/LJXHJ6u76+3eaNC5Zgdg/f3fjtvjn358O2GX0mrb//Xseq4fUpzwFBuGvbVikJIrugoL3pB+mlLjTVXDIB2KaKbQaHCuHAzgxpMzXPAJJfWnlxmJbB1DmSnuLAJ1SIwD7bWuGGa95MgAAAABJRU5ErkJggg==)",5
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABTklEQVR42mNgQAKM7AwMzAyMDFgBM4PwDB9cUgyaL/9HYtUJlGMo/b9VH7tOhvCZ8/ZbsGCTYWUI2va53QaHHMPS/3+8wYajAUYWBuHl/3//P4xNHyOD4M7/f+5s6TfA5lTJ/pe/XsVh1cfEEPfs//9PsgzcQDuZGRkZUWz2uv//7/8XbhwMDDIcgkKcKFrFrvx/0nvw97NVDpmbj+/Ysq9dDmgYPHCcl0yTkVlw4/9/oAG/gIQfkpeYGXh4GRg4BHTLJlaW2m/4/+eMBLJPGIHuYGKBhHn49/8rsIcvC4PloV/fgxlYsEvX/fu/jw2rDCOj9pW/31JBQY0NJP38f14CuxQT/7p/P4qwxQ0IOH76/0QNu43MDKX//u8XxKHR7Pr/n17YNbIJHfj/aoEcMxPWKJe89f+AKo7kJ7zq/11NoF+xSuo83sfNwIKQAwBpzWvdeD1AaAAAAABJRU5ErkJggg==)",5
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABa0lEQVR42mNgIAEwsjAxkAcELIzYcZjJYLj4/70ZogzYjGYUOvf/25//nkBV6ICJkXv1/59fdjy9kieGrpeRgbn375//n4sn/3/kyMCCqo/BaPevX7/e/98ScuqND5okM0Py//+f2qZ8ymEo/D+BkRHVUPuHv77VscbYMjBY/n+YiKpVdM7/H6etgQx2Bu3/XxahSHLVf/gaqSoNdDIjg8j/H5uRnMvIIP3hfy4L2H+MDOz/fx9GcU/Rv3tcUL+zMHz6v50DERBM3Nf+L+aE+UnuyTsHhBwjg+3D/5GMULbw6v+zuJAl9S//t4Y6UHjKzwmSLMieVDv90gRoIFC5wdJvM9CCVXbfZ0cGUEQ6Xf6/T4GBDUVSYun/tdrsUiXn/n1d7sCDFnZCc/5/uL759M//D/vMwQLIkuyl////+vn7/61oBrQIAcnKtF75//9/syQDK9bUZxRsqcSGJYFAAgYfYGJlYWIkIREDAOiuef7MZkTpAAAAAElFTkSuQmCC)",0
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABV0lEQVR42mNgGOSAEYgYGRlYWDBkmCCyPCL8EDaaNkMdLb+eU6evl6CKMzIw8tjH3f3z6u/////+PzBDkgNaYdq49Mv/v////v3/58vn/9ORXSFYcuf//7c/fl/buur9n839/30ZWOFytnv//z+VYxdWq8+a/ud/08T/iSBRiKTUhP9fF6oBmdIsTF6p+Xpr/m8Uh8mxRr39nyPBwMwIU37wbRHcTv0D77vBqoByzAxcDOb/TwkxsEA0Ctf+vyoAcwBI8Or/5UBFEEmnR//XwOWAjKr/X1ZywLhJ/39e1AcJMzMxsrAyKP7//9CfARa4Gpd+/9vibSwG5rDLbv//YTo8bJnYJvz///XPtfN1Dc31U3a/+P9qhzLCFgaGGmC4AcGn/5/fXD4321YM6hwQYGNgUEiI7s/um9AD9C0DPzgaEMHOBgsOHgZOBizRzAI0gIVwMqE5AACTJ3pm3GlL1wAAAABJRU5ErkJggg==)",0
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABYklEQVR42mNgGByAkYGFGas4I5TBhEULEAg6tkwvY8GQYmIQ8bXpuvbj/7//+UzomjjNF/z/9f/319+f/31QQ5ZkZVDxn/v/54//v////7Lq0WOoHSDAzMCofuLt/y//Xy+s7P3/P+rway0kMzlKr/9/sMS3PIVDeuLzbT6X/vewI+Qyvv2/YAM0QYKBQZFTiOfUi0SYZxgZ1Lf83+0Nci7IIjYGof9nhBiYoXJsyb/+GII8zgjisTNM/L+fHa5Rbfv7FLhKVgahv/+Py8Lk2OK+/xaHBxgHQ8//TzPgPpHe/P+uGhMTAzMrAzNQCfPv/3fNgQZAgMTd/9+mibFAHC/qt/z/u1mIMGCu/f7z+8Fps7PcKusXn//7/2ySElwjA4PK9f/AwP79/++v/8/3l8RpQ9wNA2xl785dvLlq3/Rub0hgIsUIGxAr8SgwCDEycAJdiwaYwY5hYWDBlW4Y6ZZGAS9GelTPaffDAAAAAElFTkSuQmCC)",0
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABYklEQVR42mNgoB9gxCvLwoJNlAkIeRVlscoAgaDvsfP/ZwtiSDIzWHtWnf7//++vFzlM6I7gN73x/8+/v7++Pv+/hYEVxUh+27V///78/+73z9txb7qAxiDZJrX21r8//2+VRKz4fz3h778wRiTbHFb/f783rjWQQX/CnAirG6/iGOC2Mrne+//ah5WBXYSBgYuBgefTUWmYuYwMolP/PyxkAPOZgbKe/48yIYy1uvY/nQMoDhZhYdj7/7gy3BOc5f/3sEIsYQT6QeH/71OC0ABmZFDb+d8b7nYuhp7/H9oYYKHLHPnlvwVQBzMLWAXT2//3TWCBwMg+8f+HSfxQldIpe/7f6UGKto7/f7/OjU1NVI9rXnr75/9tVkxIASS+5f+fL/8/////4f+7bSmCUkxI0c3EYAOMi3d/7x3b3Z3GhyUhqGe5OeozqDCwsTAwoaUZdhZorGFPPSygBMJIn/QLAOjjem0vgCWHAAAAAElFTkSuQmCC)",0
"List(, 28, 28, 1, 0, Binary image data placeholder. Access the image data field directly to view raw binary data., iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABbklEQVR42mNgIBawsDEw4pNnZMAlz9ewrtaGgRmrJtaSv///XxKSwKpZ5///z7d+fNukzsCCqbPz/883OX3/36ZhmMzEKH/jz/8/N4Ka/p/WYWZgQpPO/f/3xf//pgxr/x8JZmBFlUv78/9D5uJ0Bga3/182oLqJhaH+/4/ZLNa8QOte/L9ngqKVkW/Xr9+lIJvYGTr+f2hCdZPe59+7XfiAYqwMSv//tqLIMeb9PikMcSM7w5mftijOlTz1PwtsFNAhes/vaSK7h4n7zn87mD21f6NRPWL66k0U2PlMDAb3D6NIMTPY/v/iDJJkZRBa/jkOLXDtbv0PAcoApRv/X7FiRY1i8Zn/O4DhyyCW9/5BFGqUMTIwR//awMXAID3nz+8+RbT4ZGTQOvJ/VVzE6v//V2DENSMwifz/8v371387jRjYMNOB17P/v///v2OKNdWxp6y4+LZVnoGREZssI5MUyMs4UiYwrNlZiM4BAKfVeq5q7dXwAAAAAElFTkSuQmCC)",0


Setup working directory using fuse mounted persistent storage.  We use this to tage data, store checkpoints, and save anything else we don't want to lose 
this should be a high IO subsystem

In [6]:
import os
from time import time

working_dir="/tmp/blobfuse"

if not os.path.exists(working_dir):
  os.mkdir(working_dir)
#dbutils.fs.mkdirs(working_dir)
display(dbutils.fs.ls("file:/tmp"))

path,name,size
file:/tmp/.ICE-unix/,.ICE-unix/,4096
file:/tmp/.X11-unix/,.X11-unix/,4096
file:/tmp/hsperfdata_root/,hsperfdata_root/,4096
file:/tmp/spark-root-org.apache.spark.deploy.master.Master-1.pid,spark-root-org.apache.spark.deploy.master.Master-1.pid,5
file:/tmp/driver-daemon.pid,driver-daemon.pid,5
file:/tmp/custom-spark.conf,custom-spark.conf,231
file:/tmp/blobfuse/,blobfuse/,4096
file:/tmp/master-params,master-params,18
file:/tmp/chauffeur-env.sh,chauffeur-env.sh,156
file:/tmp/tmp.poAeS0dNOe,tmp.poAeS0dNOe,0


In [7]:




current_time = int(time())

# the weights go into the checkpoint
checkpoint_dir = f"{working_dir}/checkpoints-{current_time}"
data_dir = f"{working_dir}/tfrecords-{current_time}"
log_dir= f"{working_dir}/logs-{current_time}"


In [8]:
# this simply converts the data to tf format to make it easier for tf to read
enhanced_data.select(enhanced_data.image["data"].alias("image_raw"), enhanced_data.label) \
  .write.format("tfrecords") \
  .option("recordType","Example") \
  .mode("overwrite") \
  .save("file:" + data_dir)


In [9]:
import tensorflow as tf

# this will run on each cluster node
def build_dataset (files,batch_size,steps_per_epoch):
  files = tf.data.Dataset.list_files(files)
  #Horovod:  shard input data across workers
  dataset = tf.data.TFRecordDataset(files) \
    .shard(hvd.size(),hvd.rank()) \
    .repeat() \
    .shuffle(batch_size * steps_per_epoch) \
    .map(decode) \
    .map(normalize) \
    .prefetch(batch_size * 10) \
    .batch(batch_size)
  return dataset

def decode(serialized_example):
  features = tf.parse_single_example(serialized_example,
    features={
      'image_raw': tf.FixedLenFeature([], tf.string),
      'label': tf.FixedLenFeature([],tf.int64),
    })
  image = tf.decode_raw(features['image_raw'], tf.uint8)
  image.set_shape(28*28)
  image = tf.reshape (image,(28,28,1))
  label = tf.cast(features['label'],tf.int32)
  label = tf.one_hot(label,10)
  return image, label

def normalize (image, label):
  image = tf.cast(image, tf.float32) /255.
  return image, label 
  
    

In [10]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D




In [11]:
import keras
print(keras.__version__)

In [12]:

# build hidden layers, etc
def get_model():
  model = Sequential()
  model.add(Conv2D(32,kernel_size= (3,3),activation='relu',input_shape = (28,28,1)))
  model.add(Conv2D(64,(3,3),activation='relu'))
  model.add(MaxPooling2D(pool_size=(2,2)))
  model.add(Dropout(0.25))
  model.add(Flatten())
  model.add(Dense(128,activation='relu'))
  model.add(Dropout(0.5))
  model.add(Dense(10, activation='softmax'))
  return model

In [13]:
import horovod.tensorflow.keras as hvd
import tensorflow.keras.backend as k

def train_hvd(learning_rate,epochs,batch_size,steps_per_epoch):
  hvd.init()
  
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  config.gpu_options.visible_device_list = str(hvd.local_rank())
  
  tf.reset_default_graph()
  with tf.Session(config=config) as sess:
    
    dataset = build_dataset(data_dir + "/part-*", batch_size, steps_per_epoch)
    model = get_model()
    
    # hvd adjust learning rate by num GPUs
    optimizer = keras.optimizers.Adadelta(lr=learning_rate * hvd.size())
    optimizer = hvd.DistributedOptimizer(optimizer)
    
    model.compile(optimizer=optimizer,
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])
    
    callbacks = [
      # hvd:  broadcast init variable states from rank 0 to all other processes.  this ensures repeatability with random weights or a checkpoint restore
      hvd.callbacks.BroadcastGlobalVariablesCallback(0),
      
    ]
    
    # save chkpts only on worker 0 to prevent corruptions
    if hvd.rank() ==0:
      callbacks.append(keras.callbacks.ModelCheckpoint(checkpoint_dir + '/checkpoint-{epoch}', save_weights_only=True))
      callbacks.append(keras.callbacks.TensorBoard(log_dir))
    
    model.fit(dataset,
             callbacks=callbacks,
             epoch=epochs,
             verbose=2,
             steps_per_epoch=steps_per_epoch)
    
    
      
    

In [14]:
dbutils.tensorboard.start(log_dir)

In [15]:
from sparkdl import HorovodRunner
hr = HorovodRunner(np=0)
hr.run(train_hvd,learning_rate=1.0,epochs=5,batch_size=64,steps_per_epoch=200)


In [16]:
import numpy as np
import pandas as pd

from pyspark.sql.functions import pandas_udf, PandasUDFType
from pyspark.sql.types import IntegerType

spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch","1000")

def predict(epoch):
  @pandas_udf(IntegerType(),PandasUDFType.SCALAR)
  def predict_for_epoch(pandas_series):
    image_batch = np.frombuffer(b"".join(pandas_series),'uint8')
    image_batch = image_batch.reshape(-1,28,28,1)
    
    model = get_model()
    
    keras.backend.set_learning_phase(0)
    model.load_weights(checkpoint_dir + "/checkpoint-%s" % epoch)
    
    predictions = model.predict_on_batch(image_batch)
    return pd.Series(predictions.argmax(1))
  return predict_for_epoch

  
    

In [17]:
testData = sc.parallelize(zip(x_test,y_test.tolist()),8) \
  .flatMap(add_rotation).toDF(schema)


In [18]:
from pyspark.sql.functions import sum, when

test_cases = len(x_test) * enhancement_factor

predictions = testData \
  .withColumn("predicted_label_epoch_1",predict(epoch=1)(testData.image.data)) \
  .withColumn("predicted_label_epoch_5",predict(epoch=5)(testData.image.data)) \
  .cache()

accuracy = predictions.agg(
  (sum(when(predictions.predicted_label_epoch_1 == predictions.label,1)) / test_cases).alias("epoch_1_accuracy"),
  (sum(when(predictions.predicted_label_epoch_5 == predictions.label,1)) / test_cases).alias("epoch_5_accuracy"))

display(accuracy)
  
  

In [19]:
bad_predictions = predictions.filter(predictions.label != predictions.predicted_label_epoch_1)

display(bad_predictions)
