# Tutorial Part 2: Learning MNIST Digit Classifiers

In the previous tutorial, we learned some basics of how to load data into DeepChem and how to use the basic DeepChem objects to load and manipulate this data. In this tutorial, you'll put the parts together and learn how to train a basic image classification model in DeepChem. You might ask, why are we bothering to learn this material in DeepChem? Part of the reason is that image processing is an increasingly important part of AI for the life sciences. So learning how to train image processing models will be very useful for using some of the more advanced DeepChem features.

The MNIST dataset contains handwritten digits along with their human annotated labels. The learning challenge for this dataset is to train a model that maps the digit image to its true label. MNIST has been a standard benchmark for machine learning for decades at this point. 

![MNIST](mnist_examples.png)

## Setup

We recommend running this tutorial on Google colab. You'll need to run the following cell of installation commands on Colab to get your environment set up. If you'd rather run the tutorial locally, make sure you don't run these commands (since they'll download and install a new Anaconda python setup)

In [None]:
!wget -c https://repo.anaconda.com/archive/Anaconda3-2019.10-Linux-x86_64.sh
!chmod +x Anaconda3-2019.10-Linux-x86_64.sh
!bash ./Anaconda3-2019.10-Linux-x86_64.sh -b -f -p /usr/local
!conda install -y -c deepchem -c rdkit -c conda-forge -c omnia deepchem-gpu=2.3.0
import sys
sys.path.append('/usr/local/lib/python3.7/site-packages/')
import deepchem as dc

In [1]:
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
# TODO: This is deprecated. Let's replace with a DeepChem native loader for maintainability.
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
import deepchem as dc
import tensorflow as tf
from tensorflow.keras.layers import Reshape, Conv2D, Flatten, Dense, Softmax

In [4]:
train = dc.data.NumpyDataset(mnist.train.images, mnist.train.labels)
valid = dc.data.NumpyDataset(mnist.validation.images, mnist.validation.labels)

In [5]:
keras_model = tf.keras.Sequential([
    Reshape((28, 28, 1)),
    Conv2D(filters=32, kernel_size=5, activation=tf.nn.relu),
    Conv2D(filters=64, kernel_size=5, activation=tf.nn.relu),
    Flatten(),
    Dense(1024, activation=tf.nn.relu),
    Dense(10),
    Softmax()
])
model = dc.models.KerasModel(keras_model, dc.models.losses.CategoricalCrossEntropy())

In [6]:
model.fit(train, nb_epoch=2)

0.0

In [7]:
from sklearn.metrics import roc_curve, auc
import numpy as np

print("Validation")
prediction = np.squeeze(model.predict_on_batch(valid.X))

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
    fpr[i], tpr[i], thresh = roc_curve(valid.y[:, i], prediction[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print("class %s:auc=%s" % (i, roc_auc[i]))

Validation
class 0:auc=0.9998836328172079
class 1:auc=0.9999571662641497
class 2:auc=0.9998310516219043
class 3:auc=0.9999563446718672
class 4:auc=0.9999418111793702
class 5:auc=0.9995639983771051
class 6:auc=0.9998478260194437
class 7:auc=0.9998357507660879
class 8:auc=0.9999599342922393
class 9:auc=0.9998551553268534
