Skip to content

Commit

Permalink
API variable -> Variable; group -> Group
Browse files Browse the repository at this point in the history
  • Loading branch information
pluskid committed Oct 22, 2015
1 parent 3a26544 commit 0c3d066
Show file tree
Hide file tree
Showing 8 changed files with 14 additions and 11 deletions.
2 changes: 1 addition & 1 deletion examples/cifar10/cifar10.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ end

#--------------------------------------------------------------------------------
# Actual architecture
data = mx.variable(:data)
data = mx.Variable(:data)
conv1 = conv_factory(data, 96, (3,3); pad=(1,1), act_type=:relu)
in3a = simple_factory(conv1, 32, 32)
in3b = simple_factory(in3a, 32, 48)
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/lenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using MXNet
# define lenet

# input
data = mx.variable(:data)
data = mx.Variable(:data)

# first conv
conv1 = mx.Convolution(data=data, kernel=(5,5), num_filter=20)
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mlp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using MXNet

# define MLP
data = mx.variable(:data)
data = mx.Variable(:data)
fc1 = mx.FullyConnected(data = data, name=:fc1, num_hidden=128)
act1 = mx.Activation(data = fc1, name=:relu1, act_type=:relu)
fc2 = mx.FullyConnected(data = act1, name=:fc2, num_hidden=64)
Expand Down
3 changes: 3 additions & 0 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ function .*(arg0 :: NDArray, arg :: Union{Real, NDArray})
ret = copy(arg0, context(arg0))
mul_to!(ret, arg)
end
function .*(arg0 :: Real, arg :: NDArray)
.*(arg, arg0)
end
# unlike *, we only allow type Real in arguments, because array-array * operator
# means matrix multiplication in Julia
function *(arg0 :: NDArray, arg :: Real)
Expand Down
4 changes: 2 additions & 2 deletions src/symbol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,14 @@ function get_internals(self :: Symbol)
end

"Create a symbolic variable with the given name"
function variable(name :: Union{Base.Symbol, AbstractString})
function Variable(name :: Union{Base.Symbol, AbstractString})
hdr_ref = Ref{MX_handle}(0)
@mxcall(:MXSymbolCreateVariable, (char_p, Ref{MX_handle}), name, hdr_ref)
Symbol(MX_SymbolHandle(hdr_ref[]))
end

"Create a symbol that groups symbols together"
function group(symbols :: Symbol...)
function Group(symbols :: Symbol...)
handles = MX_handle[symbols...]
ref_hdr = Ref{MX_handle}(0)
@mxcall(:MXSymbolCreateGroup, (MX_uint, Ptr{MX_handle}, Ref{MX_handle}),
Expand Down
2 changes: 1 addition & 1 deletion test/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function rand_dims(max_ndim=6)
end

function mlp2()
data = mx.variable(:data)
data = mx.Variable(:data)
out = mx.FullyConnected(data=data, name=:fc1, num_hidden=1000)
out = mx.Activation(data=out, act_type=:relu)
out = mx.FullyConnected(data=out, name=:fc2, num_hidden=10)
Expand Down
4 changes: 2 additions & 2 deletions test/unittest/bind.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ function test_arithmetic(uf, gf)
shape = rand_dims()
info("Bind::arithmetic::$uf::dims = $shape")

lhs = mx.variable(:lhs)
rhs = mx.variable(:rhs)
lhs = mx.Variable(:lhs)
rhs = mx.Variable(:rhs)
ret = uf(lhs, rhs)
@test mx.list_arguments(ret) == [:lhs, :rhs]

Expand Down
6 changes: 3 additions & 3 deletions test/unittest/symbol.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ end
function test_internal()
info("Symbol::internal")

data = mx.variable(:data)
data = mx.Variable(:data)
oldfc = mx.FullyConnected(data=data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(data=oldfc, name=:fc2, num_hidden=100)

Expand All @@ -33,7 +33,7 @@ end
function test_compose()
info("Symbol::compose")

data = mx.variable(:data)
data = mx.Variable(:data)
net1 = mx.FullyConnected(data=data, name=:fc1, num_hidden=10)
net1 = mx.FullyConnected(data=net1, name=:fc2, num_hidden=100)

Expand All @@ -42,7 +42,7 @@ function test_compose()
net2 = mx.FullyConnected(data=net2, name=:fc4, num_hidden=20)

composed = net2(fc3_data=net1, name=:composed)
multi_out = mx.group(composed, net1)
multi_out = mx.Group(composed, net1)
@test mx.list_outputs(multi_out) == [:composed_output, :fc2_output]
end

Expand Down

0 comments on commit 0c3d066

Please sign in to comment.