Skip to content
Permalink
Browse files

Basic broadcasting has been implemented.

  • Loading branch information...
chriselrod committed May 26, 2019
1 parent 9539ca5 commit 74a216a6a919811db72defd4f42c459966173cea
Showing with 113 additions and 32 deletions.
  1. +1 −1 src/PaddedMatrices.jl
  2. +112 −31 src/broadcast.jl
@@ -121,6 +121,6 @@ include("linear_algebra.jl")
include("rand.jl")
include("utilities.jl")
include("seed_increments.jl")
#include("broadcast.jl")
include("broadcast.jl")

end # module
@@ -126,7 +126,7 @@ function LinearAlgebra.:×(A::TA, B::TB) where {T,TA <: AbstractPaddedVector{T},
end
@inline Base.size(A::VectorMatrixProduct) = (1,size(A.B,2))
@inline Base.axes(A::VectorMatrixProduct) = (Base.OneTo(1),Base.OneTo(size(A.B,2)))
nn

struct MatrixVectorProduct{T,TA <: AbstractPaddedMatrix{T}, TB <: AbstractPaddedVector{T}} <: AbstractProdct{T,1}
A::TA
B::TB
@@ -201,18 +201,90 @@ end
KernelBatches # Same pattern as matrix mul functions
end

abstract type AbstractPaddedMatrixStyle{S, A <: AccessPattern} <: Broadcast.BroadcastStyle end
abstract type AbstractPaddedMatrixStyle{S, A} <: Broadcast.BroadcastStyle end

struct FixedSizePaddedMatrixDefaultStyle{S,A} <: AbstractPaddedMatrixStyle{S,A} end

Base.BroadcastStyle(::Type{<:AbstractFixedSizePaddedArray{S}}) where {S} = FixedSizePaddedMatrixDefaultStyle{S,Agnostic}()
Base.BroadcastStyle(::Type{<:AbstractFixedSizePaddedArray{S}}) where {S} = FixedSizePaddedMatrixDefaultStyle{S,LinearIndexing}()

function Base.BroadcastStyle(style1::FixedSizePaddedMatrixDefaultStyle{S1,A1},
style2::FixedSizePaddedMatrixDefaultStyle{S2,A2}) where {S1,S2,A1,A2}
@inline function Base.Broadcast.result_style(s1::FixedSizePaddedMatrixDefaultStyle, s2::FixedSizePaddedMatrixDefaultStyle)
# s1, s2 is always the canonical order.
Base.Broadcast.BroadcastStyle(s1, s2)
end

@generated function Base.Broadcast.combine_styles(s::Vararg{FixedSizePaddedMatrixDefaultStyle,N}) where {N}
Svec = DataType[]
A = LinearIndexing
for n ∈ 1:N
sₙ = s[n].parameters
push!(Svec, sₙ[1])
A = max(A, sₙ[2])
end
S = Tuple{Svec...}
if (A == CartesianIndexing) || (A == KernelBatches)
return FixedSizePaddedMatrixDefaultStyle{S,A}()
end
sz = reduce_size(S)
nonscalarinds = (Svec .=== Tuple{}) .== false
if A == LinearIndexing
all_equal = true
nonscalar = Svec[nonscalarinds]
for n ∈ 2:length(nonscalar)
all_equal &= (nonscalar[n] === nonscalar[n-1])
end
all_equal && return FixedSizePaddedMatrixDefaultStyle{S,LinearIndexing}()
end
if length(sz) == 1
return FixedSizePaddedMatrixDefaultStyle{S,LinearIndexing}()
else#if length(sz) > 2
return FixedSizePaddedMatrixDefaultStyle{S,CartesianIndexing}()
end

end

@generated function Base.Broadcast.BroadcastStyle(style1::FixedSizePaddedMatrixDefaultStyle{S1,A1},
style2::FixedSizePaddedMatrixDefaultStyle{S2,A2}) where {S1,S2,A1,A2}

if (A1 == KernelBatches) || (A2 == KernelBatches)
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},KernelBatches}()
elseif (A1 == CartesianIndexing) || (A2 == CartesianIndexing)
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},CartesianIndexing}()
end
sa1 = reduce_size(S1)
sa2 = reduce_size(S2)
l1 = length(sa1)
l2 = length(sa2)
if ((A1 == BatchedColumnMajor) || (A2 == BatchedColumnMajor)) && (max(l1,l2)>2)
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},CartesianIndexing}()
end
#
equal_lengths = l1 == l2
if equal_lengths
equal_dims = sa1 .== sa2
elseif l1 < l2
equal_dims = sa1 .== @view(sa2[1:l1])
else
equal_dims = @view(sa1[1:l2]) .== sa2
end
if all(equal_dims) && equal_lengths
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},max(A1,A2)}()
elseif (max(length(sa1), length(sa2)) == 2) && equal_dims[1]
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},BatchedColumnMajor}()
else
return FixedSizePaddedMatrixDefaultStyle{Tuple{S1,S2},CartesianIndexing}()
end
end

function Base.Broadcast.BroadcastStyle(style1::FixedSizePaddedMatrixDefaultStyle{S,A},
style2::Base.Broadcast.DefaultArrayStyle{0}
) where {S,A}
return FixedSizePaddedMatrixDefaultStyle{Tuple{S,Tuple{}},A}()
end
function Base.Broadcast.BroadcastStyle(style1::Base.Broadcast.DefaultArrayStyle{0},
style2::FixedSizePaddedMatrixDefaultStyle{S,A}
) where {S,A}
return FixedSizePaddedMatrixDefaultStyle{Tuple{Tuple{},S},A}()
end

# @generated function Base.BroadcastStyle(style1::AbstractPaddedMatrixStyle{S1,A1,C1}, style2::AbstractPaddedMatrixStyle{S2,A2,C2}) where {S1,A1,C1,S2,A2,C2}
# A3 = max(A1, A2)
@@ -246,28 +318,36 @@ end


# end
Base.BroadcastStyle(::Base.Broadcast.DefaultArrayStyle{0}, style::AbstractPaddedMatrixStyle) = style
# Base.BroadcastStyle(::Base.Broadcast.DefaultArrayStyle{0}, style::AbstractPaddedMatrixStyle) = style


@generated function Base.similar(bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{S,A,C}}, ::Type{T}) where {S,A,C,T}
@generated function Base.similar(bc::Base.Broadcast.Broadcasted{FixedSizePaddedMatrixDefaultStyle{S,A}}, ::Type{T}) where {S,A,T}
Salloc = reduce_size(S)
N, R, L = calc_NPL(Salloc, T)
ST = Tuple{Salloc...}
:(MutableFixedSizePaddedArray{$ST,$T,$N,$R,$L}(undef))
end
@generated function Base.similar(bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{S,A,C}}, mystack::Ptr{T}) where {S,A,C,T}
@generated function Base.similar(bc::Base.Broadcast.Broadcasted{FixedSizePaddedMatrixDefaultStyle{S,A}}, mystack::Ptr{T}) where {S,A,T}
Salloc = reduce_size(S)
N, R, L = calc_NPL(Salloc, T)
ST = Tuple{Salloc...}
# allocates from mystack, and advances that pointer.
:(PtrArray{$ST,$T,$N,$R,$L,true}(mystack), mystack+$(L*sizeof(T)))
end
@inline function Base.Broadcast.materialize(bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{S,A,C}}, mystack::Ptr{T}) where {S,A,C,T}
@inline function Base.Broadcast.materialize(
bc::Base.Broadcast.Broadcasted{FixedSizePaddedMatrixDefaultStyle{S,A}},
mystack::Ptr{T}) where {S,A,T}
out, mystack = similar(bc, mystack)
Base.broadcast.materialize!(out, bc), mystack
Base.broadcast.materialize!(out, bc)
out, mystack
end
@inline function Base.Broadcast.materialize(bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{S,A,C}}) where {S,A,C}
Base.broadcast.materialize!(similar(bc), bc)
@inline function Base.Broadcast.materialize(
bc::Base.Broadcast.Broadcasted{FixedSizePaddedMatrixDefaultStyle{S,A}}
) where {S,A}
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
out = similar(bc, ElType)
Base.Broadcast.materialize!(out, bc)
out
end

@inline extract(x::Base.RefValue) = x[]
@@ -279,7 +359,7 @@ function broadcast_index_expression(SB, N)
q = quote end
preq = quote end
assign_to = gensym(:assign)
broadcast_index_expression!(q, preq, (SB.parameters)::DataType, inds, :bc, assign_to)
broadcast_index_expression!(q, preq, SB.parameters, inds, :bc, assign_to)
push!(q.args, Expr(:call, :setindex!, :out, assign_to, inds...))
inds, q, preq
end
@@ -289,11 +369,11 @@ function broadcast_index_expression!(q, preq, SBV, inds, bcsym, assign)
SBVₗ = ((SBV[l])::DataType).parameters
argsym = gensym(:arg)
if length(SBVₗ) == 0 # scalar argument
push!(preq.args, :($argsym = @inbounds PaddedMatrices.extract($bcsym.args[1])))
push!(preq.args, :($argsym = @inbounds PaddedMatrices.extract($bcsym.args[$l])))
push!(callexpr.args, argsym)
continue
else
push!(preq.args, :($argsym = @inbounds $bcsym.args[1]))
push!(preq.args, :($argsym = @inbounds $bcsym.args[$l]))
end
SBVₗ₁ = (SBVₗ[1])::Union{Int,DataType}
if SBVₗ₁ isa Int # array argument
@@ -321,7 +401,7 @@ function broadcast_linearindex_expression(SB, N)
q = quote end
preq = quote end
assign_to = gensym(:assign)
broadcast_linearindex_expression!(q, preq, (SB.parameters)::DataType, ind, :bc, assign_to)
broadcast_linearindex_expression!(q, preq, SB.parameters, ind, :bc, assign_to)
push!(q.args, Expr(:call, :setindex!, :out, assign_to, ind))
ind, q, preq
end
@@ -331,11 +411,11 @@ function broadcast_linearindex_expression!(q, preq, SBV, ind, bcsym, assign)
SBVₗ = ((SBV[l])::DataType).parameters
argsym = gensym(:arg)
if length(SBVₗ) == 0 # scalar argument
push!(preq.args, :($argsym = @inbounds PaddedMatrices.extract($bcsym.args[1])))
push!(preq.args, :($argsym = @inbounds PaddedMatrices.extract($bcsym.args[$l])))
push!(callexpr.args, argsym)
continue
else
push!(preq.args, :($argsym = @inbounds $bcsym.args[1]))
push!(preq.args, :($argsym = @inbounds $bcsym.args[$l]))
end
SBVₗ₁ = (SBVₗ[1])::Union{Int,DataType}
if SBVₗ₁ isa Int # array argument
@@ -350,10 +430,11 @@ function broadcast_linearindex_expression!(q, preq, SBV, ind, bcsym, assign)
nothing
end

function materialize_quote(S, A, C, T, SB, N, P)
function materialize_quote(S, A, T, SB, N, P)
if A == LinearIndexing
assigned_to, ind, loop_body, pre_loop = broadcast_index_expression(SB, N)
ind, loop_body, pre_loop = broadcast_index_expression(SB, N)
return quote
$(Expr(:meta,:inline))
$pre_loop
LoopVectorization.@vectorize $T for $ind ∈ 1:$(prod(S))
$loop_body
@@ -367,7 +448,7 @@ function materialize_quote(S, A, C, T, SB, N, P)
# and make sure those that shouldn't be reloaded are in fact not.
# Ie, which arrays are column vectors vs matrices,
# and are special functions like exp getting called, which will consume a lot of registers?
assigned_to, inds, loop_body, pre_loop = broadcast_index_expression(SB, N)
inds, loop_body, pre_loop = broadcast_index_expression(SB, N)
# M = S[1]
# L = S[2]
# Md, Mr = divrem(M,
@@ -384,11 +465,12 @@ function materialize_quote(S, A, C, T, SB, N, P)
end
end
return quote
$(Expr(:meta,:inline))
$pre_loop
$loop
end
elseif A == CartesianIndexing
assigned_to, inds, loop_body, pre_loop = broadcast_index_expression(SB, N)
inds, loop_body, pre_loop = broadcast_index_expression(SB, N)
loop = quote
LoopVectorization.@vectorize $T for $(inds[1]) ∈ 1:$(S[1])
$loop_body
@@ -402,27 +484,24 @@ function materialize_quote(S, A, C, T, SB, N, P)
end
end
return quote
$(Expr(:meta,:inline))
$pre_loop
$loop
end
elseif A == KernelBatches

throw("KernelBatches not yet implemented.")
end


end


@generated function Base.Broadcast.materialize!(out::AbstractMutableFixedSizePaddedMatrix{S,T,N,P}, bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{SB,A,C}}) where {S,A,C,T,SB,N,P}

materialize_quote(S, A, C, T, SB, N, P)



@generated function Base.Broadcast.materialize!(out::AbstractMutableFixedSizePaddedArray{S,T,N,P}, bc::Base.Broadcast.Broadcasted{FixedSizePaddedMatrixDefaultStyle{SB,A}}) where {S,A,T,SB,N,P}
materialize_quote(S.parameters, A, T, SB, N, P)
end



#=
function batched_column_major_quote(S, A, C, T)
end
@@ -443,7 +522,8 @@ function kernel_batch_quote(S, A, C, T)
end
end
end

=#
#=
@generated function Base.copyto!(
dest::AbstractMutableFixedSizePaddedArray{S1,T},
bc::Base.Broadcast.Broadcasted{PaddedMatrixStyle{S2,A,C}}
@@ -461,3 +541,4 @@ end
return kernel_bactch_quote(SV, A, C, T)
end
end
=#

0 comments on commit 74a216a

Please sign in to comment.
You can’t perform that action at this time.