In [1]:
using Flux

In [2]:
begin
	tuple_identity(a) = identity(a)
	tuple_identity(a...) = tuple(a...)
	cat3(a,b) = cat(a,b;dims=3)
	tuplecat3(a::Tuple{AbstractArray,AbstractArray},b::Tuple{AbstractArray,AbstractArray}) = (cat3(a[1],b[1]), cat3(a[2],b[2]))
	
	tuplecat3(a::Tuple{Int64,Int64},b::Tuple{Int64,Int64}) = a .+ b
end

tuplecat3 (generic function with 2 methods)

In [3]:
begin
	function QuadSemiLinear((lin,nonlin)::Tuple{AbstractArray,AbstractArray})

		nl = size(lin,3)
		nnl = size(nonlin,3)
		layers = [selectdim(lin,3,i:i) .* selectdim(nonlin,3,j:j) for i in 1:nl for j in 1:nnl]
		for k in 1:nl
			push!(layers,selectdim(lin,3,k:k))
		end
		return (reduce(cat3, layers), nonlin)
	end

	QuadSemiLinear(lin::AbstractArray,nonlin::AbstractArray) = QuadSemiLinear((lin,nonlin))
	QuadSemiLinear(a::Int64,b::Int64) = (a*b+a, b)
	
	function QuadSemiLinearConv(input_features, channels_out)
		qsl_features = QuadSemiLinear(input_features...)
		Chain(QuadSemiLinear, Parallel(tuple, Conv((1,1), qsl_features[1]=>channels_out; stride=1, pad=SamePad()), identity))
	end

end

QuadSemiLinearConv (generic function with 1 method)

In [4]:
begin
	

	
	function WNBB_dims(passthrough_dims; Nfeatures=10)
		passthrough_features = passthrough_dims((Nfeatures, Nfeatures))
		skip_features = b -> tuplecat3(passthrough_features, (Nfeatures, b))
		return ((a,b),) -> (Nfeatures, skip_features(b)[2])
	end
		
	function WienerNetBaseBlock(input_features, passthrough, passthrough_dims; Nfeatures=10)
		## a function (linear,nonlinear) -> (linear,nonlinear) 
		## that encodes the basic building block of WienerNet
		
		## input features is (a,b)
		## QSL maps (a,b) to (a*(b+1), b)
		## Skip connection maps (a,b) to (2*a, 2*b) 
		##
		## MAKE QSL(int,int) and passthrough(int,int) return the features!!
		
		passthrough_features = passthrough_dims((Nfeatures, Nfeatures))
		skip_features = tuplecat3(passthrough_features, (Nfeatures, input_features[2]))
		
		Chain(
		
		QuadSemiLinearConv(input_features, Nfeatures),
			
		SkipConnection(
			Chain(
				Parallel(
					tuple,
					Conv((5,5), Nfeatures=>Nfeatures, stride=2, pad=SamePad()),
					Conv((5,5), input_features[2]=>Nfeatures, relu; stride=2, pad=SamePad()),
					),
				passthrough, 
				Parallel(
					tuple,
					ConvTranspose((5,5), passthrough_features[1]=>Nfeatures, stride=2, pad=SamePad()),
					ConvTranspose((5,5), passthrough_features[2]=>Nfeatures, relu; stride=2, pad=SamePad()),
					)
				),
			tuplecat3),
			
		QuadSemiLinearConv(skip_features, Nfeatures)
	
		)
	end
	
	#WienerNetBaseBlock((1,1), WienerNetBaseBlock((1,1), identity, identity), WNBB_dims(identity); Nfeatures=1)
	inner = WienerNetBaseBlock((10,10), identity, identity)
 	innerdims = WNBB_dims(identity)
	outer = WienerNetBaseBlock((1,1), inner, innerdims)
	innerdims((10,10))
	#outer((rand10(),rand10()))
end

(10, 20)