# TensorFlow Estimators - a very quick guide

For simplicity we will not implement a model or input pipeline assuming that we have already written functions that take care of these matters. First we need a model function that builds the model graph and returns different outputs based on whether we are training, validating or testing the model. This should have the following signature. 

In [None]:
def model_fn(features, labels, mode, params):

The arguments `features` and `labels` contains the inputs and ground truth. You can pass in multiple inputs or labels using iterables. We will discuss below `mode` which refers to whether training, validating or testing the model. You can use the optional argument `params` to configure your model.

We need to build the model graph and obtain the loss within before returning different outputs based on `mode`. Let us say another function `get_model` builds the model graph. Then the first step would be call this function within `model_fn`. Notice how we have made use of `params` to pass in additional arguments to be used internally within `model_fn`.

In [None]:
    logits = get_model(features, **params['model_kwargs'])

Assuming that we have softmax classification output, let us also get the predicted probabilities and labels

In [None]:
    probs_pred = tf.nn.softmax(logits)
    labels_pred = tf.argmax(logits, axis=-1)

Now we will add a cross entropy loss op

In [None]:
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)

Finally (assuming our classes are balanced) let us add an op to find the accuracy of our predictions

In [None]:
    accuracy = tf.metrics.accuracy(labels=labels, predictions=labels_pred)

The parameter `mode` can be `TRAIN`, `EVAL` or `PREDICT` or and`model_fn` needs to handle each of these. These corresponding to training, validation and testing. For each mode you return an instance of `tf.EstimatorSpec`. For each mode `tf.EstimatorSpec` has different required arguments:

- For `TRAIN` you need to pass in a `train_op` and a `loss`

In [None]:
    if mode == tf.estimator.ModeKeys.TRAIN:
        optimizer = tf.train.GradientDescentOptimizer(learning_rate==0.001)
        train_op = optimizer.minimize(
            loss=loss,
            global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

For `EVAL` you need to pass in a `loss`. In the example below we also pass in `eval_metric_ops` which should be a `dict`. Any metrics you pass in this manner will be displayed on TensorBoard along with the `loss`. 

In [None]:
    if mode == tf.estimator.ModeKey.EVAL:
        return tf.estimator.EstimatorSpec(mode=mode,
                loss=loss, eval_metric_ops={'accuracy': accuracy})

For `PREDICT` you need to pass in `predictions`

In [None]:
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, 
                                          predictions=labels_pred)

Now we can build an classifier using `tf.Estimator`. Assume that `model_kwargs` is a `dict` that we have defined elsewhere whilst `model_dir` is the location at which 

In [None]:
classifier = tf.estimator.Estimator(
    model_fn = model_fn,
    model_dir = './'
    params = {
        "model_kwargs": model_kwargs
    })

In order to run the model, we also need to pass in the `input_fn` parameter for each mode. This function should output the a tuple (`features`, `labels`) pair containing a mini-batch of inputs and labels required by `model_fn`. For example (again assuming we have a `'train_input_fn`, `valid_input_fn` and `test_input_fn` already). Note that we can specify the number of `steps` for which to train but this is not necessary if code within `train_input_fn` generates a `tf.errors.OutOfRange error` (as might be the case use `tf.data.Dataset`) or `StopIteration` exception as this will be the signal to stop. 

In [None]:
classifier.train(input_fn=train_input_fn)
#classifier.train(input_fn=train_input_fn, steps=10000)

classifier.eval(input_fn=valid_input_fn)

classifier.predict(input_fn=test_input_fn)