# Training and Inference Module
We modularized commonly used codes for training and inference in the module (or mod for short) package. This package provides intermediate-level and high-level interface for executing predefined networks.

## Jupyter Scala kernel
Add mxnet scala jar which is created as a part of MXNet Scala package installation in classpath as follows:

**Note**: Process to add this jar in your scala kernel classpath can differ according to the scala kernel you are using.

We have used [jupyter-scala kernel](https://github.com/alexarchambault/jupyter-scala) for creating this notebook.

```
classpath.addPath(<path_to_jar>)

e.g
classpath.addPath("mxnet-full_2.11-osx-x86_64-cpu-0.1.2-SNAPSHOT.jar")
```

## Basic Usage
### Preliminary
In this tutorial, we will use a simple multilayer perception for 10 classes.

In [2]:
import ml.dmlc.mxnet._
import ml.dmlc.mxnet.module.{FitParams, Module}

val data = Symbol.Variable("data")
val fc1 = Symbol.FullyConnected(name = "fc1")()(Map("data" -> data, "num_hidden" -> 64))
val act1 = Symbol.Activation(name = "relu1")()(Map("data" -> fc1, "act_type" -> "relu"))
val fc2 = Symbol.FullyConnected(name = "fc2")()(Map("data" -> act1, "num_hidden" -> 10))
val softmax = Symbol.SoftmaxOutput(name = "softmax")()(Map("data" -> fc2))

log4j:WARN No appenders could be found for logger (MXNetJVM).
log4j:WARN Please initialize the log4j system properly.
log4j:WARN See http://logging.apache.org/log4j/1.2/faq.html#noconfig for more info.


[32mimport [36mml.dmlc.mxnet._[0m
[32mimport [36mml.dmlc.mxnet.module.{FitParams, Module}[0m
[36mdata[0m: [32mml[0m.[32mdmlc[0m.[32mmxnet[0m.[32mSymbol[0m = ml.dmlc.mxnet.Symbol@153695b4
[36mfc1[0m: [32mml[0m.[32mdmlc[0m.[32mmxnet[0m.[32mSymbol[0m = ml.dmlc.mxnet.Symbol@782359b1
[36mact1[0m: [32mml[0m.[32mdmlc[0m.[32mmxnet[0m.[32mSymbol[0m = ml.dmlc.mxnet.Symbol@69728e46
[36mfc2[0m: [32mml[0m.[32mdmlc[0m.[32mmxnet[0m.[32mSymbol[0m = ml.dmlc.mxnet.Symbol@6d9120e2
[36msoftmax[0m: [32mml[0m.[32mdmlc[0m.[32mmxnet[0m.[32mSymbol[0m = ml.dmlc.mxnet.Symbol@659455de

### Create Module
The most widely used module class is Module, which wraps a Symbol and one or more Executors.

We construct a module by specify

- symbol : the network Symbol
- context : the device (or a list of devices) for execution
- data_names : the list of data variable names
- label_names : the list of label variable names

One can refer to data.ipynb for more explanations about the last two arguments. Here we have only one data named data, and one label, with the name softmax_label, which is automatically named for us following the name softmax we specified for the SoftmaxOutput operator.

In [5]:
import ml.dmlc.mxnet.optimizer.SGD

val mod = new Module(softmax, contexts=Context.cpu(), dataNames=Array("data"), labelNames=Array("softmax_label"))


[32mimport [36mml.dmlc.mxnet.optimizer.SGD[0m
[36mmod[0m: [32mModule[0m = ml.dmlc.mxnet.module.Module@3110b89

Create a DataIterator. Using Mnist data

In [6]:
val batchSize=2

val trainIter = IO.MNISTIter(Map(
        "image" -> ("data/train-images-idx3-ubyte"),
        "label" -> ("data/train-labels-idx1-ubyte"),
        "label_name" -> "softmax_label",
        "input_shape" -> "(784,)",
        "batch_size" -> batchSize.toString,
        "shuffle" -> "True",
        "flat" -> "True", "silent" -> "False", "seed" -> "10"))
val evalIter = IO.MNISTIter(Map(
        "image" -> ("data/t10k-images-idx3-ubyte"),
        "label" -> ("data/t10k-labels-idx1-ubyte"),
        "label_name" -> "softmax_label",
        "input_shape" -> "(784,)",
        "batch_size" -> batchSize.toString,
        "flat" -> "True", "silent" -> "False"))

[36mbatchSize[0m: [32mInt[0m = [32m2[0m
[36mtrainIter[0m: [32mDataIter[0m = non-empty iterator
[36mevalIter[0m: [32mDataIter[0m = non-empty iterator

### Train, Predict, and Evaluate
Modules provide high-level APIs for training, predicting and evaluating. To fit a module, simply call the fit function with some DataIters.

In [7]:
mod.fit(trainIter, 
        evalData=scala.Option(evalIter),
        fitParams = new FitParams().setOptimizer(new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f)),
        numEpoch=5)




To predict with a module, simply call predict() with a DataIter. It will collect and return all the prediction results.


In [8]:
val y = mod.predict(evalIter)
y.size

[36my[0m: [32mIndexedSeq[0m[[32mNDArray[0m] = [33mArrayBuffer[0m(
  ml.dmlc.mxnet.NDArray@e4216872,
  ml.dmlc.mxnet.NDArray@b6c053f0,
  ml.dmlc.mxnet.NDArray@f3e68a11,
  ml.dmlc.mxnet.NDArray@b04c5734,
  ml.dmlc.mxnet.NDArray@e4635691,
  ml.dmlc.mxnet.NDArray@38f232b,
  ml.dmlc.mxnet.NDArray@f77fe955,
  ml.dmlc.mxnet.NDArray@eef4e2b3,
  ml.dmlc.mxnet.NDArray@e42c116c,
  ml.dmlc.mxnet.NDArray@acf7250f,
  ml.dmlc.mxnet.NDArray@e538781,
  ml.dmlc.mxnet.NDArray@199fdaef,
  ml.dmlc.mxnet.NDArray@cb3d9293,
  ml.dmlc.mxnet.NDArray@eaf6f77c,
  ml.dmlc.mxnet.NDArray@ca9ff00,
  ml.dmlc.mxnet.NDArray@c406c5ef,
  ml.dmlc.mxnet.NDArray@b670b4d8,
  ml.dmlc.mxnet.NDArray@d9194c50,
  ml.dmlc.mxnet.NDArray@2dd299,
[33m...[0m
[36mres7_1[0m: [32mInt[0m = [32m5000[0m

Another convenient API for prediction in the case where the prediction results might be too large to fit in the memory is `predictEveryBatch`:

In [9]:
import org.slf4j.LoggerFactory

private val logger = LoggerFactory.getLogger("mnist")   
val preds = mod.predictEveryBatch(evalIter)

[32mimport [36morg.slf4j.LoggerFactory[0m
[36mpreds[0m: [32mIndexedSeq[0m[[32mIndexedSeq[0m[[32mNDArray[0m]] = [33mArrayBuffer[0m(
  [33mVector[0m(ml.dmlc.mxnet.NDArray@bd4d9298),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@ca6e5090),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@3168f3d),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@ca102191),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@cdc7d946),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@c1c349f4),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@c8a05432),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@de269f8b),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@374bbd4),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@ae8dfe32),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@c0b6a5dd),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@e5b6dacb),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@1f207ca7),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@f46b7260),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@cb95d452),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@a95649d4),
  [33mVector[0m(ml.dmlc.mxnet.NDArray@ed

In [10]:
// perform prediction and calculate accuracy manually
    evalIter.reset()
    var accSum = 0.0f
    var accCnt = 0
    var i = 0
    while (evalIter.hasNext) {
              //println("hi")

      val batch = evalIter.next()
      val predLabel: Array[Int] = NDArray.argmax_channel(preds(i)(0)).toArray.map(_.toInt)
      val label = batch.label(0).toArray.map(_.toInt)
      accSum += (predLabel zip label).map { case (py, y) =>
        if (py == y) 1 else 0
      }.sum
      accCnt += predLabel.length
      val (name, value) = mod.score(evalIter, new Accuracy).get
      println("batch " + i + " accuracy " + value)
      i += 1
    }

batch 0 accuracy 0.1135


[36maccSum[0m: [32mFloat[0m = [32m1.0F[0m
[36maccCnt[0m: [32mInt[0m = [32m2[0m
[36mi[0m: [32mInt[0m = [32m1[0m

In [14]:
val acc = mod.score(evalIter, new Accuracy)
val (n,v) = acc.get
(n,v)

[36macc[0m: [32mEvalMetric[0m = ml.dmlc.mxnet.Accuracy@7055caf9
[36mn[0m: [32mString[0m = [32m"accuracy"[0m
[36mv[0m: [32mFloat[0m = [32m0.1135F[0m
[36mres13_2[0m: ([32mString[0m, [32mFloat[0m) = [33m[0m([32m"accuracy"[0m, [32m0.1135F[0m)

### Save and Load
We can save the module parameters in each training epoch by calling `setEpochEndCallback` method for `FitParams` object.

In [15]:
// construct a callback function to save checkpoints
val modelPrefix: String = "mx mlp"
//val mod = new Module(softmax)
val metric = new Accuracy()

//val epoch: Int = 1
for (epoch <- 0 until 5) {
 //   val checkpoint = mod.saveCheckpoint(modelPrefix, epoch, saveOptStates = true)
    while (trainIter.hasNext) {
        val batch = trainIter.next()
        mod.forward(batch)
        mod.updateMetric(metric, batch.label)
        mod.backward()
        mod.update()
      }
// saveOptStates = true means save optimizer states
      val checkpoint = mod.saveCheckpoint(modelPrefix, epoch, saveOptStates = true)

      val (name, value) = metric.get
      metric.reset()
      trainIter.reset()
}


[36mmodelPrefix[0m: [32mString[0m = [32m"mx mlp"[0m
[36mmetric[0m: [32mAccuracy[0m = ml.dmlc.mxnet.Accuracy@1f90e6f6

To load the saved module parameters, call the `loadCheckpoint` function. You can specify cpu/gpu you want to use in Context and also workLoadList which helps in distributing work load on different cpus/gpus. 

`loadCheckpoint` function creates a module from previously saved checkpoint.

In [16]:
// Epoch to load
val loadModelEpoch = 2
// loadOptimizerStates = true only when checkpoint was saved with saveOptStates=True
val mod = Module.loadCheckpoint(modelPrefix, loadModelEpoch, loadOptimizerStates = true)


[36mloadModelEpoch[0m: [32mInt[0m = [32m2[0m
[36mmod[0m: [32mModule[0m = ml.dmlc.mxnet.module.Module@7aaade12

To initialize parameters, Bind the symbols to construct executors first with `bind` method. Then, initialize the parameters and auxiliary states by calling `initParams()` method.

In [23]:
mod.bind(dataShapes = trainIter.provideData, labelShapes = Some(trainIter.provideLabel))
mod.initParams()



Get current parameters using `getParams` method.

In [24]:
val (argParams, auxParams) = mod.getParams

[36margParams[0m: [32mMap[0m[[32mString[0m, [32mNDArray[0m] = [33mMap[0m(
  [32m"fc1_weight"[0m -> ml.dmlc.mxnet.NDArray@96840c2c,
  [32m"fc2_bias"[0m -> ml.dmlc.mxnet.NDArray@367f5b1c,
  [32m"fc2_weight"[0m -> ml.dmlc.mxnet.NDArray@34041eee,
  [32m"fc1_bias"[0m -> ml.dmlc.mxnet.NDArray@ea1ecde3
)
[36mauxParams[0m: [32mMap[0m[[32mString[0m, [32mNDArray[0m] = [33mMap[0m()

Now, assign parameter and aux state values using `setParams` method.

In [25]:
mod.setParams(argParams, auxParams)




If we just want to resume training from a saved checkpoint, instead of calling setParams(), we can directly call fit(), passing the loaded parameters, so that fit() knows to start from those parameters instead of initializing from random. We also set the beginEpoch so that so that fit() knows we are resuming from a previous saved epoch.

In [26]:
val beginEpoch = 4
mod.fit(trainIter, 
        evalData=scala.Option(evalIter),
        fitParams=new FitParams().setArgParams(argParams).
        setAuxParams(auxParams).setBeginEpoch(beginEpoch).
        setOptimizer(new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f)),
        numEpoch=5)

[36mbeginEpoch[0m: [32mInt[0m = [32m4[0m

## Module as a computation "machine"
We already seen how to module for basic training and inference. Now we are going to show a more flexiable usage of module.

A module represents a computation component. The design purpose of a module is that it abstract a computation “machine”, that accpets Symbol programs and data, and then we can run forward, backward, update parameters, etc.

We aim to make the APIs easy and flexible to use, especially in the case when we need to use imperative API to work with multiple modules (e.g. stochastic depth network).

A module has several states:

- **Initial state**. Memory is not allocated yet, not ready for computation yet.
- **Binded**. Shapes for inputs, outputs, and parameters are all known, memory allocated, ready for computation.
- **Parameter initialized**. For modules with parameters, doing computation before initializing the parameters might result in undefined outputs.
- **Optimizer installed**. An optimizer can be installed to a module. After this, the parameters of the module can be updated according to the optimizer after gradients are computed (forward-backward).

The following codes implement a simplified fit(). Here we used other components including initializer, optimizer, and metric, which are explained in other notebooks.

In [27]:
// initial state
val mod = new Module(softmax)

// bind, tell the module the data and label shapes, so
// that memory could be allocated on the devices for computation
mod.bind(dataShapes=trainIter.provideData, labelShapes=Some(trainIter.provideLabel))

// init parameters
mod.initParams(initializer=new Xavier(magnitude = 2f))

// init optimizer
mod.initOptimizer("local", new SGD(learningRate = 0.1f, momentum = 0.9f, wd = 0.0001f))

// use accuracy as the metric
val metric = new Accuracy

// train one epoch, i.e. going over the data iter one pass
while (trainIter.hasNext) {
    val batch = trainIter.next()
    mod.forward(batch)                     // compute predictions
    mod.updateMetric(metric, batch.label)  // accumulate prediction accuracy
    mod.backward()                         // compute gradients
    mod.update()                           // update parameters using SGD
}

// training accuracy
val (name, value) = metric.get
(name, value)


[36mmod[0m: [32mModule[0m = ml.dmlc.mxnet.module.Module@27cbfc5
[36mmetric[0m: [32mAccuracy[0m = ml.dmlc.mxnet.Accuracy@273f6196
[36mname[0m: [32mString[0m = [32m"accuracy"[0m
[36mvalue[0m: [32mFloat[0m = [32m0.10195F[0m
[36mres26_7[0m: ([32mString[0m, [32mFloat[0m) = [33m[0m([32m"accuracy"[0m, [32m0.10195F[0m)

Beside the operations, a module provides a lot of useful information.

basic names:
- **dataNames**: list of string indicating the names of the required data.
- **outputNames**: list of string indicating the names of the outputs.

state information
- **binded**: bool, indicating whether the memory buffers needed for computation has been allocated.
- **forTraining**: whether the module is binded for training (if binded).
- **paramsInitialized**: bool, indicating whether the parameters of this modules has been initialized.
- **optimizerInitialized**: bool, indicating whether an optimizer is defined and initialized.
- **inputsNeedGrad**: bool, indicating whether gradients with respect to the input data is needed. Might be useful when implementing composition of modules.

input/output information
- **dataShapes**: a list of (name, shape). In theory, since the memory is allocated, we could directly provide the data arrays. But in the case of data parallelization, the data arrays might not be of the same shape as viewed from the external world.
- **labelShapes**: a list of (name, shape). This might be [] if the module does not need labels (e.g. it does not contains a loss function at the top), or a module is not binded for training.
- **outputShapes**: a list of (name, shape) for outputs of the module.

parameters (for modules with parameters)
- **getParams()**: return a tuple (argParams, auxParams). Each of those is a dictionary of name to NDArray mapping. Those NDArray always lives on CPU. The actual parameters used for computing might live on other devices (GPUs), this function will retrieve (a copy of) the latest parameters.
- **getOutputs()**: get outputs of the previous forward operation.
- **getInputGrads()**: get the gradients with respect to the inputs computed in the previous backward operation.

setup
- **bind()**: prepare environment for computation.
- **initOptimizer()**: install optimizer for parameter updating.

computation
- **forward(dataBatch)**: forward operation.
- **backward(outGrads=None)**: backward operation.
- **update()**: update parameters according to installed optimizer.
- **getOutputs()**: get outputs of the previous forward operation.
- **getInputGrads()**: get the gradients with respect to the inputs computed in the previous backward operation.
- **updateMetric(metric, labels)**: update performance metric for the previous forward computed results.


In [28]:
(mod.dataShapes, mod.labelShapes, mod.outputShapes)
mod.getParams

[36mres27_0[0m: ([32mIndexedSeq[0m[[32mDataDesc[0m], [32mIndexedSeq[0m[[32mDataDesc[0m], [32mIndexedSeq[0m[([32mString[0m, [32mShape[0m)]) = [33m[0m(
  [33mVector[0m(DataDesc[data,(2,784),Float32,NCHW]),
  [33mVector[0m(DataDesc[softmax_label,(2),Float32,NCHW]),
  [33mArrayBuffer[0m([33m[0m([32m"softmax_output"[0m, (2,10)))
)
[36mres27_1[0m: ([32mMap[0m[[32mString[0m, [32mNDArray[0m], [32mMap[0m[[32mString[0m, [32mNDArray[0m]) = [33m[0m(
  [33mMap[0m(
    [32m"fc1_weight"[0m -> ml.dmlc.mxnet.NDArray@a073e593,
    [32m"fc1_bias"[0m -> ml.dmlc.mxnet.NDArray@15f069eb,
    [32m"fc2_weight"[0m -> ml.dmlc.mxnet.NDArray@fb482814,
    [32m"fc2_bias"[0m -> ml.dmlc.mxnet.NDArray@56ac2896
  ),
  [33mMap[0m()
)

## More on Modules
Module simplifies the implementation of new modules. For example
- [SequentialModule](https://github.com/dmlc/mxnet/blob/master/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/SequentialModule.scala) can chain multiple modules together

See also [examples](https://github.com/dmlc/mxnet/tree/master/scala-package/examples/src/main/scala/ml/dmlc/mxnet/examples/module) for a list of code examples using the module API.

## Implementation
The module is implemented in scala, located at [scala/mxnet/module](https://github.com/dmlc/mxnet/tree/master/scala-package/core/src/main/scala/ml/dmlc/mxnet/module)

## Futher Readings
[module API](http://mxnet.io/api/scala/docs/index.html#ml.dmlc.mxnet.module.Module)