# Train a model on MNIST

In this tutorial, we'll walk you through a machine-learning 101 example of training a simple convolutional model on the MNIST dataset. We'll use Keras to build a model, but will use DeepChem's dataset module to make dataset handling convenient. This tutorial should also provide a good example of how to mix and match TensorFlow tooling with DeepChem. You'll definitely want to do some of this in your models.

To start, we'll pull in the raw data for MNIST from one of the TensorFlow tutorials.

In [1]:
from tensorflow.examples.tutorials.mnist import input_data

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use urllib or similar directly.
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py fr

In [3]:
import deepchem as dc
import tensorflow as tf
from tensorflow.keras.layers import Reshape, Conv2D, Flatten, Dense, Softmax



Let's wrap these datasets with DeepChem's `NumpyDataset`.

In [5]:
train = dc.data.NumpyDataset(mnist.train.images, mnist.train.labels)
valid = dc.data.NumpyDataset(mnist.validation.images, mnist.validation.labels)

We're going to train a simple convolutional neural network with 2 convolutional layers and two dense layers on this data. We'll use the `tf.keras.Sequential` class to simplify the construction of a model of this type. We're going to wrap this model in `dc.models.KerasModel` which will enable this model to train on DeepChem dataset objects.

In [6]:
keras_model = tf.keras.Sequential([
    Reshape((28, 28, 1)),
    Conv2D(filters=32, kernel_size=5, activation=tf.nn.relu),
    Conv2D(filters=64, kernel_size=5, activation=tf.nn.relu),
    Flatten(),
    Dense(1024, activation=tf.nn.relu),
    Dense(10),
    Softmax()
])
model = dc.models.KerasModel(keras_model, dc.models.losses.CategoricalCrossEntropy())

Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Let's now fit our simple model on the data! We'll run for just 2 epochs so the training doesn't take too long.

In [7]:
model.fit(train, nb_epoch=2)








0.0

Let's now try evaluating this model on our validation set. We'll print out the AUC scores for our model.

In [8]:
from sklearn.metrics import roc_curve, auc
import numpy as np

print("Validation")
prediction = np.squeeze(model.predict_on_batch(valid.X))

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(10):
    fpr[i], tpr[i], thresh = roc_curve(valid.y[:, i], prediction[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])
    print("class %s:auc=%s" % (i, roc_auc[i]))

Validation
class 0:auc=0.9999676757825577
class 1:auc=0.9999515618501131
class 2:auc=0.9999504963085688
class 3:auc=0.9999450932986371
class 4:auc=0.9999581375391153
class 5:auc=0.9999444905341222
class 6:auc=0.9999179236548019
class 7:auc=0.9997867211440246
class 8:auc=0.9998344904691312
class 9:auc=0.9995919236762745
