# This is a basic TensorFlow tutorial on image classification

In [None]:
import tensorflow as tf

## The MNIST dataset

Contains images of handwritten digits (0, 1, 2, ..., 9).

Download the dataset using TF's built-in method.

In [None]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Every MNIST sample has two parts:
an image (vectorized, raster-scanned) of a handwritten digit and a corresponding label.

In [None]:
def show_sample(index):
    image = mnist.train.images[index].reshape(28, 28) # 784 -> 28x28
    label = mnist.train.labels[index]

    import matplotlib.pyplot as plt
    plt.imshow(image, cmap='Greys')
    plt.show()
    plt.clf()
    plt.cla()
    plt.close()
    print('label[%d]: %s' % (index, str(label)))

show_sample(10)
show_sample(24)
show_sample(12)
show_sample(11)
show_sample(18)

## Our classification model

We're going to train a model to look at images and predict what digits they are.

A function $M: \mathbb{R}^{28\times 28}\rightarrow \mathbb{R}^{10}$ outputs a classification score for each input digit.
In other words, $M(\text{image})=\text{a vector of per-class scores}$.
We want that a higher score for class $c$ translates to higher confidence that $c$ is the correct class.

For example, if $M$ outputs
$$
    (0.05, 0.03, 0.82, 0.02, 0.01, 0.02, 0.01, 0.02, 0.01, 0.1)
$$
for an input image, it classifies that image as a $2$.

Let us choose a very simple classification model first:
$$
    M(\mathbf{x})=
    \mathbf{x}\cdot\mathbf{W} + \mathbf{b}
    ,
$$
where $\mathbf{x}\in\mathbb{R}^{784}$ is a vectorized input image, and $\mathbf{W}\in\mathbb{R}^{784\times 10}$ and $\mathbf{b}\in\mathbb{R}^{10}$ are the model parameters. The elements of $M(\mathbf{x})$ are sometimes called **logits**.

In [None]:
# `x` is a batch of input images (each reshaped into a vector)
x = tf.placeholder(tf.float32, [None, 784])
# define the model
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
logits = tf.matmul(x, W) + b

## Learning the model parameters from data

Initially, $\mathbf{W}$ and $\mathbf{b}$ contain random values that will not produce correct classification results.

We have to tune these tensors by minimizing an appropirate loss function that will "measure" the quality of classification.

We will use the **cross entropy criterion**:
$$
    L(\mathbf{x}, c)=
    -\log p_c(\mathbf{x})
    ,
$$
where $p_c(\mathbf{x})$ is the **probability** assigned by the model that $\mathbf{x}$ belongs to class $c$,
$$
    p_c=
    \frac{e^{l_c}}{\sum_{j=1}^{10} e^{l_j}}
    ,
$$
and $(l_0, l_1, \ldots, l_9)=M(\mathbf{x})$ are the logits output by the model.

The derivatives can now be computed by TensorFlow and the model can be tuned with **stochastic gradient descent** ($k=0, 1, 2, \ldots$):
$$
    \mathbf{W}_{k+1}=
    \mathbf{W}_k - \eta\frac{\partial L}{\partial\mathbf{W}_k}
$$
$$
    \mathbf{b}_{k+1}=
    \mathbf{b}_k - \eta\frac{\partial L}{\partial\mathbf{b}_k}
$$

In [None]:
# prepare the loss function (`labels` are denoted as `c` in the text above)
labels = tf.placeholder(tf.float32, [None, 10])
loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
# we will use SGD to learn the model
step = tf.train.GradientDescentOptimizer(1e-4).minimize(loss)

The loss $L$ is usually approximated on a batch of images.
The code above can handle this case as well.
We set the batch size to $100$ is our experiment.

We measure the quality of the model on a separate testing dataset by counting the number of images that it has correctly classified:
$$
    \text{accuracy}=
    \frac{\text{number of correctly classified samples}}{\text{total number of samples}}
$$

In [None]:
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

Befor starting the learning process, we must properly initialize TensorFlow.

In [None]:
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()

Start the learning process.

In [None]:
for k in range(5000):
    batch_xs, batch_labels = mnist.train.next_batch(100)
    sess.run(step, feed_dict={x: batch_xs, labels: batch_labels})
    if k % 200 == 0:
        acc = 100*sess.run(accuracy, feed_dict={x: mnist.test.images, labels: mnist.test.labels})
        print('* iter %d: test set accuracy=%.2f %%' % (k, acc))