Skip to content

Commit

Permalink
clean up lenet example
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 22, 2015
1 parent a75a8ef commit 43b2486
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 8 deletions.
14 changes: 11 additions & 3 deletions docs/tutorials/mnist.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ where the last 10 units correspond to the 10 output classes (digits 0,...,9). We
```julia
mlp = mx.Softmax(data = fc3, name=:softmax)
```
As we can see, the MLP is just a chain of layers. For this case, we can also use the `mx.chain` macro. The same architecture above can be defined as
```julia
mlp = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(name=:fc1, num_hidden=128) =>
mx.Activation(name=:relu1, act_type=:relu) =>
mx.FullyConnected(name=:fc2, num_hidden=64) =>
mx.Activation(name=:relu2, act_type=:relu) =>
mx.FullyConnected(name=:fc3, num_hidden=10) =>
mx.Softmax(name=:softmax)
```

After defining the architecture, we are ready to load the MNIST data. MXNet.jl provide built-in data providers for the MNIST dataset, which could automatically download the dataset into `Pkg.dir("MXNet")/data/mnist` if necessary. We wrap the code to construct the data provider into `mnist-data.jl` so that it could be shared by both the MLP example and the LeNet ConvNets example.
```julia
Expand All @@ -43,9 +53,7 @@ You can use a `mx.gpu()` or if a list of devices (e.g. `[mx.gpu(0), mx.gpu(1)]`)

The last thing we need to specify is the optimization algorithm (a.k.a. *optimizer*) to use. We use the basic SGD with a fixed learning rate 0.1 and momentum 0.9:
```julia
optimizer = mx.SGD(lr_scheduler=mx.FixedLearningRateScheduler(0.1),
mom_scheduler=mx.FixedMomentumScheduler(0.9),
weight_decay=0.00001)
optimizer = mx.SGD(lr=0.1, momentum=0.9, weight_decay=0.00001)
```
Now we can do the training. Here the `epoch_stop` parameter specifies that we want to train for 20 epochs. We also supply a `eval_data` to monitor validation accuracy on the validation set.
```julia
Expand Down
7 changes: 2 additions & 5 deletions examples/mnist/lenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,10 @@ train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)

#--------------------------------------------------------------------------------
# fit model
dev = mx.Context(mx.GPU)
estimator = mx.FeedForward(lenet, context=dev)
estimator = mx.FeedForward(lenet, context=mx.gpu())

# optimizer
optimizer = mx.SGD(lr_scheduler=mx.FixedLearningRateScheduler(0.05),
mom_scheduler=mx.FixedMomentumScheduler(0.9),
weight_decay=0.00001)
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)

# fit parameters
mx.fit(estimator, optimizer, train_provider, epoch_stop=20, eval_data=eval_provider)

0 comments on commit 43b2486

Please sign in to comment.