In [1]:
using FunctionOperators, LinearAlgebra, Test, BenchmarkTools

┌ Info: Recompiling stale cache file /home/hakkelt/.julia/compiled/v1.2/FunctionOperators/2tgjG.ji for FunctionOperators [98ce4118-165c-488a-9b71-2bb5aff4e594]
└ @ Base loading.jl:1240


## Test all constructors

In [2]:
@testset "Constructors - StructDefs.jl" begin
    @testset "Proper constructors" begin
        @testset "Keyword constructors" begin
            @test FunctionOperator{Float64}(name = "Op₁",
                forw = x -> x, backw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(
                forw = x -> x, backw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(name = "Op₁",
                forw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(forw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
        end
        @testset "Positional constructors" begin
            @test FunctionOperator{Float64}("Op₁", x -> x, x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}(x -> x, x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}("Op₁", x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}(x -> x,
                (1,), (1,)) isa FunOp
        end
    end
    @testset "Missing value" begin
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            backw = x -> x,
            inDims = (1,), outDims = (1,))
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            forw = x -> x, backw = x -> x,
            outDims = (1,))
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            forw = x -> x, backw = x -> x,
            inDims = (1,))
    end
    @testset "No arguments for forw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = () -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = () -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = () -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(forw = () -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", () -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}(() -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", () -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}(() -> x,
                (1,), (1,))
        end
    end
    @testset "Too many arguments for forw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y,z) -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = (x,y,z) -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y,z) -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(forw = (x,y,z) -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y,z) -> x,
                x -> x, (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y,z) -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y,z) -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y,z) -> x,
                (1,), (1,))
        end
    end
    @testset "Forw has more arguments as backw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y) -> x.^3, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = (x,y) -> x.^3, backw = x -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y) -> x.^3,
                x -> x, (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y) -> x.^3, x -> x,
                (1,), (1,))
        end
    end
    @testset "Unique default name" begin
        @test FunctionOperator{Float64}(x->x,(1,),(1,)) ≠
            FunctionOperator{Float64}(x->x,(1,),(1,))
    end
    @testset "Undefined backw" begin
        Op₁ = FunctionOperator{Float64}(x->x,(1,),(1,))
        @test_throws ErrorException Op₁' * [1.]
    end
end;

[37m[1mTest Summary:                | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Constructors - StructDefs.jl | [32m  33  [39m[36m   33[39m


In [3]:
@testset "Helpers.jl" begin
    Op₁ = FunctionOperator{Float64}(x -> x, (1,), (1,))
    Op₂ = FunctionOperator{Int64}(x -> x, (1,), (1,))
    Op₃ = FunctionOperator{Float64}(x -> x, (1,), (1,))
    Op₄ = FunctionOperator{Float64}(x -> x, (1,), (2,))
    buffer = Array{Float64}(undef, 1)
    @testset "eltype" begin
        @test eltype(Op₁) == Float64
        @test eltype(Op₁ * Op₃) == Float64
    end
    @testset "assertions" begin
        @testset "TypeError" begin
            @test_throws TypeError Op₁ * Op₂
            @test_throws TypeError Op₁ + Op₂
            @test_throws TypeError Op₁ - Op₂
            @test_throws TypeError (Op₁ + Op₁) * Op₂
            @test_throws TypeError (Op₁ + Op₁) + Op₂
            @test_throws TypeError (Op₁ + Op₁) - Op₂
            @test_throws TypeError Op₁ * (Op₂ + Op₂)
            @test_throws TypeError Op₁ + (Op₂ + Op₂)
            @test_throws TypeError Op₁ - (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) * (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) + (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) - (Op₂ + Op₂)
            @test_throws TypeError mul!(Array{Int64}(undef, 1), Op₁, [1.])
            @test_throws TypeError mul!(Array{Int64}(undef, 1), Op₁ * Op₂, [1.])
            @test_throws TypeError mul!(buffer, Op₁, [1])
            @test_throws TypeError mul!(buffer, Op₁ * Op₂, [1])
        end
        @testset "DimensionError" begin
            @test_throws DimensionMismatch Op₁ * [1.,2.]
            @test_throws DimensionMismatch Op₁ * Op₃ * [1.,2.]
            @test_throws DimensionMismatch (Op₁ + Op₃) * [1.,2.]
            @test_throws DimensionMismatch (Op₁ - Op₃) * [1.,2.]
            @test_throws DimensionMismatch Op₃ * Op₄
            @test_throws DimensionMismatch (Op₄ + Op₃)
            @test_throws DimensionMismatch (Op₄ - Op₃)
            @test_throws DimensionMismatch (Op₄ + I)
            @test_throws DimensionMismatch (Op₄ - I)
            @test_throws DimensionMismatch (Op₄ + 3I)
            @test_throws DimensionMismatch (Op₄ - 3I)
            @test_throws DimensionMismatch mul!(buffer, Op₁, [1., 2.])
            @test_throws DimensionMismatch mul!(buffer, Op₁ * Op₃, [1.,2.])
            @test_throws DimensionMismatch mul!(Array{Float64}(undef, 2), Op₁, [1.])
            @test_throws DimensionMismatch mul!(Array{Float64}(undef, 2), Op₁ * Op₃, [1.])
        end
    end
end;

[37m[1mTest Summary: | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Helpers.jl    | [32m  33  [39m[36m   33[39m


In [4]:
@testset "normalizeExpression - Auxiliary.jl" begin
    @testset "normalizeExpression" begin
        normE(str) = FunctionOperators.normalizeExpression(str)
        @test normE("A * (B * C * D)' + (A * (B' * C * D)')'") ==
                normE("(B' * C * D) * A' + A * (B * C * D)'") ==
                "B' * C * D * A' + A * D' * C' * B'"
        @test normE("A * (A + ((B -C) -E) + (e - f))") ==
                normE("A * (((B -E) -C)+A + (e - f))") == 
                "A * ((e - f) + A + ((B - E) - C))"
        @test normE("q + (A + ((B -C) -E) + (e - f)) * W") ==
                normE("(A + ((B -C) -E) + (e - f)) * W + q") ==
                "q + (e * W - f * W) + A * W + ((B * W - E * W) - C * W)"
        @test normE("A * (A + ((B -C) -E) + (e - f)) * W") ==
                normE("A * (((B -E) -C) * W + A * W + (e - f) * W)") ==
                "A * ((e * W - f * W) + A * W + ((B * W - E * W) - C * W))"
        @test normE("(A + B) * (C + D)") == "B * (D + C) + A * (D + C)"
        @test normE("(A + B) * (C * D)") == normE("(A + B) * C * D") ==
                "B * C * D + A * C * D"
        @test normE("(a + b) * (A + B) * (C + D)") ==
                "b * (B * (D + C) + A * (D + C)) + a * (B * (D + C) + A * (D + C))"
    end
end;

[37m[1mTest Summary: | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Auxiliary.jl  | [32m   7  [39m[36m    7[39m


In [8]:
@testset "BuildCompTree.jl + getPlan.jl" begin
    Op₁ = FunctionOperator{Float64}("Op₁",
        x -> x .^ 3, x -> cbrt.(x), (10,10), (10,10))
    Op₂ = FunctionOperator{Float64}("Op₂",
        x -> x .+ 2, x -> x .- 2, (10,10), (10,10))
    bOp₁ = FunctionOperator{Float64}("bOp₁",
        (b,x) -> b.=x.^3, (b,x) -> b.=cbrt.(x), (10,10), (10,10))
    bOp₂ = FunctionOperator{Float64}("bOp₂",
        (b,x) -> b.=x.+2, (b,x) -> b.=x.-2, (10,10), (10,10))
    w = reshape([1 2 3 4 5], 1, 1, 5)
    Op₃ = FunctionOperator{Float64}("Op₃",
        x -> x .* w, x -> x[:,:,1], (10,10), (10,10,5))
    Op₄ = FunctionOperator{Float64}("Op₄",
        x -> repeat(x, outer=(1,1,5)), x -> x[:,:,5], (10,10), (10, 10, 5))
    bOp₃ = FunctionOperator{Float64}("bOp₃",
        (b,x) -> broadcast!(*, b, reshape(x, 10, 10, 1), w), (b,x) -> b .= x[:,:,1], (10,10), (10,10,5))
    bOp₄ = FunctionOperator{Float64}("bOp₄",
        (b,x) -> b.=repeat(x, outer=(1,1,5)), (b,x) -> b .= x[:,:,5], (10,10), (10, 10, 5))
    Op₅ = FunctionOperator{Float64}("Op₅",
        x -> -x .^ 3, x -> -cbrt.(x), (10,10,5), (10,10,5))
    Op₆ = FunctionOperator{Float64}("Op₆",
        x -> x .+ 5, x -> x .- 5, (10,10,5), (10,10,5))
    bOp₅ = FunctionOperator{Float64}("bOp₅",
        (b,x) -> b.=-x.^3, (b,x) -> b.=-cbrt.(x), (10,10,5), (10,10,5))
    bOp₆ = FunctionOperator{Float64}("bOp₆",
        (b,x) -> b.=x.+5, (b,x) -> b.=x.-5, (10,10,5), (10,10,5))
    @testset "Fidelity (manually checked)" begin
        @test Op₁ * (ones(10,10)*2) == ones(10,10)*8
        @test Op₁' * (ones(10,10)*8) == ones(10,10)*2
        @test Op₃ * Op₁ * (ones(10,10)*2) == ones(10,10)*8 .* w
        @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == Op₁' * Op₃' * (ones(10,10)*8 .* w)
        @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == ones(10,10)*2
        @test (Op₃ + Op₄) * (ones(10,10)*2) == Op₃ * (ones(10,10)*2) + Op₄ * (ones(10,10)*2)
        @test (Op₃ - Op₄) * (ones(10,10)*2) == Op₃ * (ones(10,10)*2) - Op₄ * (ones(10,10)*2)
        @test 3I * Op₃ * (ones(10,10)*2) == (ones(10,10)*2 .* w) * 3
        @test Op₃ * 3I * (ones(10,10)*2) == ones(10,10)*6 .* w
        @test (Op₃ * 3I)' * (ones(10,10)*2 .* w) == ones(10,10)*6
        @test (Op₅ + 3I) * (ones(10,10,5)*2) == Op₅ * (ones(10,10,5)*2) + (ones(10,10,5)*6)
        @test (Op₅ - 3I) * (ones(10,10,5)*2) == Op₅ * (ones(10,10,5)*2) - (ones(10,10,5)*6)
        @test (3I + Op₅) * (ones(10,10,5)*2) == (ones(10,10,5)*6) + Op₅ * (ones(10,10,5)*2)
        @test (3I - Op₅) * (ones(10,10,5)*2) == (ones(10,10,5)*6) - Op₅ * (ones(10,10,5)*2)
    end
    @testset "Adjoint of addition/substraction" begin
        @test_throws ErrorException (Op₃ + Op₄)' * ones(10,10,5)
        @test_throws ErrorException (Op₃ - Op₄)' * ones(10,10,5)
    end
    @testset "Automated" begin
        @testset "Combine" begin
            function combineMul(item1, item2)
                if item1.op.inDims ≠ item2.op.outDims
                    @test_throws DimensionMismatch item1.op * item2.op
                    missing
                else
                    (op = item1.op * item2.op,
                    forw = x -> (x₁ = item2.op * x; item1.op * x₁),
                    backw = x -> (x₁ = item1.op' * x; item2.op' * x₁),
                    hasAddOrSub = item1.hasAddOrSub || item2.hasAddOrSub)
                end
            end
            function combineMulScalingRight(item1)
                (op = item1.op * 5I,
                forw = x -> (x₁ = 5 * x; item1.op * x₁),
                backw = x -> (x₁ = item1.op' * x; 5 * x₁),
                hasAddOrSub = item1.hasAddOrSub)
            end
            function combineMulScalingLeft(item1)
                (op = 4I * item1.op,
                forw = x -> (x₁ = item1.op * x; 4 * x₁),
                backw = x -> (x₁ = 4 * x; item1.op' * x₁),
                hasAddOrSub = item1.hasAddOrSub)
            end
            function combineAdd(item1, item2)
                if item1.op.inDims ≠ item2.op.inDims || item1.op.outDims ≠ item2.op.outDims
                    @test_throws DimensionMismatch item1.op + item2.op
                    missing
                else
                    (op = item1.op + item2.op,
                    forw = x -> (x₁ = item1.op * x; x₂ = item2.op * x; x₁ + x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineAddScalingRight(item1)
                if item1.op.inDims ≠ item1.op.outDims
                    @test_throws DimensionMismatch item1.op + 6I
                    missing
                else
                    (op = item1.op + 6I,
                    forw = x -> (x₁ = item1.op * x; x₂ = 6 * x; x₁ + x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineAddScalingLeft(item1)
                if item1.op.inDims ≠ item1.op.outDims
                    @test_throws DimensionMismatch 7I + item1.op
                    missing
                else
                    (op = 7I + item1.op,
                    forw = x -> (x₁ = 7 * x; x₂ = item1.op * x; x₁ + x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineSub(item1, item2)
                if item1.op.inDims ≠ item2.op.inDims || item1.op.outDims ≠ item2.op.outDims
                    @test_throws DimensionMismatch item1.op - item2.op
                    missing
                else
                    (op = item1.op - item2.op,
                    forw = x -> (x₁ = item1.op * x; x₂ = item2.op * x; x₁ - x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineSubScalingRight(item1)
                if item1.op.inDims ≠ item1.op.outDims
                    @test_throws DimensionMismatch item1.op - 3I
                    missing
                else
                    (op = item1.op - 3I,
                    forw = x -> (x₁ = item1.op * x; x₂ = 3 * x; x₁ - x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineSubScalingLeft(item1)
                if item1.op.inDims ≠ item1.op.outDims
                    @test_throws DimensionMismatch 11I - item1.op
                    missing
                else
                    (op = 11I - item1.op,
                    forw = x -> (x₁ = 11 * x; x₂ = item1.op * x; x₁ - x₂),
                    backw = x -> throw(AssertionError("This should not be invoked")),
                    hasAddOrSub = true)
                end
            end
            function combineAdjoint(item1)
                if item1.hasAddOrSub
                    @test_throws ErrorException item1.op'
                    missing
                else
                    (op = item1.op',
                    forw = item1.backw,
                    backw = item1.forw,
                    hasAddOrSub = false)
                end
            end
            function allTypeOfCombinationOfOne(item1)
                [combineMulScalingRight(item1), combineMulScalingLeft(item1),
                 combineAddScalingRight(item1), combineAddScalingLeft(item1),
                 combineSubScalingRight(item1), combineSubScalingLeft(item1),
                 combineAdjoint(item1)]
            end
            function allTypeOfCombinationOfTwo(item1, item2)
                [combineMul(item1, item2), combineAdd(item1, item2), combineSub(item1, item2)]
            end
            function Cartesian_product_with_itself(list)
                new_list1 = [allTypeOfCombinationOfOne(item1) for item1 in list]
                new_list2 = [allTypeOfCombinationOfTwo(item1, item2)
                    for item1 in list, item2 in list]
                collect(skipmissing(vcat(list, new_list1..., new_list2...)))
            end
            Ops = [Op₁, bOp₁, Op₃, bOp₃, Op₅, bOp₅]
            global list
            list = [(op = op, forw = x -> op * x, backw = x -> op' * x, hasAddOrSub = false)
                for op in Ops]
            list = Cartesian_product_with_itself(list)
            # it would be nice to repeat this step, but it would be never completed...
            #list = Cartesian_product_with_itself(list)
            # instead:
            function combine_special_cases(list)
                new_list = []
                push!(new_list, [combineMul(item1, list[1]) for item1 in list]...)
                push!(new_list, [combineMul(list[1], item1) for item1 in list]...)
                push!(new_list, [combineMul(item1, list[2]) for item1 in list]...)
                push!(new_list, [combineMul(list[2], item1) for item1 in list]...)
                push!(new_list, [combineMul(item1, list[3]) for item1 in list]...)
                push!(new_list, [combineMul(list[3], item1) for item1 in list]...)
                push!(new_list, [combineMul(item1, list[4]) for item1 in list]...)
                push!(new_list, [combineMul(list[4], item1) for item1 in list]...)
                spec₁ = combineAdd(list[1], list[2])
                push!(new_list, [combineMul(item1, spec₁) for item1 in list]...)
                push!(new_list, [combineAdd(item1, spec₁) for item1 in list]...)
                push!(new_list, [combineMul(spec₁, item1) for item1 in list]...)
                spec₂ = combineAdd(list[2], list[1])
                push!(new_list, [combineMul(item1, spec₂) for item1 in list]...)
                push!(new_list, [combineSub(item1, spec₁) for item1 in list]...)
                push!(new_list, [combineMul(spec₂, item1) for item1 in list]...)
                list = collect(skipmissing(vcat(list, new_list...)))
                new_list = []
                push!(new_list, [combineAdjoint(item1) for item1 in list]...)
                collect(skipmissing(vcat(list, new_list...)))
            end
            list = combine_special_cases(list)
            println("Number of generated operators: ", length(list))
        end
        @testset "Fidelity" begin
            global list
            data₁ = [sin(i+j) for i=1:10, j=1:10]
            data₂ = [sin(i+j+k) for i=1:10, j=1:10, k=1:5]
            getName(op) = op isa FunctionOperator && op.adjoint ? op.name*"'" : op.name
            normName(str) = replace(FunctionOperators.normalizeExpression(getName(str)), "b" => "")
            results = [(name = getName(op.op),
                        normalized_name = normName(op.op),
                        mult_res = op.op * (op.op.inDims == size(data₁) ? data₁ : data₂),
                        forw_res = op.forw(op.op.inDims == size(data₁) ? data₁ : data₂),
                        plan = op.op isa FunctionOperator ? op.op.name : op.op.plan_string)
                            for op in list]
            for res in results
                res.mult_res ≠ res.forw_res && println(res.name, ", ", res.plan)
                @test res.mult_res == res.forw_res
            end
            counter = 0
            for (i,res1) in enumerate(results), (j,res2) in enumerate(results)
                i >= j && continue
                if res1.normalized_name == res2.normalized_name
                    counter += 1
                    res1.mult_res ≠ res2.mult_res && println(i, ", ", res1.name, ", ", res1.plan, "\n", j, res2.name, ", ", res2.plan)
                    @test res1.mult_res == res2.mult_res
                end
            end
            println("Pairwise matches between operators: ", counter, " (match means same functionality that checked for same result)")
        end
    end
end;

Number of generated operators: 752
Pairwise matches between operators: 2293 (match means same functionality that checked for same result)
[37m[1mTest Summary:                 | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
BuildCompTree.jl + getPlan.jl | [32m4161  [39m[36m 4161[39m


In [1]:
using FunctionOperators
using LinearAlgebra
using Test

┌ Info: Precompiling FunctionOperators [98ce4118-165c-488a-9b71-2bb5aff4e594]
└ @ Base loading.jl:1242


In [3]:
@testset "Macro - Auxiliary.jl" begin
    @testset "🔝 marker" begin
        result, var1, var2, var3 = rand(4)
        @♻ for i=1:5
            result = 🔝(var1 + var2) * var3
        end
        @test 🔝_1 == var1 + var2
        @test result == (var1 + var2) * var3
    end
    @testset "🔃 marker" begin
        result, var1, var2, var3 = rand(3,3), rand(3,3), rand(3,3), rand(3,3)
        @♻ for i=1:5
            var2 .= rand(3,3)
            result .= var1 * 🔃(var2 + var3)
            @test 🔃_1 == var2 + var3
        end
        @test result == var1 * (var2 + var3)
        @♻ for i=1:5
            result = var1 * 🔃(var2 * var3)
            @test 🔃_1 == var2 * var3
        end
        @test result == var1 * (var2 * var3)
    end
    @testset "@🔃 marker" begin
        result, var1, var2 = rand(3,3), rand(3,3), rand(3,3)
        @♻ for i=1:5
            @🔃 result = var1 * var2
        end
        @test result == var1 * var2
    end
    @testset "nesting" begin
        result, var1, var2, var3 = rand(3,3), rand(3,3), rand(3,3), rand(3,3)
        @♻ for i=1:5
            @🔃 result = 🔃(🔝(var1 + var2) * var3)
            @test 🔃_1 == (var1 + var2) * var3
        end
        @test 🔝_1 == var1 + var2
        @test result == (var1 + var2) * var3
    end
end;

[37m[1mTest Summary:        | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Macro - Auxiliary.jl | [32m  22  [39m[36m   22[39m


In [14]:
@testset "Performance" begin
    data = [sin(i+j+k)^2 for i=1:300, j=1:300, k=1:50]
    bOp₁ = FunctionOperator{Float64}(name="Op₁",
        forw = (buffer, x) -> buffer .= x.^2,
        backw = (buffer, x) -> broadcast!(sqrt, buffer, x),
        inDims = (300, 300, 50), outDims = (300, 300, 50))
    weights = [sin((i-j)*l) + 1 for i=1:300, j=1:300, k=1:50, l=1:10]
    bOp₂ = FunctionOperator{Float64}(name="Op₂",
        forw = (buffer,x) -> buffer .= reshape(x, 300, 300, 50, 1) .* weights,
        backw = (buffer,x) -> dropdims(sum!(reshape(buffer, 300, 300, 50, 1), x ./ weights), dims=4),
        inDims=(300, 300, 50), outDims=(300, 300, 50, 10))
    combined = bOp₂ * (bOp₁ - 2.5*I) * bOp₁'
    combined * data
    output = Array{Float64}(undef, 300, 300, 50, 10)
    function getAggregatedFunction()
        weights = [sin((i-j)*l) + 1 for i=1:300, j=1:300, k=1:50, l=1:10]
        buffer2 = Array{Float64}(undef, 300, 300, 50)
        buffer3 = Array{Float64}(undef, 300, 300, 50)
        buffer4 = Array{Float64}(undef, 300, 300, 50)
        (buffer, x) -> begin
            broadcast!(sqrt, buffer2, x)  # Of course, this two lines can be optimized to
            buffer3 .= buffer2 .^ 2       # (√x)^2 = |x|, but let's now avoid this fact
            broadcast!(-, buffer3, buffer3, broadcast!(*, buffer4, 2.5, buffer2))
            buffer .= reshape(buffer3, 300, 300, 50, 1) .* weights
        end
    end
    aggrFun = getAggregatedFunction()
    t1 = @belapsed mul!(output, combined, data)
    t2 = @belapsed aggrFun(output, data)
    @test t1 / t2 ≈ 1 atol=0.05
end

[37m[1mTest Summary: | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Performance   | [32m   1  [39m[36m    1[39m


Test.DefaultTestSet("Performance", Any[], 1, false)

In [1]:
Op₁ = FunctionOperator{Float64}("Op₁",
    x -> x .^ 3, x -> cbrt.(x), (10,10), (10,10))
Op₂ = FunctionOperator{Float64}("Op₂",
    x -> x .+ 2, x -> x .- 2, (10,10), (10,10))
bOp₁ = FunctionOperator{Float64}("Op₁",
    (b,x) -> b.=x.^3, (b,x) -> b.=cbrt.(x), (10,10), (10,10))
bOp₂ = FunctionOperator{Float64}("Op₂",
    (b,x) -> b.=x.+2, (b,x) -> b.=x.-2, (10,10), (10,10))
w = reshape([1 2 3 4 5], 1, 1, 5)
Op₃ = FunctionOperator{Float64}("Op₃",
    x -> x .* w, x -> x[:,:,1], (10,10), (10,10,5))
Op₄ = FunctionOperator{Float64}("Op₄",
    x -> repeat(x, outer=(1,1,5)), x -> x[:,:,5], (10,10), (10, 10, 5))
bOp₃ = FunctionOperator{Float64}("Op₃",
    (b,x) -> broadcast!(*, b, x, w), (b,x) -> b .= x[:,:,1], (10,10), (10,10,5))
bOp₄ = FunctionOperator{Float64}("Op₄",
    (b,x) -> b.=repeat(x, outer=(1,1,5)), (b,x) -> b .= x[:,:,5], (10,10), (10, 10, 5))
Op₅ = FunctionOperator{Float64}("Op₅",
    x -> -x .^ 3, x -> -cbrt.(x), (10,10,5), (10,10,5))
Op₆ = FunctionOperator{Float64}("Op₆",
    x -> x .+ 5, x -> x .- 5, (10,10,5), (10,10,5))
bOp₅ = FunctionOperator{Float64}("Op₅",
    (b,x) -> b.=-x.^3, (b,x) -> b.=-cbrt.(x), (10,10,5), (10,10,5))
bOp₆ = FunctionOperator{Float64}("Op₆",
    (b,x) -> b.=x.+5, (b,x) -> b.=x.-5, (10,10,5), (10,10,5))

┌ Info: Recompiling stale cache file /home/hakkelt/.julia/compiled/v1.2/FunctionOperators/2tgjG.ji for FunctionOperators [98ce4118-165c-488a-9b71-2bb5aff4e594]
└ @ Base loading.jl:1240


FunctionOperator with eltype Float64
    Name: Op₆
    Input dimensions: (10, 10, 5)
    Output dimensions: (10, 10, 5)

In [7]:
combined = ((Op₁ - Op₁) + (Op₁ + bOp₁))

FunctionOperatorComposite with eltype Float64
    Name: ((Op₁ - Op₁) + (Op₁ + Op₁))
    Input dimensions: (10, 10)
    Output dimensions: (10, 10)
    Plan: no plan

In [8]:
FO_settings.verbose = true
combined * ones(10, 10);

Allocation of buffer1, size: (10, 10)
Allocation of buffer2, size: (10, 10)
Allocation of buffer3, size: (10, 10)
Plan calculated: buffer1 .= (buffer2 .= x; broadcast!(+, buffer1, (buffer3 .= buffer2; broadcast!(-, buffer1, Op₁.forw(buffer3), Op₁.forw(buffer3))), (buffer3 .= buffer2; broadcast!(+, buffer1, Op₁.forw(buffer3), Op₁.forw(buffer1, buffer3)))))


In [4]:
combined

FunctionOperatorComposite with eltype Float64
    Name: Op₃ * (Op₁ + (6*I))
    Input dimensions: (10, 10)
    Output dimensions: (10, 10, 5)
    Plan: Op₃.forw((buffer2 .= x; broadcast!(+, buffer3, Op₁.forw(buffer3, buffer2), broadcast!(*, buffer2, 6.0, buffer2))))

In [5]:
FO_settings.verbose = false

false