Skip to content

Commit

Permalink
Support broadcasting over structured block matrices (JuliaLang#53909)
Browse files Browse the repository at this point in the history
Fix JuliaLang#48664

After this, broadcasting over structured block matrices with
matrix-valued elements works:
```julia
julia> D = Diagonal([[1 2; 3 4], [5 6; 7 8]])
2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}:
 [1 2; 3 4]      ⋅     
     ⋅       [5 6; 7 8]

julia> D .+ D
2×2 Diagonal{Matrix{Int64}, Vector{Matrix{Int64}}}:
 [2 4; 6 8]      ⋅     
     ⋅       [10 12; 14 16]

julia> cos.(D)
2×2 Matrix{Matrix{Float64}}:
 [0.855423 -0.110876; -0.166315 0.689109]  [1.0 0.0; 0.0 1.0]
 [1.0 0.0; 0.0 1.0]                        [0.928384 -0.069963; -0.0816235 0.893403]
```
Such operations show up when using `BlockArrays`.

The implementation is a bit hacky as it uses `0I` as the zero element in
`fzero`, which isn't really the correct zero if the blocks are
rectangular. Nonetheless, this works, as `fzero` is only used to
determine if the structure is preserved.
  • Loading branch information
jishnub committed Apr 7, 2024
1 parent 1febcd6 commit 243ebc3
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
6 changes: 4 additions & 2 deletions stdlib/LinearAlgebra/src/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ struct StructuredMatrixStyle{T} <: Broadcast.AbstractArrayStyle{2} end
StructuredMatrixStyle{T}(::Val{2}) where {T} = StructuredMatrixStyle{T}()
StructuredMatrixStyle{T}(::Val{N}) where {T,N} = Broadcast.DefaultArrayStyle{N}()

const StructuredMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular}
for ST in Base.uniontypes(StructuredMatrix)
const StructuredMatrix{T} = Union{Diagonal{T},Bidiagonal{T},SymTridiagonal{T},Tridiagonal{T},LowerTriangular{T},UnitLowerTriangular{T},UpperTriangular{T},UnitUpperTriangular{T}}
for ST in (Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,LowerTriangular,UnitLowerTriangular,UpperTriangular,UnitUpperTriangular)
@eval Broadcast.BroadcastStyle(::Type{<:$ST}) = $(StructuredMatrixStyle{ST}())
end

Expand Down Expand Up @@ -133,6 +133,7 @@ fails as `zero(::Tuple{Int})` is not defined. However,
iszerodefined(::Type) = false
iszerodefined(::Type{<:Number}) = true
iszerodefined(::Type{<:AbstractArray{T}}) where T = iszerodefined(T)
iszerodefined(::Type{<:UniformScaling{T}}) where T = iszerodefined(T)

fzeropreserving(bc) = (v = fzero(bc); !ismissing(v) && (iszerodefined(typeof(v)) ? iszero(v) : v == 0))
# Like sparse matrices, we assume that the zero-preservation property of a broadcasted
Expand All @@ -144,6 +145,7 @@ fzero(::Type{T}) where T = T
fzero(r::Ref) = r[]
fzero(t::Tuple{Any}) = t[1]
fzero(S::StructuredMatrix) = zero(eltype(S))
fzero(::StructuredMatrix{<:AbstractMatrix{T}}) where {T<:Number} = haszero(T) ? zero(T)*I : missing
fzero(x) = missing
function fzero(bc::Broadcast.Broadcasted)
args = map(fzero, bc.args)
Expand Down
1 change: 1 addition & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ Random.seed!(1)
struct TypeWithZero end
Base.promote_rule(::Type{TypeWithoutZero}, ::Type{TypeWithZero}) = TypeWithZero
Base.convert(::Type{TypeWithZero}, ::TypeWithoutZero) = TypeWithZero()
Base.zero(x::Union{TypeWithoutZero, TypeWithZero}) = zero(typeof(x))
Base.zero(::Type{<:Union{TypeWithoutZero, TypeWithZero}}) = TypeWithZero()
LinearAlgebra.symmetric(::TypeWithoutZero, ::Symbol) = TypeWithoutZero()
LinearAlgebra.symmetric_type(::Type{TypeWithoutZero}) = TypeWithoutZero
Expand Down
58 changes: 58 additions & 0 deletions stdlib/LinearAlgebra/test/structuredbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,4 +280,62 @@ end
# structured broadcast with function returning non-number type
@test tuple.(Diagonal([1, 2])) == [(1,) (0,); (0,) (2,)]

@testset "broadcast over structured matrices with matrix elements" begin
function standardbroadcastingtests(D, T)
M = [x for x in D]
Dsum = D .+ D
@test Dsum isa T
@test Dsum == M .+ M
Dcopy = copy.(D)
@test Dcopy isa T
@test Dcopy == D
Df = float.(D)
@test Df isa T
@test Df == D
@test eltype(eltype(Df)) <: AbstractFloat
@test (x -> (x,)).(D) == (x -> (x,)).(M)
@test (x -> 1).(D) == ones(Int,size(D))
@test all(==(2), ndims.(D))
@test_throws MethodError size.(D)
end
@testset "Diagonal" begin
@testset "square" begin
A = [1 3; 2 4]
D = Diagonal([A, A])
standardbroadcastingtests(D, Diagonal)
@test sincos.(D) == sincos.(Matrix{eltype(D)}(D))
M = [x for x in D]
@test cos.(D) == cos.(M)
end

@testset "different-sized square blocks" begin
D = Diagonal([ones(3,3), fill(3.0,2,2)])
standardbroadcastingtests(D, Diagonal)
end

@testset "rectangular blocks" begin
D = Diagonal([ones(Bool,3,4), ones(Bool,2,3)])
standardbroadcastingtests(D, Diagonal)
end

@testset "incompatible sizes" begin
A = reshape(1:12, 4, 3)
B = reshape(1:12, 3, 4)
D1 = Diagonal(fill(A, 2))
D2 = Diagonal(fill(B, 2))
@test_throws DimensionMismatch D1 .+ D2
end
end
@testset "Bidiagonal" begin
A = [1 3; 2 4]
B = Bidiagonal(fill(A,3), fill(A,2), :U)
standardbroadcastingtests(B, Bidiagonal)
end
@testset "UpperTriangular" begin
A = [1 3; 2 4]
U = UpperTriangular([(i+j)*A for i in 1:3, j in 1:3])
standardbroadcastingtests(U, UpperTriangular)
end
end

end

0 comments on commit 243ebc3

Please sign in to comment.