In [0]:
# This tutorial shows you how to solve the Iris classification problem in TensorFlow using Estimators.
from __future__ import absolute_import, division, print_function, unicode_literals

!pip install -q tensorflow==2.0.0-beta0

import tensorflow as tf

import pandas as pd

[K     |████████████████████████████████| 87.9MB 1.4MB/s 
[K     |████████████████████████████████| 501kB 39.5MB/s 
[K     |████████████████████████████████| 3.1MB 44.5MB/s 
[?25h

In [0]:
CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth', 'Species']
SPECIES = ['Setosa', 'Versicolor', 'Virginica']

In [0]:
train_path = tf.keras.utils.get_file(
    "iris_training.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_training.csv")
test_path = tf.keras.utils.get_file(
    "iris_test.csv", "https://storage.googleapis.com/download.tensorflow.org/data/iris_test.csv")

train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)
test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)

In [0]:
train.head()

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth,Species
0,6.4,2.8,5.6,2.2,2
1,5.0,2.3,3.3,1.0,1
2,4.9,2.5,4.5,1.7,2
3,4.9,3.1,1.5,0.1,0
4,5.7,3.8,1.7,0.3,0


In [0]:
train_y = train.pop('Species')
test_y = test.pop('Species')

train.head()

Unnamed: 0,SepalLength,SepalWidth,PetalLength,PetalWidth
0,6.4,2.8,5.6,2.2
1,5.0,2.3,3.3,1.0
2,4.9,2.5,4.5,1.7
3,4.9,3.1,1.5,0.1
4,5.7,3.8,1.7,0.3


In [0]:
# To use Estimators
# 1. Create input functions
# 2. Define the model's feature column
# 3. Instantiate an Estimator, specifying the feature columns and various hyperparameters
# 4. Call one or more methods on the Estimator object, passing the appropriate input function as the source of data

In [0]:
# Input function returns tf.data.Dataset object which outputs the two-element tuple
# 1. Features - A Python dictionary
# - Each key is the name of feature
# - Each value is an array containing all of that feature's value

# 2. Label

def input_evaluation_set():
    features = {'SepalLength': np.array([6.4, 5.5]),
                'SepalWidth': np.array([2.8, 2.3]),
                'PetalLength': np.array([5.6, 3.3]),
                'PetalWidth': np.array([2.2, 1.0])}
    labels = np.array([2, 1])
    return features, labels

In [0]:
def input_fn(features, labels, training=True, batch_size=256):
    """An input function for training or evaluating"""
    # Convert the input to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    
    # Shuffle and repeat if you are in training mode.
    if training:
        dataset = dataset.shuffle(1000).repeat()
       
    return dataset.batch(batch_size)

In [0]:
# Feature columns describe how to use the input
my_feature_columns = []
for key in train.keys():
    my_feature_columns.append(tf.feature_column.numeric_column(key = key))

In [0]:
## Instantiate an estimator

# Build a DNN with 2 hidden layers with 30 and 10 hidden nodes each.
classifier = tf.estimator.DNNClassifier(
    feature_columns = my_feature_columns,
    hidden_units = [30, 10],
    n_classes = 3
)

W0614 07:11:01.857672 139763058550656 estimator.py:1811] Using temporary folder as model directory: /tmp/tmpcc6ipvtq


In [0]:
## Train the model
classifier.train(
    input_fn = lambda: input_fn(train, train_y, training = True),
    steps = 5000
)

W0614 07:14:10.179111 139763058550656 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W0614 07:14:11.263737 139763058550656 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow_estimator/python/estimator/canned/head.py:437: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `tf.cast` instead.
W0614 07:14:11.398242 139763058550656 deprecation.py:506] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/adagrad.py:76: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for 

<tensorflow_estimator.python.estimator.canned.dnn.DNNClassifier at 0x7f1cd98433c8>

In [0]:
eval_result = classifier.evaluate(
    input_fn = lambda: input_fn(test, test_y, training = False))

print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))

W0614 07:19:01.191463 139763058550656 deprecation.py:323] From /usr/local/lib/python3.6/dist-packages/tensorflow/python/training/saver.py:1276: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.



Test set accuracy: 0.967



In [0]:
print(eval_result)

{'accuracy': 0.96666664, 'average_loss': 0.08420571, 'loss': 2.5261712, 'global_step': 5000}


In [0]:
# Generate predictions from the model
expected = ['Setosa', 'Versicolor', 'Virginica']
predict_x = {
    'SepalLength': [5.1, 5.9, 6.9],
    'SepalWidth': [3.3, 3.0, 3.1],
    'PetalLength': [1.7, 4.2, 5.4],
    'PetalWidth': [0.5, 1.5, 2.1],
}

def input_fn(features, batch_size = 256):
    """An input function for prediction."""
    # Convert the inputs to a Dataset without labels.
    return tf.data.Dataset.from_tensor_slices(dict(features)).batch(batch_size)

predictions = classifier.predict(
    input_fn = lambda: input_fn(predict_x))

In [0]:
for pred_dict, expec in zip(predictions, expected):
    class_id = pred_dict['class_ids'][0]
    probability = pred_dict['probabilities'][class_id]
    
    print('Prediction is "{}" ({:.1f}%), expected "{}"'.format(
        SPECIES[class_id], 100 * probability, expec))

Prediction is "Setosa" (100.0%), expected "Setosa"
Prediction is "Versicolor" (100.0%), expected "Versicolor"
Prediction is "Virginica" (100.0%), expected "Virginica"
