Skip to content

Commit

Permalink
ndarray: add Base.show(io, MIME"text/plain") (#347)
Browse files Browse the repository at this point in the history
* ndarray: add `Base.show(io, MIME"text/plain")`

e.g. make Array of NDArray show like this
```julia
julia> [mx.zeros(100)]
1-element Array{MXNet.mx.NDArray,1}:
 NDArray Float32[0.0, 0.0, 0.0  … 0.0, 0.0, 0.0]
```

* test cases
  • Loading branch information
iblislin committed Nov 27, 2017
1 parent cb06a9a commit 45d6279
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
12 changes: 9 additions & 3 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,15 @@ const NDArrayOrReal = Union{NDArray, Real}

@unfuse NDArray

function Base.show(io :: IO, arr :: NDArray)
println(io, "$(join(size(arr), "×")) mx.NDArray{$(eltype(arr))} @ $(context(arr)):")
Base.showarray(io, try_get_shared(arr, sync=:read), false, header=false)
function Base.show(io::IO, x::NDArray)
print(io, "NDArray ")
Base.showarray(io, try_get_shared(x, sync = :read), header = false)
end

# for REPL
function Base.show(io::IO, ::MIME{Symbol("text/plain")}, x::NDArray)
println(io, "$(join(size(x), "×")) mx.NDArray{$(eltype(x))} @ $(context(x)):")
Base.showarray(io, try_get_shared(x, sync = :read), false, header = false)
end

function Base.unsafe_convert(::Type{MX_handle}, obj::NDArray)
Expand Down
12 changes: 11 additions & 1 deletion test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,13 +741,23 @@ function test_transpose()
end

function test_show()
let str = sprint(show, mx.NDArray([1 2 3 4]))
info("NDArray::show::REPL")
let str = sprint(show, MIME"text/plain"(), mx.NDArray([1 2 3 4]))
@test contains(str, "1×4")
@test contains(str, "mx.NDArray")
@test contains(str, "Int64")
@test contains(str, "CPU")
@test match(r"1\s+2\s+3\s+4", str) != nothing
end

info("NDArray::show")
let str = sprint(show, mx.NDArray([1 2 3 4]))
@test str == "NDArray [1 2 3 4]"
end

let str = sprint(show, mx.zeros(4))
@test str == "NDArray Float32[0.0, 0.0, 0.0, 0.0]"
end
end

################################################################################
Expand Down

0 comments on commit 45d6279

Please sign in to comment.