Permalink
Browse files

workaround for #55

  • Loading branch information...
pluskid committed Jan 16, 2016
1 parent 8edb94b commit 4c52eb8eb76c6239b0a447a5624c7b6c7c3586b6
Showing with 19 additions and 0 deletions.
  1. +7 −0 src/ndarray.jl
  2. +12 −0 test/unittest/ndarray.jl
View
@@ -940,6 +940,13 @@ function _import_ndarray_functions(;gen_docs=false)
_use_vars = Expr(:ref, :MX_handle, [symbol("in$i") for i=1:n_used_vars]...)
_scalars = Expr(:ref, :MX_float, [symbol("sca$i") for i=1:n_scalars]...)
_mut_vars = Expr(:ref, :MX_handle, [symbol("out$i") for i=1:n_mutate_vars]...)
# XXX: hacky way of solving the problem that the arguments of `dot` should be swapped
# See https://github.com/dmlc/MXNet.jl/issues/55
if func_name == :dot
_use_vars.args[2:end] = flipdim(_use_vars.args[2:end], 1)
end
stmt_call = Expr(:call, :_invoke_mxfunction, func_handle, _use_vars, _scalars, _mut_vars)
if n_mutate_vars == 1
stmt_ret = :(return out1)
View
@@ -260,6 +260,17 @@ function test_nd_as_jl()
@test reldiff(copy(z)[:,2:end], copy(x)[:,2:end]) < 1e-6
end
function test_dot()
dims1 = (2, 3)
dims2 = (3, 8)
info("NDArray::dot")
x = mx.zeros(dims1)
y = mx.zeros(dims2)
z = mx.dot(x, y)
@test size(z) == (2, 8)
end
################################################################################
# Run tests
@@ -276,5 +287,6 @@ test_saveload()
test_clip()
test_sqrt()
test_nd_as_jl()
test_dot()
end

0 comments on commit 4c52eb8

Please sign in to comment.