Skip to content

Commit

Permalink
ndarray: support matrix/tensor multiplication (#384)
Browse files Browse the repository at this point in the history
  • Loading branch information
iblislin committed Dec 21, 2017
1 parent 883cdd3 commit bfeba81
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
20 changes: 20 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,26 @@
1.0 2.0 3.0 4.0
```

* Matrix/tensor multiplication is supported now. (#TBD)

```julia
julia> x
2×3 mx.NDArray{Float32,2} @ CPU0:
1.0 2.0 3.0
4.0 5.0 6.0

julia> y
3 mx.NDArray{Float32,1} @ CPU0:
-1.0
-2.0
-3.0

julia> x * y
2 mx.NDArray{Float32,1} @ CPU0:
-14.0
-32.0
```

## API Changes

### `NDArray`
Expand Down
8 changes: 4 additions & 4 deletions src/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ import Base: *
"""
.*(x, y)
Currently only multiplication a scalar with an `NDArray` is implemented.
Elementwise multiplication for `NDArray`.
"""
*(x::NDArray, y::Real) = _mul_scalar(x, scalar = y)
*(y::Real, x::NDArray) = _mul_scalar(x, scalar = y)
Expand All @@ -686,9 +686,9 @@ broadcast_(::typeof(*), x::NDArray, y::NDArray) = _mul(x, y)
"""
*(A::NDArray, B::NDArray)
Matrix (2D NDArray) multiplication.
Matrix/tensor multiplication.
"""
*(x::NDArray{T,2}, y::NDArray{S,2}) where {T,S} = dot(x, y)
*(x::NDArray{T}, y::NDArray{T}) where T = x y

"""
div_from!(dst::NDArray, arg::NDArrayOrReal)
Expand Down Expand Up @@ -1165,7 +1165,7 @@ _mxsig[:reshape] = :(reshape(arr; shape = dim, reverse = !reverse))
@_remap minimum(arr::NDArray, dims) min(arr; axis = 0 .- dims, keepdims = true)

# See https://github.com/dmlc/MXNet.jl/issues/55
@_remap dot(x::NDArray{T,N}, y::NDArray{S,N}) where {T,S,N} dot(y, x)
@_remap dot(x::NDArray, y::NDArray) dot(y, x)

# See https://github.com/dmlc/MXNet.jl/pull/123
@_remap transpose(arr::NDArray{T,1}) where T reshape(arr; shape = (1, length(arr)), reverse = true)
Expand Down
23 changes: 22 additions & 1 deletion test/unittest/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,28 @@ function test_dot()

x = mx.zeros(1, 2)
y = mx.zeros(1, 2, 3)
@test_throws MethodError dot(x, y)
@test_throws mx.MXError dot(x, y) # dimension mismatch

info("NDArray::matrix mul")
let
A = [1. 2 3; 4 5 6]
B = [-1., -2, -3]
x = NDArray(A)
y = NDArray(B)
z = x * y
@test copy(z) == A * B
@test size(z) == (2,)
end

let
A = [1. 2 3; 4 5 6]
B = [-1. -2; -3 -4; -5 -6]
x = NDArray(A)
y = NDArray(B)
z = x * y
@test copy(z) == A * B
@test size(z) == (2, 2)
end
end

function test_eltype()
Expand Down

0 comments on commit bfeba81

Please sign in to comment.