# Estimators

Estimators in Tensorflow support safe distributed training loops for graph building, variable initialization, data loading, handling exceptions, checkpoints and summaries for Tensorboard

#### Steps in developing an Estimator model

**1.** Acquiring the data and creation of the data functions

**2.** Creating feature columns

**3.** Instantiating the Estimator

**4.** Evaluating the model's

In [3]:
import tensorflow as tf
import numpy as np

#### 1. Acquire & preprocess the data

In [4]:
mnist = tf.keras.datasets.fashion_mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255., x_test / 255. # Normalize
y_train = y_train.astype('int32')
y_test = y_test.astype('int32')

**Dataset functions**

The `tf.compat.v1.estimator.inputs.numpy_input_fn` method helps batch, repeat or shuffle a dataset

In [10]:
train_input_func = tf.compat.v1.estimator.inputs.numpy_input_fn(
    x = {'x': x_train},
    y = y_train,
    num_epochs=None,
    batch_size=64,
    shuffle=True
)

In [6]:
test_input_func = tf.compat.v1.estimator.inputs.numpy_input_fn(
    x = {'x': x_test},
    y = y_test,
    num_epochs=1,
    batch_size=128,
    shuffle=False
)

#### 2. Create feature columns

- Feature columns are means of passing data to the estimator

In [7]:
feature_column = tf.feature_column.numeric_column(key='x', shape=[28, 28]) # Input Images shape

#### 3. Initialize Estimator

In [8]:
# Build DNN classifier

clf = tf.compat.v1.estimator.DNNClassifier(hidden_units=[256, 32],
                                 feature_columns=[feature_column],
                                 optimizer='Adam',
                                 n_classes=10,
                                 dropout=.1,
                                 model_dir='/tmp/mnist_est',
                                 loss_reduction=tf.compat.v1.losses.Reduction.SUM
                                )

In [12]:
# Train the model

clf.train(input_fn=train_input_func, steps=10000)

W0902 19:30:19.624284 140128253720384 deprecation.py:323] From /root/.virtualenvs/tfs/lib/python3.7/site-packages/tensorflow/python/training/saver.py:1066: get_checkpoint_mtimes (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file utilities to get mtimes.


<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f71e3e7a890>

In [13]:
clf.evaluate(input_fn=test_input_func)

{'accuracy': 0.3614,
 'average_loss': 1.5363648,
 'loss': 194.47656,
 'global_step': 14688}