Skip to content

Commit

Permalink
ndarray: support transpose on 1D array (#375)
Browse files Browse the repository at this point in the history
Python doesn't have this functionality,
so I implement it via `reshape`.

```julia
julia> x = NDArray(Float32[1, 2, 3, 4])
4 mx.NDArray{Float32,1} @ CPU0:
 1.0
 2.0
 3.0
 4.0

julia> x'
1×4 mx.NDArray{Float32,2} @ CPU0:
 1.0  2.0  3.0  4.0
```
  • Loading branch information
iblislin authored and pluskid committed Dec 15, 2017
1 parent bfb1cc4 commit 6609616
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
15 changes: 15 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,21 @@
4.0
```

* Transposing a column `NDArray` to a row `NDArray` is supported now. (#TBD)

```julia
julia> x = NDArray(Float32[1, 2, 3, 4])
4 mx.NDArray{Float32,1} @ CPU0:
1.0
2.0
3.0
4.0

julia> x'
1×4 mx.NDArray{Float32,2} @ CPU0:
1.0 2.0 3.0 4.0
```

## API Changes

### `NDArray`
Expand Down
3 changes: 2 additions & 1 deletion src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1114,7 +1114,8 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap dot(x::NDArray{T,N}, y::NDArray{S,N}) where {T,S,N} dot(y, x)

# See https://github.com/dmlc/MXNet.jl/pull/123
@_remap transpose(arr::NDArray) transpose(_only2d(arr))
@_remap transpose(arr::NDArray{T,1}) where T reshape(arr; shape = (1, length(arr)), reverse = true)
@_remap transpose(arr::NDArray{T,2}) where T transpose(arr)
@_remap permutedims(arr::NDArray, axes) transpose(arr; axes = length(axes) .- tuple(axes...))

@_remap prod(arr::NDArray) prod(arr)
Expand Down
5 changes: 0 additions & 5 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ function _format_signature(narg::Int, arg_names::Ref{char_pp})
return join([unsafe_string(name) for name in arg_names] , ", ")
end

@inline function _only2d(x)
@assert ndims(x) == 2
x
end

"""
libmxnet operators signature checker.
Expand Down
8 changes: 7 additions & 1 deletion test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,7 +775,13 @@ function test_fill()
end # function test_fill

function test_transpose()
info("NDArray::transpose")
info("NDArray::transpose::1D")
let A = rand(Float32, 4), x = NDArray(A)
@test size(x) == (4,)
@test size(x') == (1, 4)
end

info("NDArray::transpose::2D")
let A = rand(Float32, 2, 3), x = mx.NDArray(A)
@test size(x) == (2, 3)
@test size(x') == (3, 2)
Expand Down

0 comments on commit 6609616

Please sign in to comment.