In [1]:
using BenchmarkTools, StructArrays, LinearAlgebra, Profile, PProf, Zygote, Zygote.ChainRulesCore

In [2]:
struct Variable{D}
    d :: D    
end
mean(v::Variable) = mean(v.d)
var(v::Variable) = var(v.d)
mean_var(x) = mean(x), var(x)
struct Normal{T}
    m :: T
    v :: T
end
mean(n::Normal) = n.m
var(n::Normal) = n.v
Base.zero(::Variable{<:Normal{T}}) where { T } = Variable(Normal(zero(T), zero(T)))
Base.zero(::Type{Variable{T}}) where { T } = Variable(zero(T))
Base.zero(::Type{Normal{T}}) where { T } = Normal(zero(T), zero(T))
function Base.:*(n1::Variable{<:Normal}, n2::Variable{<:Normal})
    m1, v1 = mean_var(n1)
    m2, v2 = mean_var(n2)

    return Variable(Normal(m1 * m2, v1 * v2 + v2 * m1 * m1 + v1 * m2 * m2))
end
function Base.:+(n1::Variable{<:Normal}, n2::Variable{<:Normal})
    m1, v1 = mean_var(n1)
    m2, v2 = mean_var(n2)

    return Variable(Normal(m1 + m2, v1 + v2))
end
LinearAlgebra.dot(n1::Variable{<:Normal}, n2::Variable{<:Normal}) = n1 * n2
;

In [3]:
dimin = 1000
dimout = 500
T = Float32

W = Variable.(Normal.(randn(T,dimout,dimin), abs.(randn(T,dimout,dimin))))
b = Variable.(Normal.(randn(T,dimout), abs.(randn(T,dimout))))
x = Variable.(Normal.(randn(T,dimin), abs.(randn(T,dimin))))

W_sa = StructArray(W; unwrap = t -> t <: Normal)
b_sa = StructArray(b; unwrap = t -> t <: Normal)
x_sa = StructArray(x; unwrap = t -> t <: Normal)
;

In [4]:
function Base.:*(A::Matrix{T1}, x::Vector{T2}) where { T1 <: Variable, T2 <: Variable }
    y = zeros(T1, size(A, 1))
    for i = 1:size(A, 1), k = 1:size(A, 2)
        y[i] += A[i,k] * x[k]
    end
    return y
end

In [5]:
# function ChainRulesCore.rrule(::typeof(*), A::Matrix{T1}, x::Vector{T2}) where { T1 <: Variable, T2 <: Variable }
#     A * x, Δ -> (NoTangent(), NoTangent(), NoTangent())
# end

In [6]:
@benchmark $W * $x + $b

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m363.000 μs[22m[39m … [35m635.042 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m370.979 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m381.257 μs[22m[39m ± [32m 20.072 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m▆[39m█[39m▆[39m▆[39m▅[39m▃[34m▃[39m[39m▃[39m▂[39m▂[39m▂[39m▂[39m▂[39m▂[39m▃[32m▂[39m[39m▂[39m▂[39m▁[39m▁[39m▂[39m▃[39m▅[39m▅[39m▄[39m▄[39m▃[39m▃[39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m▁[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[3

In [7]:
@benchmark $W_sa * $x_sa + $b_sa

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m82.417 μs[22m[39m … [35m118.792 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m89.708 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m90.497 μs[22m[39m ± [32m  3.086 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▆[39m█[39m▇[39m█[34m▇[39m[39m▇[32m▆[39m[39m▄[39m▃[39m▄[39m▃[39m▄[39m▃[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▃
  [39m▇[39m█[39m█[39m▇[39m▅

In [8]:
abstract type Layer end
struct Dense{T1, T2} <: Layer
    W :: T1
    b :: T2
end
(l::Dense)(x) = l.W * x + l.b

In [9]:
layer = Dense(W, b)
layer_sa = Dense(W_sa, b_sa);

In [10]:
@benchmark layer($x)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m363.458 μs[22m[39m … [35m649.375 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m393.667 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m399.091 μs[22m[39m ± [32m 22.707 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m▆[39m▅[39m▄[39m▄[39m▂[39m▁[39m▂[39m▁[39m▂[39m▂[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m▁[39m▁[39m▅[39m█[39m▅[34m▆[39m[39m▄[39m▃[39m▃[32m▂[39m[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m▁[39m▂[39m▃[39m▄[39m▅[39m▆[39m▆[39m▅[39m▄[39m▃[39m▂[39m▂[39m▂[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▂
  [39m█[39m█[39m█[3

In [11]:
@benchmark layer_sa($x_sa)

BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m82.750 μs[22m[39m … [35m118.416 μs[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m89.792 μs               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m90.912 μs[22m[39m ± [32m  3.536 μs[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m0.00% ± 0.00%

  [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▆[39m█[39m█[39m█[34m▇[39m[39m▇[32m▅[39m[39m▄[39m▄[39m▃[39m▄[39m▃[39m▂[39m▂[39m▁[39m▂[39m▁[39m▁[39m [39m [39m▁[39m▁[39m▁[39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m▁[39m▁[39m [39m [39m [39m▁[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m▃
  [39m▇[39m▅[39m▃[39m▁[39m▁

In [12]:
dlayer = gradient(l -> sum(mean.(l(x))), layer)

ErrorException: Mutating arrays is not supported -- called setindex!(Vector{Variable{Normal{Float32}}}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations


In [16]:
dlayer_sa = gradient(l -> sum(mean.(l(x_sa))), layer_sa)

ErrorException: Mutating arrays is not supported -- called setindex!(StructVector{Variable{Normal{Float32}}, @NamedTuple{d::StructVector{Normal{Float32}, @NamedTuple{m::Vector{Float32}, v::Vector{Float32}}, Int64}}, Int64}, ...)
This error occurs when you ask Zygote to differentiate operations that change
the elements of arrays in place (e.g. setting values with x .= ...)

Possible fixes:
- avoid mutating operations (preferred)
- or read the documentation and solutions for this error
  https://fluxml.ai/Zygote.jl/latest/limitations
