-
Notifications
You must be signed in to change notification settings - Fork 70
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
75 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
In this tutorial, we will work through examples of training a simple multi-layer perceptron and then a convolutional neural network (the LeNet architecture) on the [MNIST handwritten digit dataset](http://yann.lecun.com/exdb/mnist/). The code for this tutorial could be found in [`Pkg.dir("MXNet")`/examples/mnist/](https://github.com/dmlc/MXNet.jl/tree/master/examples/mnist). | ||
|
||
# Simple 3-layer MLP | ||
|
||
This is a tiny 3-layer MLP that could be easily trained on CPU. The script starts with | ||
```julia | ||
using MXNet | ||
``` | ||
to load the `MXNet` module. Then we are ready to define the network architecture via the [symbolic API](../user-guide/overview.md#symbols-and-composition). We start with a placeholder `data` symbol, | ||
```julia | ||
data = mx.Variable(:data) | ||
``` | ||
and then cascading fully-connected layers and activation functions: | ||
```julia | ||
fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128) | ||
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu) | ||
fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64) | ||
act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu) | ||
fc3 = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10) | ||
``` | ||
Note each composition we take the previous symbol as the `data` argument, forming a feedforward chain. The architecture looks like | ||
``` | ||
Input --> 128 units (ReLU) --> 64 units (ReLU) --> 10 units | ||
``` | ||
where the last 10 units correspond to the 10 output classes (digits 0,...,9). We then add a final `Softmax` operation to turn the 10-dimensional prediction to proper probability values for the 10 classes: | ||
```julia | ||
mlp = mx.Softmax(data = fc3, name=:softmax) | ||
``` | ||
|
||
After defining the architecture, we are ready to load the MNIST data. MXNet.jl provide built-in data providers for the MNIST dataset, which could automatically download the dataset into `Pkg.dir("MXNet")/data/mnist` if necessary. We wrap the code to construct the data provider into `mnist-data.jl` so that it could be shared by both the MLP example and the LeNet ConvNets example. | ||
```julia | ||
batch_size = 100 | ||
include("mnist-data.jl") | ||
train_provider, eval_provider = get_mnist_providers(batch_size) | ||
``` | ||
If you need to write your own data providers for customized data format, please refer to **TODO**: pointer to data provider API. | ||
|
||
Given the architecture and data, we can instantiate an *estimator* to do the actual training. `mx.FeedForward` is the built-in estimator that is suitable for most feed-forward architectures. When constructing the estimator, we also specify the *context* on which the computation should be carried out. Because this is a really tiny MLP, we will just run on a single CPU device. | ||
```julia | ||
estimator = mx.FeedForward(mlp, context=mx.cpu()) | ||
``` | ||
You can use a `mx.gpu()` or if a list of devices (e.g. `[mx.gpu(0), mx.gpu(1)]`) is provided, data-parallelization will be used automatically. But for this tiny example, using a GPU device might not help. | ||
|
||
The last thing we need to specify is the optimization algorithm (a.k.a. *optimizer*) to use. We use the basic SGD with a fixed learning rate 0.1 and momentum 0.9: | ||
```julia | ||
optimizer = mx.SGD(lr_scheduler=mx.FixedLearningRateScheduler(0.1), | ||
mom_scheduler=mx.FixedMomentumScheduler(0.9), | ||
weight_decay=0.00001) | ||
``` | ||
Now we can do the training. Here the `epoch_stop` parameter specifies that we want to train for 20 epochs. We also supply a `eval_data` to monitor validation accuracy on the validation set. | ||
```julia | ||
mx.fit(estimator, optimizer, train_provider, epoch_stop=20, eval_data=eval_provider) | ||
``` | ||
Here is a sample output | ||
``` | ||
INFO: Start training on [CPU0] | ||
INFO: Initializing parameters... | ||
INFO: Creating KVStore... | ||
INFO: == Epoch 001 ========== | ||
INFO: ## Training summary | ||
INFO: :accuracy = 0.7554 | ||
INFO: time = 1.3165 seconds | ||
INFO: ## Validation summary | ||
INFO: :accuracy = 0.9502 | ||
... | ||
INFO: == Epoch 020 ========== | ||
INFO: ## Training summary | ||
INFO: :accuracy = 0.9949 | ||
INFO: time = 0.9287 seconds | ||
INFO: ## Validation summary | ||
INFO: :accuracy = 0.9775 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters