# Handwritten Digit Classfication using Recurrent Neural Network

In this example, we are going to use the MNIST dataset to train a multi-layer feed foward neural network. MNIST is a simple computer vision dataset of handwritten digits. It has 60,000 training examles and 10,000 test examples. "It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting." For more details, please checkout the website [MNIST](http://yann.lecun.com/exdb/mnist/)

In [1]:
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkContext

import com.intel.analytics.bigdl.utils._
import com.intel.analytics.bigdl.utils.{Engine, LoggerFilter, T, Table}
import com.intel.analytics.bigdl.dataset.DataSet
import com.intel.analytics.bigdl.dataset.image.{BytesToGreyImg, GreyImgNormalizer, GreyImgToBatch, GreyImgToSample}
import com.intel.analytics.bigdl.models.lenet.Utils._
import com.intel.analytics.bigdl.nn._
import com.intel.analytics.bigdl.optim._
import com.intel.analytics.bigdl.optim.{Adam, Top1Accuracy, Trigger}
import com.intel.analytics.bigdl.visualization.{TrainSummary, ValidationSummary}
import com.intel.analytics.bigdl.tensor.Tensor
import com.intel.analytics.bigdl.numeric.NumericFloat

Engine.init

## 1. Load MNIST dataset

In [2]:
import java.nio.ByteBuffer
import java.nio.file.{Files, Path, Paths}

import com.intel.analytics.bigdl.dataset.ByteRecord
import com.intel.analytics.bigdl.utils.File
import scopt.OptionParser

def load(featureFile: String, labelFile: String): Array[ByteRecord] = {
    val featureBuffer = ByteBuffer.wrap(Files.readAllBytes(Paths.get(featureFile)))
    val labelBuffer = ByteBuffer.wrap(Files.readAllBytes(Paths.get(labelFile)))
    
    val labelMagicNumber = labelBuffer.getInt()
    require(labelMagicNumber == 2049)
    val featureMagicNumber = featureBuffer.getInt()
    require(featureMagicNumber == 2051)

    val labelCount = labelBuffer.getInt()
    val featureCount = featureBuffer.getInt()
    require(labelCount == featureCount)

    val rowNum = featureBuffer.getInt()
    val colNum = featureBuffer.getInt()

    val result = new Array[ByteRecord](featureCount)
    var i = 0
    while (i < featureCount) {
      val img = new Array[Byte]((rowNum * colNum))
      var y = 0
      while (y < rowNum) {
        var x = 0
        while (x < colNum) {
          img(x + y * colNum) = featureBuffer.get()
          x += 1
        }
        y += 1
      }
      result(i) = ByteRecord(img, labelBuffer.get().toFloat + 1.0f)
      i += 1
    }

    result
}

Then we need to set paths of data. Please edit paths if they are changed.

In [3]:
val trainData = "../../datasets/mnist/train-images-idx3-ubyte"
val trainLabel = "../../datasets/mnist/train-labels-idx1-ubyte"
val validationData = "../../datasets/mnist/t10k-images-idx3-ubyte"
val validationLabel = "../../datasets/mnist/t10k-labels-idx1-ubyte"

## 2. Recurent Neural Network Model Setup

This time we will use a recurrent neural network (aka RNN) to classify handwritten digits. You can checkout this blog to get a detailed understanding of recurrent neural networks and LSTMs in particular.

In [4]:
//Parameters
val batchSize = 64
val maxEpochs = 5

//Network Parameters
val nInput = 28 //MNIST data input (img shape: 28*28)
val nHidden = 128 // hidden layer num of features
val nClasses = 10  //MNIST total classes (0-9 digits)

Then the data set should be created and the model needs to be established.

In [5]:
val trainSet = 
    DataSet.array(load(trainData, trainLabel), sc) -> BytesToGreyImg(28, 28) -> GreyImgNormalizer(trainMean, trainStd) -> GreyImgToBatch(batchSize)
val validationSet = 
    DataSet.array(load(validationData, validationLabel), sc) -> BytesToGreyImg(28, 28) -> GreyImgNormalizer(testMean, testStd) -> GreyImgToBatch(batchSize)

In [6]:
val recurrent = Recurrent().add(RnnCell(nInput, nHidden, Tanh()))
val rnnModel = Sequential().add(InferReshape(Array(-1, nInput), true)).add(recurrent).add(Select(2, -1)).add(Linear(nHidden, nClasses))

## 3. Optimizer Setup

In [7]:
val optimizer = Optimizer(model = rnnModel, dataset = trainSet, criterion = CrossEntropyCriterion[Float]())
optimizer.setValidation(trigger = Trigger.everyEpoch, dataset = validationSet, vMethods = Array(new Top1Accuracy))
optimizer.setOptimMethod(new Adam())
optimizer.setEndWhen(Trigger.maxEpoch(maxEpochs))

com.intel.analytics.bigdl.optim.DistriOptimizer@cecdb7c

The following is to create training and validation summary.

In [8]:
import java.text.SimpleDateFormat
import java.util.Calendar
val today = Calendar.getInstance
val formatDate = new SimpleDateFormat("yyyyMMdd-hhmmss")
val name = "rnn-" + formatDate.format(today.getTime()).toString()
val trainSummary = TrainSummary(logDir="/tmp/bigdl_summaries", appName=name)
trainSummary.setSummaryTrigger("Parameters", Trigger.severalIteration(50))
val valSummary = ValidationSummary(logDir="/tmp/bigdl_summaries", appName=name)
optimizer.setTrainSummary(trainSummary)
optimizer.setValidationSummary(valSummary)
printf("saving logs to %s", name)

saving logs to rnn-20170925-024042

In [9]:
// Boot training process
val trainedModel = optimizer.optimize()
print("Optimization Done.")

can't find locality partition for partition 0 Partition locations are (ArrayBuffer(172.168.2.109)) Candidate partition locations are
(0,List()).
Optimization Done.

In [10]:
val rddData = sc.parallelize(load(validationData, validationLabel), batchSize)
val transformer = BytesToGreyImg(28, 28) -> GreyImgNormalizer(testMean, testStd) -> GreyImgToSample()
val evaluationSet = transformer(rddData)
        
val result = trainedModel.evaluate(evaluationSet, Array(new Top1Accuracy[Float]), Some(batchSize))

result.foreach(r => println(s"${r._2} is ${r._1}"))

Top1Accuracy is Accuracy(correct: 9509, count: 10000, accuracy: 0.9509)


In [11]:
val predictions = trainedModel.predict(evaluationSet)
val preLabels = predictions.take(8).map(_.toTensor.max(1)._2.valueAt(1)).mkString(",")
val labels = evaluationSet.take(8).map(_.label.valueAt(1)).mkString(",")
println(preLabels)
println(labels)

5.0,2.0,5.0,10.0,5.0,2.0,5.0,10.0
8.0,3.0,2.0,1.0,5.0,2.0,5.0,10.0


In [12]:
val input = T.array(evaluationSet.take(8).map(_.feature))
val layer = JoinTable(2, 2)
val output = layer.forward(input)

In [13]:
import java.awt.Color
import java.awt.image.BufferedImage
import java.io.File
import javax.imageio

val imageSize = output.size()
val dim1 = imageSize(1)
val dim2 = imageSize(0)
println(dim1)
println(dim2)
val img = new BufferedImage(dim1, dim2, BufferedImage.TYPE_BYTE_GRAY)
for (i <- 0 until dim1)
    for (j <- 0 until dim2) {
        val value = (output.valueAt(j+1, i+1) + 0.5).toInt
        img.setRGB(i, j, new Color(value, value, value).getRGB)
    }
ImageIO.write(img, "PNG", new File("outimg.png")); 

224
28


Name: Compile Error
Message: <console>:80: error: not found: value ImageIO
       ImageIO.write(img, "PNG", new File("outimg.png"));
       ^
StackTrace: 

In [14]:
%%html
<body>
<img src="outimg.png"/>
</body>

## 4. Draw the performance curve

In [18]:
import vegas._
import vegas.render.HTMLRenderer._
import vegas.DSL._

In [21]:
val loss = trainSummary.readScalar("Loss")
val lossXY = loss.map(_ ._1).zip(loss.map(_ ._2)).toSeq
Vegas(description = "The Loss curve.", width = 700.0, height = 300.0).
  withXY(lossXY).
  encodeX("x", Quantitative, bin = Bin(maxbins = 500.0), title = "Iteration").
  encodeY("y", Quantitative, title = "Loss").
  mark(Line).
  show

In [22]:
val top1 = valSummary.readScalar("Top1Accuracy")
val top1XY = top1.map(_ ._1).zip(top1.map(_ ._2)).toSeq
Vegas(description = "The Top1Accuracy curve.", width = 700.0, height = 300.0).
  withXY(top1XY).
  encodeX("x", Quantitative, bin = Bin(maxbins = 500.0), title = "Iteration").
  encodeY("y", Quantitative, bin = Bin(base = 0.9), title = "Top1Accuracy").
  mark(Line).
  show

Finally, the Spark should be stopped.

In [20]:
sc.stop()