# <div class="alert alert-block alert-info" style="border-width:4px">SBrain Keras Training Tutorial </div>

# Introduction

This tutorial will walk you through how SBrain supports keras distributed training. This will showcase an example where keras models are training against Mnist data.
 

### Lets try it out


Before we begin, it would be good to copy this notebook and rename it with your name at the end, since we don't want multiple people editing the same notebook at the same time, causing reloading issues.

#### Imports

Below are the necessary imports.

In [None]:
from sbrain.learning.experiment import *
from sbrain.dataset.dataset import *

#### Unique Names

Just as in the other notebooks, we need to make all the names unique. 

In [None]:
import time

user_name = "albin"

def uniquify(name):
    import time
    should_uniquify = True
    if should_uniquify:
        return name + user_name + str(time.time()).replace(".","")
    else:
        return name

#### Input Function

For this example we use an input function where we download the Mnist data from the internet. 

In [None]:
def input_function(mode, batch_size, params):
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf

    local_dir = "/workspace/shared-dir/sample-notebooks/demo-data/learning/mnist/"

    if mode == "train":
        mnist = input_data.read_data_sets(local_dir, one_hot=True)

        dataset = tf.data.Dataset.from_tensor_slices(({"data" : mnist.train.images}, mnist.train.labels))
        dataset = dataset.shuffle(1000).batch(batch_size).repeat()
        return dataset
    else:
        mnist = input_data.read_data_sets(local_dir, one_hot=True)

        dataset = tf.data.Dataset.from_tensor_slices(({"data" : mnist.test.images}, mnist.test.labels))
        dataset = dataset.batch(batch_size)
        return dataset

### The model function

Here we define a keras model function, where we create a model layer by layer and compile it with all the loss, optimizer and metrics that are necessary. More details on how to create Keras models are given at <a href="https://keras.io">Keras Model Api</a>. SBrain expects a compiled Keras model to be returned from the model function. 

In [None]:
def keras_model_function(params):
    import keras
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Activation, InputLayer, Input
    from keras.optimizers import SGD
    layers = []
    layers.append(Dense(64, activation='relu', input_shape=[784], name="data"))
    layers.append(Dropout(0.5))
    layers.append(Dense(64, activation='relu'))
    layers.append(Dropout(0.5))
    layers.append(Dense(10, activation='softmax'))

    model = Sequential(layers=layers)

    decay = params["decay"]
    momentum = params["momentum"]
    sgd = SGD(lr=0.01, decay=decay, momentum=momentum, nesterov=True)
    model.compile(loss='categorical_crossentropy',
                  optimizer=sgd,
                  metrics=['accuracy'])
    return model

#### SBrain Keras Estimator

SBrain provides the same estimator API to create an SBrain asset wrapped around the keras model function, as below. This can be saved against any name and can be reused across many runs of the same model with different hyperparameters.

In [None]:
estimator = Estimator.NewClassificationEstimator(keras_model_fn=keras_model_function)
name = uniquify("MyMnistKerasEstimator")
estimator = Estimator.create(name, "KerasEstimator", estimator)

#### Hyperparameters and RunConfig

Here we define the hyperparameters and run configuration. 

In [None]:
hyper_parameters = HParams(iterations=5000, batch_size=128, decay=1e-6, momentum=0.9)
rc = RunConfig(no_of_ps=1, no_of_workers=2, summary_save_frequency=30, run_eval=False, use_gpu=False)

Now let us create an experiment as in other notebooks, and run it. 

In [None]:
name = uniquify("MyMnistKerasExperiment")
experiment = Experiment.run(experiment_name=name,
                     description="Mnist Keras Experiment",
                     estimator=estimator,
                     hyper_parameters=hyper_parameters,
                     run_config=rc,
                     dataset_version_split=None,
                     input_function=input_function)

job = experiment.get_single_job()
print("tensorboard url")
print(job.get_tensorboard_url())

print(job.has_finished())

job.wait_until_finish()

Lets wait until it is done. 

In [None]:
job.wait_until_finish()

Finally print the metrics

In [None]:
print("Is the job success?? : {}".format(job.is_success()))
print("Model metrics..")
print(job.get_model().model_metrics)