Skip to content

Commit

Permalink
doc for mnist
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 22, 2015
1 parent 210aba8 commit 478568c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
72 changes: 72 additions & 0 deletions docs/tutorials/mnist.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
In this tutorial, we will work through examples of training a simple multi-layer perceptron and then a convolutional neural network (the LeNet architecture) on the [MNIST handwritten digit dataset](http://yann.lecun.com/exdb/mnist/). The code for this tutorial could be found in [`Pkg.dir("MXNet")`/examples/mnist/](https://github.com/dmlc/MXNet.jl/tree/master/examples/mnist).

# Simple 3-layer MLP

This is a tiny 3-layer MLP that could be easily trained on CPU. The script starts with
```julia
using MXNet
```
to load the `MXNet` module. Then we are ready to define the network architecture via the [symbolic API](../user-guide/overview.md#symbols-and-composition). We start with a placeholder `data` symbol,
```julia
data = mx.Variable(:data)
```
and then cascading fully-connected layers and activation functions:
```julia
fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
fc3 = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10)
```
Note each composition we take the previous symbol as the `data` argument, forming a feedforward chain. The architecture looks like
```
Input --> 128 units (ReLU) --> 64 units (ReLU) --> 10 units
```
where the last 10 units correspond to the 10 output classes (digits 0,...,9). We then add a final `Softmax` operation to turn the 10-dimensional prediction to proper probability values for the 10 classes:
```julia
mlp = mx.Softmax(data = fc3, name=:softmax)
```

After defining the architecture, we are ready to load the MNIST data. MXNet.jl provide built-in data providers for the MNIST dataset, which could automatically download the dataset into `Pkg.dir("MXNet")/data/mnist` if necessary. We wrap the code to construct the data provider into `mnist-data.jl` so that it could be shared by both the MLP example and the LeNet ConvNets example.
```julia
batch_size = 100
include("mnist-data.jl")
train_provider, eval_provider = get_mnist_providers(batch_size)
```
If you need to write your own data providers for customized data format, please refer to **TODO**: pointer to data provider API.

Given the architecture and data, we can instantiate an *estimator* to do the actual training. `mx.FeedForward` is the built-in estimator that is suitable for most feed-forward architectures. When constructing the estimator, we also specify the *context* on which the computation should be carried out. Because this is a really tiny MLP, we will just run on a single CPU device.
```julia
estimator = mx.FeedForward(mlp, context=mx.cpu())
```
You can use a `mx.gpu()` or if a list of devices (e.g. `[mx.gpu(0), mx.gpu(1)]`) is provided, data-parallelization will be used automatically. But for this tiny example, using a GPU device might not help.

The last thing we need to specify is the optimization algorithm (a.k.a. *optimizer*) to use. We use the basic SGD with a fixed learning rate 0.1 and momentum 0.9:
```julia
optimizer = mx.SGD(lr_scheduler=mx.FixedLearningRateScheduler(0.1),
mom_scheduler=mx.FixedMomentumScheduler(0.9),
weight_decay=0.00001)
```
Now we can do the training. Here the `epoch_stop` parameter specifies that we want to train for 20 epochs. We also supply a `eval_data` to monitor validation accuracy on the validation set.
```julia
mx.fit(estimator, optimizer, train_provider, epoch_stop=20, eval_data=eval_provider)
```
Here is a sample output
```
INFO: Start training on [CPU0]
INFO: Initializing parameters...
INFO: Creating KVStore...
INFO: == Epoch 001 ==========
INFO: ## Training summary
INFO: :accuracy = 0.7554
INFO: time = 1.3165 seconds
INFO: ## Validation summary
INFO: :accuracy = 0.9502
...
INFO: == Epoch 020 ==========
INFO: ## Training summary
INFO: :accuracy = 0.9949
INFO: time = 0.9287 seconds
INFO: ## Validation summary
INFO: :accuracy = 0.9775
```
2 changes: 1 addition & 1 deletion examples/mnist/mlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ include("mnist-data.jl")
train_provider, eval_provider = get_mnist_providers(batch_size)

# setup estimator
estimator = mx.FeedForward(mlp, context=mx.Context(mx.CPU))
estimator = mx.FeedForward(mlp, context=mx.cpu())

# optimizer
optimizer = mx.SGD(lr_scheduler=mx.FixedLearningRateScheduler(0.1),
Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pages:
- User Guide:
- 'Installation Guide' : 'user-guide/install.md'
- 'Overview' : 'user-guide/overview.md'
- Tutorials:
- 'MNIST': 'tutorials/mnist.md'
- API Documentation:
- 'ndarray': 'api/ndarray.md'
- 'symbol': 'api/symbol.md'

0 comments on commit 478568c

Please sign in to comment.