Skip to content

Commit

Permalink
add mx.chain macro
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 22, 2015
1 parent 478568c commit d15bc77
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,13 @@


Julia wrapper of [MXNet](https://github.com/dmlc/mxnet).

```julia
mlp = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(name=:fc1, num_hidden=128) =>
mx.Activation(name=:relu1, act_type=:relu) =>
mx.FullyConnected(name=:fc2, num_hidden=64) =>
mx.Activation(name=:relu2, act_type=:relu) =>
mx.FullyConnected(name=:fc3, num_hidden=10) =>
mx.Softmax(name=:softmax)
```
27 changes: 20 additions & 7 deletions examples/mnist/mlp.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,26 @@
using MXNet

#--------------------------------------------------------------------------------
# define MLP
data = mx.Variable(:data)
fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
fc3 = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10)
mlp = mx.Softmax(data = fc3, name=:softmax)
# the following two ways are equivalent

#-- Option 1: explicit composition
# data = mx.Variable(:data)
# fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
# act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
# fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
# act2 = mx.Activation(data = fc2, name=:relu2, act_type=:relu)
# fc3 = mx.FullyConnected(data = act2, name=:fc3, num_hidden=10)
# mlp = mx.Softmax(data = fc3, name=:softmax)

#-- Option 2: using the mx.chain macro
mlp = @mx.chain mx.Variable(:data) =>
mx.FullyConnected(name=:fc1, num_hidden=128) =>
mx.Activation(name=:relu1, act_type=:relu) =>
mx.FullyConnected(name=:fc2, num_hidden=64) =>
mx.Activation(name=:relu2, act_type=:relu) =>
mx.FullyConnected(name=:fc3, num_hidden=10) =>
mx.Softmax(name=:softmax)

# data provider
batch_size = 100
Expand Down
29 changes: 29 additions & 0 deletions src/symbol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,32 @@ function _import_atomic_symbol_creators()
_define_atomic_symbol_creator(creator_hdr)
end
end

################################################################################
# Utility macros to chain up symbols
################################################################################
macro chain(layers)
exprs = []
last_layer = nothing
function _chain_layer(layer, last_layer)
if isa(last_layer, Void)
layer
else
@assert(isa(layer, Expr) && layer.head == :call, "Do not know how to chain up $layer")
return Expr(:call, layer.args[1], last_layer, layer.args[2:end]...)
end
end
while true
if layers.head == :(=>)
new_layer = gensym()
push!(exprs, :($new_layer = $(_chain_layer(layers.args[1], last_layer))))
last_layer = new_layer
layers = layers.args[2]
else
push!(exprs, _chain_layer(layers, last_layer))
break
end
end
return Expr(:block, exprs...)
end

0 comments on commit d15bc77

Please sign in to comment.