Skip to content

Commit

Permalink
base: merge _julia_to_mx_param into dump_mx_param (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Oct 25, 2017
1 parent a79e33c commit 71f2d40
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 26 deletions.
16 changes: 9 additions & 7 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,16 @@ end
#
# TODO: find a better solution in case this cause issues in the future.
################################################################################
function dump_mx_param(val :: Any)
string(val)
end
function dump_mx_param{N,T<:Integer}(shape :: NTuple{N, T})
string(tuple(flipdim([shape...],1)...))
end
dump_mx_param(val::Any) = string(val)
dump_mx_param(val::Float64) = @sprintf("%.16e", val)
dump_mx_param(val::Float32) = @sprintf("%.8e", val)
dump_mx_param(val::Float16) = @sprintf("%.4e", val)
dump_mx_param{N, T<:Integer}(shape::NTuple{N, T}) =
string(tuple(flipdim([shape...], 1)...))

"""A convenient macro copied from Mocha.jl that could be used to define structs

"""
A convenient macro copied from Mocha.jl that could be used to define structs
with default values and type checks. For example
```julia
@defstruct MyStruct Any (
Expand Down
17 changes: 2 additions & 15 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1012,19 +1012,6 @@ end
ACCEPT_EMPTY_MUTATE_TARGET = (1 << 2)
)

function _julia_to_mx_param(val :: Any)
string(val)
end
function _julia_to_mx_param(val :: Float64)
@sprintf("%.16e", val)
end
function _julia_to_mx_param(val :: Float32)
@sprintf("%.8e", val)
end
function _julia_to_mx_param(val :: Float16)
@sprintf("%.4e", val)
end

# Import corresponding math functions from base so the automatically defined libmxnet
# functions can overload them
import Base: sqrt
Expand Down Expand Up @@ -1086,7 +1073,7 @@ function _get_ndarray_function_def(name :: String)
# and in libmxnet.
# See https://github.com/dmlc/MXNet.jl/pull/123
if $name == "transpose"
kwargs = Any[key != :axes ? (key, arg) : (key, reverse(map(i->length(arg)-i, arg))) for (key, arg) in kwargs]
kwargs = Any[key != :axes ? (key, arg) : (key, map(i->length(arg)-i, arg)) for (key, arg) in kwargs]
end

if length(output_vars) > 0
Expand All @@ -1100,7 +1087,7 @@ function _get_ndarray_function_def(name :: String)
num_outputs_p = [convert(Cint, num_outputs)]

kw_keys_str = String[string(x[1]) for x in kwargs]
kw_vals_str = String[_julia_to_mx_param(x[2]) for x in kwargs]
kw_vals_str = String[dump_mx_param(x[2]) for x in kwargs]

#op_handle = _get_cached_libmx_op_handle($(QuoteNode(name)))
op_handle = _get_cached_libmx_op_handle($(name))
Expand Down
6 changes: 2 additions & 4 deletions src/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@ julia> mx.rand(0, 1, mx.zeros(2, 2)) |> copy
```
"""
function rand!(low::Real, high::Real, out::NDArray)
# XXX: note we reverse shape because julia and libmx has different dim order
_random_uniform(NDArray, low=low, high=high, shape=reverse(size(out)), out=out)
_random_uniform(NDArray, low=low, high=high, shape=size(out), out=out)
end

"""
Expand Down Expand Up @@ -46,8 +45,7 @@ end
Draw random samples from a normal (Gaussian) distribution.
"""
function randn!(mean::Real, stdvar::Real, out::NDArray)
# XXX: note we reverse shape because julia and libmx has different dim order
_random_normal(NDArray, loc=mean, scale=stdvar, shape=reverse(size(out)), out=out)
_random_normal(NDArray, loc=mean, scale=stdvar, shape=size(out), out=out)
end

"""
Expand Down

0 comments on commit 71f2d40

Please sign in to comment.