Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mutable array operations #16

Merged
merged 5 commits into from
Nov 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[DocStringExtensions]]
deps = ["LibGit2", "Markdown", "Pkg", "Test"]
git-tree-sha1 = "0513f1a8991e9d83255e0140aace0d0fc4486600"
git-tree-sha1 = "88bb0edb352b16608036faadcc071adda068582a"
uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
version = "0.8.0"
version = "0.8.1"

[[Documenter]]
deps = ["Base64", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "4a84478277020abfff208cde31ba1aa68a5bc572"
deps = ["Base64", "Dates", "DocStringExtensions", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"]
git-tree-sha1 = "0be9bf63e854a2408c2ecd3c600d68d4d87d8a73"
uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
version = "0.23.0"
version = "0.24.2"

[[InteractiveUtils]]
deps = ["Markdown"]
Expand All @@ -48,9 +48,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[Parsers]]
deps = ["Dates", "Test"]
git-tree-sha1 = "db2b35dedab3c0e46dc15996d170af07a5ab91c9"
git-tree-sha1 = "0139ba59ce9bc680e2925aec5b7db79065d60556"
uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
version = "0.3.6"
version = "0.3.10"

[[Pkg]]
deps = ["Dates", "LibGit2", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
Expand Down
7 changes: 7 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
```@meta
CurrentModule = MutableArithmetics
DocTestSetup = quote
using MutableArithmetics
end
```

# MutableArithmetics.jl

```@index
Expand Down
9 changes: 3 additions & 6 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,16 @@ include("Test/Test.jl")
# Implementation of the interface for Base types
import LinearAlgebra
const Scaling = Union{Number, LinearAlgebra.UniformScaling}
#mutable_copy(A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(mutable_copy(parent(A)), ifelse(A.uplo == 'U', :U, :L))
## Broadcast applies the transpose
#mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A)))
#mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A)))
scaling(x::Scaling) = x
include("bigint.jl")
include("linear_algebra.jl")

isequal_canonical(a, b) = a == b

include("rewrite.jl")

include("dispatch.jl")

include("sparse_arrays.jl")

include("dispatch.jl")

end # module
3 changes: 2 additions & 1 deletion src/bigint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ function mutable_operate_to!(output::BigInt, op::Union{typeof(*), typeof(+)},
end

# add_mul
buffer_for(::typeof(add_mul), args::Vararg{Type{BigInt}, N}) where {N} = BigInt()
function mutable_operate_to!(output::BigInt, ::typeof(add_mul), args::Vararg{BigInt, N}) where N
return mutable_buffered_operate_to!(BigInt(), output, add_mul, args...)
end
Expand All @@ -43,6 +44,6 @@ end
scaling_to_bigint(x::BigInt) = x
scaling_to_bigint(x::Number) = convert(BigInt, x)
scaling_to_bigint(J::LinearAlgebra.UniformScaling) = scaling_to_bigint(J.λ)
function mutable_operate_to!(output::BigInt, op::Function, args::Vararg{Scaling, N}) where N
function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(*), typeof(add_mul)}, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(output, op, scaling_to_bigint.(args)...)
end
175 changes: 175 additions & 0 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,184 @@
# TODO: Intercepting "externally owned" method calls by dispatching on type parameters
# (rather than outermost wrapper type) is generally bad practice, but refactoring this code
# to use a different mechanism would be a lot of work. In the future, this interception code
# would be more easily/robustly replaced by using a tool like
# https://github.com/jrevels/Cassette.jl.

abstract type AbstractMutable end

function Base.sum(a::AbstractArray{<:AbstractMutable})
return mapreduce(identity, add!, a, init = zero(promote_operation(+, eltype(a), eltype(a))))
end

LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray) = _dot(lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray, rhs::AbstractArray{<:AbstractMutable}) = _dot(lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray{<:AbstractMutable}) = _dot(lhs, rhs)

function _dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
end
if iszero(lx)
return LinearAlgebra.dot(zero(eltype(x)), zero(eltype(y)))
end

# We need a buffer to hold the intermediate multiplication.
mul_buffer = buffer_for(add_mul, eltype(x), eltype(y))

s = zero(promote_operation(add_mul, eltype(x), eltype(x), eltype(y)))

for (Ix, Iy) in zip(eachindex(x), eachindex(y))
s = @inbounds buffered_operate!(mul_buffer, add_mul, s, x[Ix], y[Iy])
end

return s
end

# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)}))
_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A))
function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
@assert _one_indexed(x) # `LinearAlgebra.diagm` doesn't work for non-one-indexed arrays in general.
ZeroType = promote_operation(zero, eltype(x))
return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x))
end

###############################################################################
# Interception of Base's matrix/vector arithmetic machinery

# Redirect calls with `eltype(ret) <: AbstractMutable` to `_mul!` to
# replace it with an implementation more efficient than `generic_matmatmul!` and
# `generic_matvecmul!` since it takes into account the mutability of the arithmetic.
# We need `args...` because SparseArrays` also gives `α` and `β` arguments.

function _mul!(output, A, B, α, β)
# See SparseArrays/src/linalg.jl
if !isone(β)
if iszero(β)
mutable_operate!(zero, output)
else
rmul!(output, β)
end
end
return mutable_operate!(add_mul, output, A, B, α)
end
function _mul!(output, A, B, α)
mutable_operate!(zero, output)
return mutable_operate!(add_mul, output, A, B, α)
end
function _mul!(output, A, B)
mutable_operate!(zero, output)
return mutable_operate!(add_mul, output, A, B)
end

function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
A::AbstractVecOrMat, B::AbstractVecOrMat, args::Vararg{Any, N}) where N
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractVector{<:AbstractMutable},
A::AbstractVecOrMat, B::AbstractVector, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractVector, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractVector{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractVector, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
A::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat},
B::AbstractMatrix, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Transpose{<:Any,<:AbstractVecOrMat}, args...)
_mul!(ret, A, B, args...)
end
function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
A::AbstractMatrix,
B::LinearAlgebra.Adjoint{<:Any,<:AbstractVecOrMat}, args...)
_mul!(ret, A, B, args...)
end

# SparseArrays promotes the element types of `A` and `B` to the same type
# which always produce quadratic expressions for JuMP even if only one of them
# was affine and the other one constant. Moreover, it does not always go through
# `LinearAlgebra.mul!` which prevents us from using mutability of the arithmetic.
# For this reason we intercept the calls and redirect them to `mul`.

# A few are overwritten below but many more need to be redirected to `mul` in
# `linalg.jl`.

Base.:*(A::SparseMat{<:AbstractMutable}, x::StridedVector) = mul(A, x)
Base.:*(A::SparseMat, x::StridedVector{<:AbstractMutable}) = mul(A, x)
Base.:*(A::SparseMat{<:AbstractMutable}, x::StridedVector{<:AbstractMutable}) = mul(A, x)

Base.:*(A::SparseMat{<:AbstractMutable}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::SparseMat{<:Any}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::SparseMat{<:AbstractMutable}, B::SparseMat{<:Any}) = mul(A, B)

Base.:*(A::SparseMat{<:AbstractMutable}, B::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}) = mul(A, B)
Base.:*(A::SparseMat{<:Any}, B::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}) = mul(A, B)
Base.:*(A::SparseMat{<:AbstractMutable}, B::LinearAlgebra.Adjoint{<:Any, <:SparseMat}) = mul(A, B)

Base.:*(A::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::LinearAlgebra.Adjoint{<:Any, <:SparseMat}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}, B::SparseMat{<:Any}) = mul(A, B)

Base.:*(A::StridedMatrix{<:AbstractMutable}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::StridedMatrix{<:Any}, B::SparseMat{<:AbstractMutable}) = mul(A, B)
Base.:*(A::StridedMatrix{<:AbstractMutable}, B::SparseMat{<:Any}) = mul(A, B)

Base.:*(A::SparseMat{<:AbstractMutable}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B)
Base.:*(A::SparseMat{<:Any}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B)
Base.:*(A::SparseMat{<:AbstractMutable}, B::StridedMatrix{<:Any}) = mul(A, B)

Base.:*(A::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B)
Base.:*(A::LinearAlgebra.Adjoint{<:Any, <:SparseMat}, B::StridedMatrix{<:AbstractMutable}) = mul(A, B)
Base.:*(A::LinearAlgebra.Adjoint{<:AbstractMutable, <:SparseMat}, B::StridedMatrix{<:Any}) = mul(A, B)

# Base doesn't define efficient fallbacks for sparse array arithmetic involving
# non-`<:Number` scalar elements, so we define some of these for `<:AbstractMutable` scalar
# elements here.

function Base.:*(A::Scaling, B::SparseMat{<:AbstractMutable})
return SparseMat(B.m, B.n, copy(B.colptr), copy(rowvals(B)), A .* nonzeros(B))
end

function Base.:*(A::SparseMat{<:AbstractMutable}, B::Scaling)
return SparseMat(A.m, A.n, copy(A.colptr), copy(rowvals(A)), nonzeros(A) .* B)
end

function Base.:*(A::AbstractMutable, B::SparseMat)
return SparseMat(B.m, B.n, copy(B.colptr), copy(rowvals(B)), A .* nonzeros(B))
end

function Base.:*(A::SparseMat, B::AbstractMutable)
return SparseMat(A.m, A.n, copy(A.colptr), copy(rowvals(A)), nonzeros(A) .* B)
end

function Base.:/(A::SparseMat{<:AbstractMutable}, B::Scaling)
return SparseMat(A.m, A.n, copy(A.colptr), copy(rowvals(A)), nonzeros(A) ./ B)
end

Base.:+(A::AbstractArray{<:AbstractMutable}) = A

# Fix https://github.com/JuliaLang/julia/issues/32374 as done in
# https://github.com/JuliaLang/julia/pull/32375. This hack should
# be removed once we drop Julia v1.0.
function Base.:-(A::LinearAlgebra.Symmetric{<:AbstractMutable})
return LinearAlgebra.Symmetric(-parent(A), LinearAlgebra.sym_uplo(A.uplo))
end
function Base.:-(A::LinearAlgebra.Hermitian{<:AbstractMutable})
return LinearAlgebra.Hermitian(-parent(A), LinearAlgebra.sym_uplo(A.uplo))
end
9 changes: 8 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,18 @@ function mutable_operate!(op::Function, args::Vararg{Any, N}) where N
mutable_operate_to!(args[1], op, args...)
end

buffer_for(::Function, args::Vararg{Type, N}) where {N} = nothing

"""
mutable_buffered_operate_to!(buffer, output, op::Function, args...)

Modify the value of `output` to be equal to the value of `op(args...)`,
possibly modifying `buffer`. Can only be called if
`mutability(output, op, args...)` returns `true`.
"""
function mutable_buffered_operate_to! end
function mutable_buffered_operate_to!(::Nothing, output, op::Function, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, op, args...)
end

"""
mutable_buffered_operate!(buffer, op::Function, args...)
Expand All @@ -103,6 +107,9 @@ possibly modifying `buffer`. Can only be called if
function mutable_buffered_operate!(buffer, op::Function, args::Vararg{Any, N}) where N
mutable_buffered_operate_to!(buffer, args[1], op, args...)
end
function mutable_buffered_operate!(::Nothing, op::Function, args::Vararg{Any, N}) where N
return mutable_operate!(op, args...)
end

"""
operate_to!(output, op::Function, args...)
Expand Down
Loading