Skip to content

Commit

Permalink
sym: fix printing of symbol generated via get_internals (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 20, 2017
1 parent 010ea3c commit 1f50a14
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 7 additions & 1 deletion src/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,13 @@ function get_name(self :: mx.SymbolicNode)
success = Ref(0)
@mxcall(:MXSymbolGetName, (MX_handle, Ref{char_p}, Ref{Int}), self.handle.value, name, success)
@assert success[] != -1
return Symbol(unsafe_string(name[]))

str = name[]
if str == C_NULL # e.g. the symbol returned via get_internals
string(self.handle.value)
else
Symbol(unsafe_string(str))
end
end

Base.show(io::IO, sym::SymbolicNode) =
Expand Down
11 changes: 9 additions & 2 deletions test/unittest/symbolic-node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ function test_power()
end
end

info("NDArray::power::e.^x::x.^e")
info("SymbolicNode::power::e.^x::x.^e")
let x = mx.Variable(:x), A = [0 0 0; 0 0 0]
y = exec(e.^x; :x => A)[]
@test copy(y) ones(A)
Expand All @@ -486,7 +486,7 @@ function test_power()
end
end

info("NDArray::power::π.^x::x.^π")
info("SymbolicNode::power::π.^x::x.^π")
let x = mx.Variable(:x), A = Float32[1 2; 3 4]
let y = π.^x
z = exec(y; :x => A)[]
Expand All @@ -500,6 +500,12 @@ function test_power()
end
end # function test_power

function test_get_name()
info("SymbolicNode::get_name::with get_internals")
name = mx.get_name(mx.get_internals(mlp2())) # no error
@test contains(name, "Ptr")
end # function test_get_name

################################################################################
# Run tests
################################################################################
Expand All @@ -522,6 +528,7 @@ end # function test_power
test_mul()
test_div()
test_power()
test_get_name()
end

end

0 comments on commit 1f50a14

Please sign in to comment.