# Tutorial Part 2: Learning MNIST Digit Classifiers

In the previous tutorial, we learned some basics of how to load data into DeepChem and how to use the basic DeepChem objects to load and manipulate this data. In this tutorial, you'll put the parts together and learn how to train a basic image classification model in DeepChem. You might ask, why are we bothering to learn this material in DeepChem? Part of the reason is that image processing is an increasingly important part of AI for the life sciences. So learning how to train image processing models will be very useful for using some of the more advanced DeepChem features.

The MNIST dataset contains handwritten digits along with their human annotated labels. The learning challenge for this dataset is to train a model that maps the digit image to its true label. MNIST has been a standard benchmark for machine learning for decades at this point. 

![MNIST](https://github.com/deepchem/deepchem/blob/master/examples/tutorials/mnist_examples.png?raw=1)

## Colab

This tutorial and the rest in this sequence are designed to be done in Google colab. If you'd like to open this notebook in colab, you can use the following link.

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deepchem/deepchem/blob/master/examples/tutorials/02_Learning_MNIST_Digit_Classifiers.ipynb)

## Setup

We recommend running this tutorial on Google colab. You'll need to run the following cell of installation commands on Colab to get your environment set up. If you'd rather run the tutorial locally, make sure you don't run these commands (since they'll download and install a new Anaconda python setup)

In [None]:
!curl -Lo conda_installer.py https://raw.githubusercontent.com/deepchem/deepchem/master/scripts/colab_install.py
import conda_installer
conda_installer.install()
!/root/miniconda/bin/conda info -e

In [None]:
!pip install --pre deepchem
import deepchem
deepchem.__version__

First let's import the libraries we will be using and load the data (which comes bundled with Tensorflow).

In [1]:
import deepchem as dc
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Reshape, Conv2D, Flatten, Dense

mnist = tf.keras.datasets.mnist.load_data(path='mnist.npz')
train_images = mnist[0][0].reshape((-1, 28, 28, 1))/255
valid_images = mnist[1][0].reshape((-1, 28, 28, 1))/255
train = dc.data.NumpyDataset(train_images, mnist[0][1])
valid = dc.data.NumpyDataset(valid_images, mnist[1][1])

Now create the model.  We use two convolutional layers followed by two dense layers.  The final layer outputs ten numbers for each sample.  These correspond to the ten possible digits.

How does the model know how to interpret the output?  That is determined by the loss function.  We specify `SparseSoftmaxCrossEntropy`.  This is a very convenient class that implements a common case:

1. Each label is an integer which is interpreted as a class index (i.e. which of the ten digits this sample is a drawing of).
2. The outputs are passed through a softmax function, and the result is interpreted as a probability distribution over those same classes.

The model learns to produce a large output for the correct class, and small outputs for all other classes.

In [2]:
keras_model = tf.keras.Sequential([
    Conv2D(filters=32, kernel_size=5, activation=tf.nn.relu),
    Conv2D(filters=64, kernel_size=5, activation=tf.nn.relu),
    Flatten(),
    Dense(1024, activation=tf.nn.relu),
    Dense(10),
])
model = dc.models.KerasModel(keras_model, dc.models.losses.SparseSoftmaxCrossEntropy())

Fit the model on the training set.

In [3]:
model.fit(train, nb_epoch=2)

0.031744494438171386

Let's see how well it works.  We ask the model to predict the class of every sample in the validation set.  Remember there are ten outputs for each sample.  We use `argmax()` to identify the largest one, which corresponds to the predicted class.

In [4]:
prediction = np.argmax(model.predict_on_batch(valid.X), axis=1)
score = dc.metrics.accuracy_score(prediction, valid.y)
print('Validation set accuracy: ', score)

Validation set accuracy:  0.9891


It gets about 99% of samples correct.  Not too bad for such a simple model!

# Congratulations! Time to join the Community!

Congratulations on completing this tutorial notebook! If you enjoyed working through the tutorial, and want to continue working with DeepChem, we encourage you to finish the rest of the tutorials in this series. You can also help the DeepChem community in the following ways:

## Star DeepChem on [GitHub](https://github.com/deepchem/deepchem)
This helps build awareness of the DeepChem project and the tools for open source drug discovery that we're trying to build.

## Join the DeepChem Gitter
The DeepChem [Gitter](https://gitter.im/deepchem/Lobby) hosts a number of scientists, developers, and enthusiasts interested in deep learning for the life sciences. Join the conversation!