Skip to content

Commit

Permalink
test: minor style changes for mlp-test (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 22, 2017
1 parent b0556e6 commit 91a410e
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions examples/mnist/mlp-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,30 @@
# features of MXNet.jl in this example in order to detect regression errors.

module MNISTTest

using MXNet
using Base.Test

include("mnist-data.jl")

function get_mnist_mlp()
mlp = @mx.chain mx.Variable(:data) =>
@mx.chain mx.Variable(:data) =>
mx.FullyConnected(name=:fc1, num_hidden=128) =>
mx.Activation(name=:relu1, act_type=:relu) =>
mx.FullyConnected(name=:fc2, num_hidden=64) =>
mx.Activation(name=:relu2, act_type=:relu) =>
mx.FullyConnected(name=:fc3, num_hidden=10) =>
mx.SoftmaxOutput(name=:softmax)
return mlp
end

function get_mnist_data(batch_size=100)
return get_mnist_providers(batch_size)
end
get_mnist_data(batch_size = 100) = get_mnist_providers(batch_size)

function mnist_fit_and_predict(optimizer, initializer, n_epoch)
mlp = get_mnist_mlp()
train_provider, eval_provider = get_mnist_data()

# setup model
model = mx.FeedForward(mlp, context=mx.cpu())
model = mx.FeedForward(mlp, context = mx.cpu())

# fit parameters
cp_prefix = "mnist-test-cp"
Expand Down Expand Up @@ -73,12 +71,25 @@ function mnist_fit_and_predict(optimizer, initializer, n_epoch)
end

function test_mnist_mlp()
info("MNIST::SGD")
@test mnist_fit_and_predict(mx.SGD(lr=0.1, momentum=0.9), mx.UniformInitializer(0.01), 2) > 90

info("MNIST::ADAM")
@test mnist_fit_and_predict(mx.ADAM(), mx.NormalInitializer(), 2) > 90

info("MNIST::AdaGrad")
@test mnist_fit_and_predict(mx.AdaGrad(), mx.NormalInitializer(), 2) > 90

info("MNIST::AdaDelta")
@test mnist_fit_and_predict(mx.AdaDelta(), mx.NormalInitializer(), 2) > 90

info("MNIST::AdaMax")
@test mnist_fit_and_predict(mx.AdaMax(), mx.NormalInitializer(), 2) > 90

info("MNIST::RMSProp")
@test mnist_fit_and_predict(mx.RMSProp(), mx.NormalInitializer(), 2) > 90

info("MNIST::Nadam")
@test mnist_fit_and_predict(mx.Nadam(), mx.NormalInitializer(), 2) > 90
end

Expand Down

0 comments on commit 91a410e

Please sign in to comment.