# <div class="alert alert-block alert-info" style="border-width:4px">Keras Early Stopping </div>

# Introduction

This notebook walks through how to run Keras Early stopping.
 

In [None]:
from sbrain.learning.experiment import *
from sbrain.dataset.dataset import *

In [None]:
import time

user_name = "albin"

def uniquify(name):
    import time
    should_uniquify = True
    if should_uniquify:
        return name + user_name + str(time.time()).replace(".","")
    else:
        return name

#### Input Function

For this example we use an input function where we download the Mnist data from the internet. 

In [None]:
def input_function(mode, batch_size, params):
    from tensorflow.examples.tutorials.mnist import input_data
    import tensorflow as tf

    local_dir = "/workspace/shared-dir/sample-notebooks/demo-data/learning/mnist/"

    if mode == "train":
        mnist = input_data.read_data_sets(local_dir, one_hot=True)

        dataset = tf.data.Dataset.from_tensor_slices(({"data" : mnist.train.images}, mnist.train.labels))
        dataset = dataset.shuffle(1000).batch(batch_size).repeat()
        return dataset
    else:
        mnist = input_data.read_data_sets(local_dir, one_hot=True)

        dataset = tf.data.Dataset.from_tensor_slices(({"data" : mnist.test.images}, mnist.test.labels))
        dataset = dataset.batch(batch_size)
        return dataset

### KerasExecutionHints in model_function

The main difference in the API is that, we expose a new class called KerasExecutionHints which is an SBrain class. This can be used to pass arguments for early stopping and best model for keras.

 ```python

early_stop_settings = sbrain.learning.KerasEarlyStopSettings(
    metric_name="accuracy", 
    threshold=0.8,
    higher_is_better=True,
    check_every_n_seconds=5
)
exec_hints = sbrain.learning.KerasExecutionHints(early_stop_settings=early_stop_settings)

```

- **metric_name** - Name of the metric to look for early stopping, example - accuracy, loss etc. The name should match the metric added to graph
- **threshold** - Threshold value beyond which the training should stop.
- **higher_is_better** - Define the "best" notion, whether higher is better or lower is better. True for accuracy, False for loss.
- **check_every_n_seconds** - How often should the training code check whether the threshold has been crossed.


Model function in keras can return this additional hints parameter, in case you need to use the early stopping feature. Keras model function can work without this additional parameter also. SBrain checks whether the return type is a tuple and if so, interprets the first elements as a keras model and the second as execution hints.

In [None]:
def keras_model_function(params):
    import keras
    from keras.models import Sequential
    from keras.layers import Dense, Dropout, Activation, InputLayer, Input
    from keras.optimizers import SGD
    layers = []
    layers.append(Dense(64, activation='relu', input_shape=[784], name="data"))
    layers.append(Dropout(0.5))
    layers.append(Dense(64, activation='relu'))
    layers.append(Dropout(0.5))
    layers.append(Dense(10, activation='softmax'))

    model = Sequential(layers=layers)

    decay = params["decay"]
    momentum = params["momentum"]
    sgd = SGD(lr=0.01, decay=decay, momentum=momentum, nesterov=True)
    model.compile(loss='categorical_crossentropy',
                  optimizer=sgd,
                  metrics=['accuracy'])
    
    keras_early_stop_settings = KerasEarlyStopSettings(metric_name="accuracy", threshold=0.5, higher_is_better=True, check_every_n_seconds=4)
    keras_exec_hints = KerasExecutionHints(early_stop_settings=keras_early_stop_settings)
    return model, keras_exec_hints

#### Rest of it

No change for the below code

In [None]:
estimator = Estimator.NewClassificationEstimator(keras_model_fn=keras_model_function)
name = uniquify("MyMnistKerasEstimator")
estimator = Estimator.create(name, "KerasEstimator", estimator)

hyper_parameters = HParams(iterations=20000, batch_size=128, decay=1e-6, momentum=0.9)
rc = RunConfig(no_of_ps=1, no_of_workers=2, summary_save_frequency=30, 
               run_eval=True, use_gpu=False, checkpoint_frequency_in_steps=500)

In [None]:
name = uniquify("MyMnistKerasExperiment")
experiment = Experiment.run(experiment_name=name,
                     description="Mnist Keras Experiment",
                     estimator=estimator,
                     hyper_parameters=hyper_parameters,
                     run_config=rc,
                     dataset_version_split=None,
                     input_function=input_function)

job = experiment.get_single_job()
print("tensorboard url")
print(job.get_tensorboard_url())

print(job.has_finished())

job.wait_until_finish()

Lets wait until it is done. 

In [None]:
job.wait_until_finish()

Finally print the metrics

In [None]:
print("Is the job success?? : {}".format(job.is_success()))
print("Model metrics..")
print(job.get_model().model_metrics)