In [1]:
using FunctionOperators
using LinearAlgebra
using Test

┌ Info: Recompiling stale cache file /home/hakkelt/.julia/compiled/v1.2/FunctionOperators/2tgjG.ji for FunctionOperators [98ce4118-165c-488a-9b71-2bb5aff4e594]
└ @ Base loading.jl:1240


## Test all constructors

In [2]:
@testset "Constructors - StructDefs.jl" begin
    @testset "Proper constructors" begin
        @testset "Keyword constructors" begin
            @test FunctionOperator{Float64}(name = "Op₁",
                forw = x -> x, backw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(
                forw = x -> x, backw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(name = "Op₁",
                forw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
            @test FunctionOperator{Float64}(forw = x -> x,
                inDims = (1,), outDims = (1,)) isa FunOp
        end
        @testset "Positional constructors" begin
            @test FunctionOperator{Float64}("Op₁", x -> x, x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}(x -> x, x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}("Op₁", x -> x,
                (1,), (1,)) isa FunOp
            @test FunctionOperator{Float64}(x -> x,
                (1,), (1,)) isa FunOp
        end
    end
    @testset "Missing value" begin
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            backw = x -> x,
            inDims = (1,), outDims = (1,))
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            forw = x -> x, backw = x -> x,
            outDims = (1,))
        @test_throws ErrorException FunctionOperator{Float64}(name = "Op₁",
            forw = x -> x, backw = x -> x,
            inDims = (1,))
    end
    @testset "No arguments for forw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = () -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = () -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = () -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(forw = () -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", () -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}(() -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", () -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}(() -> x,
                (1,), (1,))
        end
    end
    @testset "Too many arguments for forw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y,z) -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = (x,y,z) -> x, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y,z) -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(forw = (x,y,z) -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y,z) -> x,
                x -> x, (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y,z) -> x, x -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y,z) -> x,
                (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y,z) -> x,
                (1,), (1,))
        end
    end
    @testset "Forw has more arguments as backw" begin
        @testset "Keyword constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}(name = "Op₁",
                forw = (x,y) -> x.^3, backw = x -> x,
                inDims = (1,), outDims = (1,))
            @test_throws AssertionError FunctionOperator{Float64}(
                forw = (x,y) -> x.^3, backw = x -> x,
                inDims = (1,), outDims = (1,))
        end
        @testset "Positional constructors" begin
            @test_throws AssertionError FunctionOperator{Float64}("Op₁", (x,y) -> x.^3,
                x -> x, (1,), (1,))
            @test_throws AssertionError FunctionOperator{Float64}((x,y) -> x.^3, x -> x,
                (1,), (1,))
        end
    end
    @testset "Unique default name" begin
        @test FunctionOperator{Float64}(x->x,(1,),(1,)) ≠
            FunctionOperator{Float64}(x->x,(1,),(1,))
    end
    @testset "Undefined backw" begin
        Op₁ = FunctionOperator{Float64}(x->x,(1,),(1,))
        @test_throws ErrorException Op₁' * [1.]
    end
end;

[37m[1mTest Summary:                | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Constructors - StructDefs.jl | [32m  33  [39m[36m   33[39m


In [3]:
@testset "Helpers.jl" begin
    Op₁ = FunctionOperator{Float64}(x -> x, (1,), (1,))
    Op₂ = FunctionOperator{Int64}(x -> x, (1,), (1,))
    Op₃ = FunctionOperator{Float64}(x -> x, (1,), (1,))
    Op₄ = FunctionOperator{Float64}(x -> x, (1,), (2,))
    buffer = Array{Float64}(undef, 1)
    @testset "eltype" begin
        @test eltype(Op₁) == Float64
        @test eltype(Op₁ * Op₃) == Float64
    end
    @testset "assertions" begin
        @testset "TypeError" begin
            @test_throws TypeError Op₁ * Op₂
            @test_throws TypeError Op₁ + Op₂
            @test_throws TypeError Op₁ - Op₂
            @test_throws TypeError (Op₁ + Op₁) * Op₂
            @test_throws TypeError (Op₁ + Op₁) + Op₂
            @test_throws TypeError (Op₁ + Op₁) - Op₂
            @test_throws TypeError Op₁ * (Op₂ + Op₂)
            @test_throws TypeError Op₁ + (Op₂ + Op₂)
            @test_throws TypeError Op₁ - (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) * (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) + (Op₂ + Op₂)
            @test_throws TypeError (Op₁ + Op₁) - (Op₂ + Op₂)
            @test_throws TypeError mul!(Array{Int64}(undef, 1), Op₁, [1.])
            @test_throws TypeError mul!(Array{Int64}(undef, 1), Op₁ * Op₂, [1.])
            @test_throws TypeError mul!(buffer, Op₁, [1])
            @test_throws TypeError mul!(buffer, Op₁ * Op₂, [1])
        end
        @testset "DimensionError" begin
            @test_throws DimensionMismatch Op₁ * [1.,2.]
            @test_throws DimensionMismatch Op₁ * Op₃ * [1.,2.]
            @test_throws DimensionMismatch (Op₁ + Op₃) * [1.,2.]
            @test_throws DimensionMismatch (Op₁ - Op₃) * [1.,2.]
            @test_throws DimensionMismatch Op₃ * Op₄
            @test_throws DimensionMismatch (Op₄ + Op₃)
            @test_throws DimensionMismatch (Op₄ - Op₃)
            @test_throws DimensionMismatch (Op₄ + I)
            @test_throws DimensionMismatch (Op₄ - I)
            @test_throws DimensionMismatch (Op₄ + 3I)
            @test_throws DimensionMismatch (Op₄ - 3I)
            @test_throws DimensionMismatch mul!(buffer, Op₁, [1., 2.])
            @test_throws DimensionMismatch mul!(buffer, Op₁ * Op₃, [1.,2.])
            @test_throws DimensionMismatch mul!(Array{Float64}(undef, 2), Op₁, [1.])
            @test_throws DimensionMismatch mul!(Array{Float64}(undef, 2), Op₁ * Op₃, [1.])
        end
    end
end;

[37m[1mTest Summary: | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
Helpers.jl    | [32m  33  [39m[36m   33[39m


In [4]:
@testset "BuildCompTree.jl + getPlan.jl" begin
    Op₁ = FunctionOperator{Float64}("Op₁",
        x -> x .^ 3, x -> cbrt.(x), (10,10), (10,10))
    Op₂ = FunctionOperator{Float64}("Op₂",
        x -> x .+ 2, x -> x .- 2, (10,10), (10,10))
    bOp₁ = FunctionOperator{Float64}("Op₁",
        (b,x) -> b.=x.^3, (b,x) -> b.=cbrt.(x), (10,10), (10,10))
    bOp₂ = FunctionOperator{Float64}("Op₂",
        (b,x) -> b.=x.+2, (b,x) -> b.=x.-2, (10,10), (10,10))
    w = reshape([1 2 3 4 5], 1, 1, 5)
    Op₃ = FunctionOperator{Float64}("Op₃",
        x -> x .* w, x -> x[:,:,1], (10,10), (10,10,5))
    Op₄ = FunctionOperator{Float64}("Op₄",
        x -> repeat(x, outer=(1,1,5)), x -> x[:,:,5], (10,10), (10, 10, 5))
    bOp₃ = FunctionOperator{Float64}("Op₃",
        (b,x) -> broadcast!(b, x, w), (b,x) -> b .= x[:,:,1], (10,10), (10,10,5))
    bOp₄ = FunctionOperator{Float64}("Op₄",
        (b,x) -> b.=repeat(x, outer=(1,1,5)), (b,x) -> b .= x[:,:,5], (10,10), (10, 10, 5))
    Op₅ = FunctionOperator{Float64}("Op₅",
        x -> -x .^ 3, x -> -cbrt.(x), (10,10,5), (10,10,5))
    Op₆ = FunctionOperator{Float64}("Op₆",
        x -> x .+ 5, x -> x .- 5, (10,10,5), (10,10,5))
    bOp₅ = FunctionOperator{Float64}("Op₅",
        (b,x) -> b.=-x.^3, (b,x) -> b.=-cbrt.(x), (10,10,5), (10,10,5))
    bOp₆ = FunctionOperator{Float64}("Op₆",
        (b,x) -> b.=x.+5, (b,x) -> b.=x.-5, (10,10,5), (10,10,5))
    @testset "Fidelity (manually checked)" begin
        @test Op₁ * (ones(10,10)*2) == ones(10,10)*8
        @test Op₁' * (ones(10,10)*8) == ones(10,10)*2
        @test Op₃ * Op₁ * (ones(10,10)*2) == ones(10,10)*8 .* w
        @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == Op₁' * Op₃' * (ones(10,10)*8 .* w)
        @test (Op₃ * Op₁)' * (ones(10,10)*8 .* w) == ones(10,10)*2
        @test (Op₃ + Op₄) * (ones(10,10)*2) == Op₃ * (ones(10,10)*2) + Op₄ * (ones(10,10)*2)
        @test (Op₃ - Op₄) * (ones(10,10)*2) == Op₃ * (ones(10,10)*2) - Op₄ * (ones(10,10)*2)
        @test 3I * Op₃ * (ones(10,10)*2) == (ones(10,10)*2 .* w) * 3
        @test Op₃ * 3I * (ones(10,10)*2) == ones(10,10)*6 .* w
        @test (Op₃ * 3I)' * (ones(10,10)*2 .* w) == ones(10,10)*6
        @test (Op₅ + 3I) * (ones(10,10,5)*2) == Op₅ * (ones(10,10,5)*2) + (ones(10,10,5)*6)
        @test (Op₅ - 3I) * (ones(10,10,5)*2) == Op₅ * (ones(10,10,5)*2) - (ones(10,10,5)*6)
        @test (3I + Op₅) * (ones(10,10,5)*2) == (ones(10,10,5)*6) + Op₅ * (ones(10,10,5)*2)
        @test (3I - Op₅) * (ones(10,10,5)*2) == (ones(10,10,5)*6) - Op₅ * (ones(10,10,5)*2)
    end
    @testset "Adjoint of addition" begin
        @test_throws ErrorException (Op₃ + Op₄)' * ones(10,10,5)
    end
end;

[37m[1mTest Summary:                 | [22m[39m[32m[1mPass  [22m[39m[36m[1mTotal[22m[39m
BuildCompTree.jl + getPlan.jl | [32m  15  [39m[36m   15[39m


In [64]:
Op₁ = FunctionOperator{Float64}(name = "Op₁",
    forw = x -> x.^3, backw = x -> cbrt.(x),
    inDims = (300, 300, 50), outDims = (300, 300, 50))

FunctionOperator with eltype Float64
    Name: Op₁
    Input dimensions: (300, 300, 50)
    Output dimensions: (300, 300, 50)

In [6]:
bOp₁ = FunctionOperator{Float64}(name="Op₁",
    forw = (buffer, x) -> buffer .= x.^2,
    backw = (buffer, x) -> broadcast!(sqrt, buffer, x),
    inDims = (300, 300, 50), outDims = (300, 300, 50))

FunctionOperator with eltype Float64
    Name: Op₁
    Input dimensions: (300, 300, 50)
    Output dimensions: (300, 300, 50)

In [7]:
weights = [sin((i-j)*l) + 1 for i=1:300, j=1:300, k=1:50, l=1:10]
Op₂ = FunctionOperator{Float64}("Op₂",
    x -> reshape(x, 300, 300, 50, 1) .* weights, # broadcasting: 3D to 4D
    x -> reshape(sum(x ./ weights, dims=4), 300, 300, 50),
    (300, 300, 50), (300, 300, 50, 10))

FunctionOperator with eltype Float64
    Name: Op₂
    Input dimensions: (300, 300, 50)
    Output dimensions: (300, 300, 50, 10)

In [9]:
bOp₂ = FunctionOperator{Float64}(name="Op₂",
    forw = (buffer,x) -> buffer .= reshape(x, 300, 300, 50, 1) .* weights,
    backw = (buffer,x) -> dropdims(sum!(reshape(buffer, 300, 300, 50, 1), x ./ weights), dims=4),
    inDims=(300, 300, 50), outDims=(300, 300, 50, 10))

FunctionOperator with eltype Float64
    Name: Op₂
    Input dimensions: (300, 300, 50)
    Output dimensions: (300, 300, 50, 10)

In [11]:
Op₃ = FunctionOperator{Float64}(name = "Op₃",
    forw = x -> x .* 2, backw = x -> x ./ 2,
    inDims = (300, 300, 50, 10), outDims = (300, 300, 50, 10))

FunctionOperator with eltype Float64
    Name: Op₃
    Input dimensions: (300, 300, 50, 10)
    Output dimensions: (300, 300, 50, 10)

In [None]:
Op₃ = FunctionOperator{Float64}(name = "Op₃",
    forw = x -> x .* 2, backw = x -> x ./ 2,
    inDims = (300, 300, 50), outDims = (300, 300, 50))