Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chase down matmul allocation from mul_to! #12

Merged
merged 1 commit into from
Oct 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,15 @@ trial2 = @benchmark MA.mul_to!($c2, $A2, $b2)
display(trial2)

BenchmarkTools.Trial:
memory estimate: 168 bytes
allocs estimate: 9
memory estimate: 48 bytes
allocs estimate: 3
--------------
minimum time: 928.306 μs (0.00% GC)
median time: 933.144 μs (0.00% GC)
mean time: 952.015 μs (0.00% GC)
maximum time: 1.910 ms (0.00% GC)
minimum time: 917.819 μs (0.00% GC)
median time: 999.239 μs (0.00% GC)
mean time: 1.042 ms (0.00% GC)
maximum time: 2.319 ms (0.00% GC)
--------------
samples: 5244
samples: 4791
evals/sample: 1
```

Expand Down
10 changes: 5 additions & 5 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@ struct NotMutable <: MutableTrait end
Return `IsMutable` to indicate an object of type `T` can be modified to be
equal to `op(args...)`.
"""
function mutability(T::Type, op, args::Type...)
function mutability(T::Type, op, args::Vararg{Type, N}) where N
if mutability(T) isa IsMutable && promote_operation(op, args...) == T
return IsMutable()
else
return NotMutable()
end
end
mutability(x, op, args...) = mutability(typeof(x), op, typeof.(args)...)
mutability(x, op, args::Vararg{Any, N}) where {N} = mutability(typeof(x), op, typeof.(args)...)
mutability(::Type) = NotMutable()

function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...)
Expand Down Expand Up @@ -85,14 +85,14 @@ end

Returns the value of `op(args...)`, possibly modifying `output`.
"""
function operate_to!(output, op::Function, args...)
function operate_to!(output, op::Function, args::Vararg{Any, N}) where N
return operate_to_fallback!(mutability(output, op, args...), output, op, args...)
end

function operate_to_fallback!(::NotMutable, output, op::Function, args...)
function operate_to_fallback!(::NotMutable, output, op::Function, args::Vararg{Any, N}) where N
return op(args...)
end
function operate_to_fallback!(::IsMutable, output, op::Function, args...)
function operate_to_fallback!(::IsMutable, output, op::Function, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, op, args...)
end

Expand Down
10 changes: 5 additions & 5 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ function mutable_operate_to!(C::Vector, ::typeof(*), A::AbstractMatrix, B::Abstr
end

# We need a buffer to hold the intermediate multiplication.
mul_buffer = zero(zero(eltype(A)) * zero(eltype(B)))
for k = 1:mB
aoffs = (k-1)*Astride
mul_buffer = zero(promote_operation(*, eltype(A), eltype(B)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the change that removed the 140 bytes of allocation.

@inbounds for k = Base.OneTo(mB)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks for catching this

aoffs = (k-1) * Astride
b = B[k]
for i = 1:mA
for i = Base.OneTo(mA)
# `C[i] = muladd_buf!(mul_buffer, C[i], A[aoffs + i], b)`
mutable_buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs + i], b)
end
end
end # @inbounds
end
return C
end

Expand Down
4 changes: 2 additions & 2 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add!(args...) = operate!(+, args...)

Return the product of `b` and `c`, possibly modifying `a`.
"""
mul_to!(output, args...) = operate_to!(output, *, args...)
mul_to!(output, args::Vararg{Any, N}) where {N} = operate_to!(output, *, args...)

"""
mul!(a, b, ...)
Expand All @@ -34,7 +34,7 @@ Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`.
function add_mul end
add_mul(a, b, c) = muladd(b, c, a)

function promote_operation(::typeof(add_mul), T::Type, args::Type...)
function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) where N
return promote_operation(+, T, promote_operation(*, args...))
end

Expand Down
10 changes: 10 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
function alloc_test(f, n)
f() # compile
@test n == @allocated f()
end
@testset "Matrix multiplication" begin
@testset "matrix-vector product" begin
A = BigInt[1 1 1; 1 1 1; 1 1 1]
Expand All @@ -16,5 +20,11 @@
@test MA.mul_to!(y, A, x) == BigInt[3; 3; 3] && y == BigInt[3; 3; 3]
@test_throws DimensionMismatch MA.mul(BigInt[1 1; 1 1], BigInt[])
@test_throws DimensionMismatch MA.mul_to!(BigInt[], BigInt[1 1; 1 1], BigInt[1; 1])

# 40 bytes to create the buffer
# 8 bytes in the double for loop. FIXME: figure out why
alloc_test(() -> MA.mul_to!(y, A, x), 48)
alloc_test(() -> MA.operate_to!(y, *, A, x), 48)
alloc_test(() -> MA.mutable_operate_to!(y, *, A, x), 48)
end
end