##### Copyright 2018 The TensorFlow Authors.

In [0]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [0]:
#@title MIT License
#
# Copyright (c) 2017 François Chollet
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.

# Welcome to `xla.compile()` tutorial

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/xla/jit#turning_on_jit_compilation"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.sandbox.google.com/github/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/compiler/xla/g3doc/tutorials/xla_compile.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

xla.compile() is a new experimental API that compiles part or all of a model with [XLA](https://www.tensorflow.org/extend/xla/).

Please run all code blocks in order.

In [0]:
import tensorflow as tf

Imports XLA library, which includes xla.compile() experimental API.

In [0]:
from tensorflow.contrib.compiler import xla

Define some necessary constants and prepare MNIST dataset.

In [0]:
# Size of each input image, 28 x 28 pixels
IMAGE_SIZE = 28 * 28
# Number of distinct number labels, [0..9]
NUM_CLASSES = 10
# Number of examples in each training batch (step)
TRAIN_BATCH_SIZE = 100
# Number of training steps to run
TRAIN_STEPS = 1000

In [0]:
# Loads MNIST dataset.
train, test = tf.keras.datasets.mnist.load_data()
train_ds = tf.data.Dataset.from_tensor_slices(train).batch(TRAIN_BATCH_SIZE).repeat()
test_ds = tf.data.Dataset.from_tensor_slices(test).batch(TRAIN_BATCH_SIZE)

iterator = tf.data.Iterator.from_structure(train_ds.output_types, train_ds.output_shapes)
images, labels = iterator.get_next()
images = tf.reshape(images, [-1, IMAGE_SIZE])
images, labels = tf.cast(images, tf.float32), tf.cast(labels, tf.int64)

## Defines build_mnist_model function to construct model

Following code block contains a function that constructs a simple model with one dense layer, including both forward and backward propagation.

When called, it returns two values. `y` is a `tf.Tensor` representing predicted probability of each target class, `train_step` is a `tf.Operation` that increments `global_step` and applies variable update.

In [0]:
def build_mnist_model(x, y_):
  y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)

  cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

  return y, train_step

## Uses xla.compile with build_mnist_model function to enable XLA

Following code block wraps the model with xla.compile(), which allows the target function with provided inputs to be executed by XLA.

In [0]:
[y] = xla.compile(build_mnist_model, inputs=[images, labels])

When compiling the graph, XLA replaces all the graph nodes constructed in the target function with a few XLA ops.

xla.compile does not return any
`tf.Operation` nodes that can be executed independently from the generated XLA ops. Instead, returned `tf.Operation` nodes from the target function are added as control dependencies of all returned `tf.Tensor` values. This triggers execution of the `tf.Operation` nodes when the returned tensors are evaluated.

In pseudo-code, xla.compile's implementation looks as follows:

---
```
# Ask Tensorflow to execute code in XLA-friendly manner

y, train_step = build_mnist_model(images, labels)
with tf.control_dependencies([train_step]):
  y = tf.identity(y)

# Ask Tensorflow to STOP executing code in XLA-friendly manner
```
---

xla.compile() always returns a list of `tf.Tensor`'s (even if there is only one-element).

If you were to print the constructed graph now, you will see that it is not much different from a normal Tensorflow graph and you won't be able to find XLA ops mentioned before. This is because the actual compilation happens later when you try to execute the graph with `sess.run()`.  At that time, Tensorflow triggers a series of graph rewrite passes that actually generate XLA ops, which compiles and executes computation when all inputs are ready.

## Trains and tests the model

In [0]:
# Creates session and initialize all variables.
# xla.compile() doesn't work with Keras model.fit() API or TF eager mode yet.
sess = tf.Session()
sess.run(tf.global_variables_initializer())

Following code block trains model.

Note that evaluating `y` also triggers its control dependency node `train_step`, which updates model variables.

In [0]:
# Feeds training dataset
sess.run(iterator.make_initializer(train_ds))

# Runs TRAIN_STEPS steps
for i in range(TRAIN_STEPS):
  sess.run(y)
print("Model trained for %s steps." % TRAIN_STEPS)

In [0]:
# Tests trained model

# Feeds testing dataset
sess.run(iterator.make_initializer(test_ds))

# Calculates accuracy
correct_prediction = tf.equal(tf.argmax(y, 1), labels)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print("Prediction accuracy after training: %s" % sess.run(accuracy))

In [0]:
# Cleans up session
sess.close()