In [None]:
using Flux, Zygote

In [None]:
begin
    channelcat(a::AbstractArray...) = cat(a...;dims=3)
    channelcat(a::Tuple{AbstractArray,AbstractArray}...) = channelcat.(a...)
end

In [None]:
begin
    Zygote.@adjoint function Iterators.product(xs...)
        back(::AbstractArray{Nothing}) = nothing
        back(dy::NamedTuple{(:iterators,)}) = dy.iterators
        function back(dy::AbstractArray)
          d = 1
          ntuple(length(xs)) do n
            first(dy)[n] === nothing && return nothing
            nd = _ndims(xs[n])
            dims = ntuple(i -> i<d ? i : i+nd, ndims(dy)-nd)
            d += nd
            init = zero.(first(dy)[n]) # allows for tuples, which accum can add:
            red = mapreduce(StaticGetter{n}(), accum, dy; dims=dims, init=init)
            return reshape(red, axes(xs[n]))
          end
        end
        Iterators.product(xs...), back
    end
    
    function FastQuadSemiLinear((lin,nonlin))
        nl = size(lin,3)
        nnl = size(nonlin,3)
        layers = Array{Float32}(undef, size(lin,1),size(lin,2),nl*nnl+nl, size(lin,4))
        k=0
        for i in 1:nl
            for j in 1:nnl
                k+=1
                layers[:,:,k,:] .= selectdim(lin,3,i) .* selectdim(nonlin,3,j)
            end
            layers[:,:,nl*nnl+i,:] .= selectdim(lin,3,i)
        end
        return (layers, nonlin)
    end
    FastQuadSemiLinear(lin::AbstractArray,nonlin::AbstractArray) = FastQuadSemiLinear((lin,nonlin)) 

    function ADQuadSemiLinear((lin,nonlin))
        @views out = cat([lin[:,:,i:i,:] .* nonlin[:,:,j:j,:] for i=1:size(lin,3), j=1:size(nonlin,3)]...; dims=3)
        out = cat(out, lin; dims=3)
        return (out, nonlin)
    end
                    
    ADQuadSemiLinear(lin::AbstractArray,nonlin::AbstractArray) = ADQuadSemiLinear((lin,nonlin)) 
    
    function QuadSemiLinearConv(input_features, channels_out)
        a,b = input_features
        Chain(ADQuadSemiLinear, Parallel(tuple, Conv((1,1), a*(b+1)=>channels_out; stride=1, pad=SamePad()), identity))
    end
end

In [None]:
begin

    function WienerNetBaseBlock(inner; Nfeatures=10)
        ## a function (linear,nonlinear) -> (linear,nonlinear) 
        ## that encodes the basic building block of WienerNet
        
        inner_features = (inner == identity) ? (Nfeatures,Nfeatures) : (Nfeatures,2Nfeatures)

        Chain(
        QuadSemiLinearConv((Nfeatures,Nfeatures), Nfeatures),
        SkipConnection(
            Chain(
                Parallel(
                    tuple,
                    Conv((5,5), Nfeatures=>Nfeatures, stride=2, pad=SamePad()),
                    Conv((5,5), Nfeatures=>Nfeatures, relu; stride=2, pad=SamePad()),
                    ),
                inner, 
                Parallel(
                    tuple,
                    ConvTranspose((5,5), inner_features[1]=>Nfeatures, stride=2, pad=SamePad()),
                    ConvTranspose((5,5), inner_features[2]=>Nfeatures, relu; stride=2, pad=SamePad()),
                    )
                ),
            channelcat),
        QuadSemiLinearConv((2Nfeatures,2Nfeatures), Nfeatures)
        )
    end
    WienerNetBaseBlock(;Nfeatures=10) = WienerNetBaseBlock(identity; Nfeatures=Nfeatures)
    
    function WienerNetIO(wienernet; Nfeatures=10)
        Chain(Parallel(tuple,
            Conv((5,5), 1=>Nfeatures; stride=1, pad=SamePad()),
            Conv((5,5), 1=>Nfeatures; stride=1, pad=SamePad())),
            wienernet,
            ((a,b),) -> a, 
            ConvTranspose((5,5), Nfeatures=>1; stride=1, pad=SamePad())
        )
    end
end

In [None]:
function WienerNet()
    wn = WienerNetBaseBlock()
    wn = WienerNetBaseBlock(wn)
    wn = WienerNetBaseBlock(wn)
    wn = WienerNetBaseBlock(wn)
    wn = WienerNetBaseBlock(wn)
    wn = WienerNetIO(wn)
    wn
end