Skip to content

Commit

Permalink
ndarray: add outer constrcutor for AbstractArray (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin authored and pluskid committed Nov 20, 2017
1 parent ceb7fbf commit 1a7887c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 13 deletions.
7 changes: 3 additions & 4 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ mutable struct NDArray
end
end

NDArray(x::AbstractArray{T}) where {T<:DType} = copy(collect(x), cpu())
NDArray(x::Array{T}) where {T<:DType} = copy(x, cpu())

const NDArrayOrReal = Union{NDArray, Real}

@unfuse NDArray
Expand All @@ -107,10 +110,6 @@ function Base.show(io :: IO, arr :: NDArray)
Base.showarray(io, try_get_shared(arr, sync=:read), false, header=false)
end

function NDArray(data :: Array{T}) where T<:Real
copy(data, cpu())
end

function Base.unsafe_convert(::Type{MX_handle}, obj::NDArray)
Base.unsafe_convert(MX_handle, obj.handle)
end
Expand Down
33 changes: 24 additions & 9 deletions test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,20 @@ function rand_tensors(::Type{T}, dims::NTuple{N, Int}) where {N, T}
return (tensor, array)
end

function test_constructor()
info("NDArray::NDArray(x::AbstractArray)")
function check_absarray(x)
y = mx.NDArray(x)
@test ndims(x) == ndims(y)
@test eltype(x) == eltype(y)
@test x[3] == y[3][]
end

check_absarray(1:10)
check_absarray(1.0:10)
end # function test_constructor


function test_copy()
dims = rand_dims()
tensor = rand(mx.MX_float, dims)
Expand Down Expand Up @@ -87,7 +101,7 @@ end

function test_linear_idx()
info("NDArray::getindex::linear indexing")
let A = reshape(collect(1:30), 3, 10)
let A = reshape(1:30, 3, 10)
x = mx.NDArray(A)

@test copy(x) == A
Expand All @@ -104,7 +118,7 @@ function test_linear_idx()
@test_throws BoundsError x[42]
end

let A = reshape(collect(1:24), 3, 2, 4)
let A = reshape(1:24, 3, 2, 4)
x = mx.NDArray(A)

@test copy(x) == A
Expand All @@ -118,7 +132,7 @@ function test_linear_idx()
end

info("NDArray::setindex!::linear indexing")
let A = reshape(collect(1:24), 3, 2, 4)
let A = reshape(1:24, 3, 2, 4)
x = mx.NDArray(A)

@test copy(x) == A
Expand All @@ -136,7 +150,7 @@ end # function test_linear_idx

function test_first()
info("NDArray::first")
let A = reshape(collect(1:30), 3, 10)
let A = reshape(1:30, 3, 10)
x = mx.NDArray(A)

@test x[] == 1
Expand Down Expand Up @@ -613,7 +627,7 @@ end
function test_sum()
info("NDArray::sum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
let A = reshape(1.0:8, 2, 2, 2), X = mx.NDArray(A)
@test copy(sum(X))[] == sum(A)
@test copy(sum(X, 1)) == sum(A, 1)
@test copy(sum(X, 2)) == sum(A, 2)
Expand All @@ -626,7 +640,7 @@ end
function test_mean()
info("NDArray::mean")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
let A = reshape(1.0:8, 2, 2, 2), X = mx.NDArray(A)
@test copy(mean(X))[] == mean(A)
@test copy(mean(X, 1)) == mean(A, 1)
@test copy(mean(X, 2)) == mean(A, 2)
Expand All @@ -639,7 +653,7 @@ end
function test_maximum()
info("NDArray::maximum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
let A = reshape(1.0:8, 2, 2, 2), X = mx.NDArray(A)
@test copy(maximum(X))[] == maximum(A)
@test copy(maximum(X, 1)) == maximum(A, 1)
@test copy(maximum(X, 2)) == maximum(A, 2)
Expand All @@ -652,7 +666,7 @@ end
function test_minimum()
info("NDArray::minimum")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
let A = reshape(1.0:8, 2, 2, 2), X = mx.NDArray(A)
@test copy(minimum(X))[] == minimum(A)
@test copy(minimum(X, 1)) == minimum(A, 1)
@test copy(minimum(X, 2)) == minimum(A, 2)
Expand All @@ -665,7 +679,7 @@ end
function test_prod()
info("NDArray::prod")

let A = reshape(1.0:8, 2, 2, 2) |> collect, X = mx.NDArray(A)
let A = reshape(1.0:8, 2, 2, 2), X = mx.NDArray(A)
@test copy(prod(X))[] == prod(A)
@test copy(prod(X, 1)) == prod(A, 1)
@test copy(prod(X, 2)) == prod(A, 2)
Expand Down Expand Up @@ -740,6 +754,7 @@ end
# Run tests
################################################################################
@testset "NDArray Test" begin
test_constructor()
test_assign()
test_copy()
test_slice()
Expand Down

0 comments on commit 1a7887c

Please sign in to comment.