In [None]:
"Convolution operation. Used by @convolve macro to create a
convolution operation with specified input, kernel, stride
and dilation configuration"
mutable struct Convolution
    "Convolution matrix. @convolve macro creates the
    convolution matrix with respect to the expected inputsize"
    matrix
    "Convolution kernel function. @convolve macro creates
    a function to retrieve convolution kernels from convolution matrix"
    kernel
    f
end

In [None]:
_dimcheck_convolutionkernel(inputsize::Tuple{Int, Int}, kernelsize::Tuple{Int, Int}) = prod(inputsize .> kernelsize) || throw(ArgumentError("Kernel size $kernelsize cannot be larger than Input size $inputsize"))

_kernelsize(kernelsize::Tuple{Int, Int}, dilations::Tuple{Int, Int}) = ((dilations .* kernelsize) .- (dilations.-1))
_featuremapsize(inputsize::Tuple{Int, Int}, kernelsize::Tuple{Int, Int}, strides::Tuple{Int, Int}) = div.(inputsize .- kernelsize, strides) .+ 1
_convolutionkernelmargin(inputsize::Tuple{Int, Int}, kernelsize::Tuple{Int, Int}) = Int.(ceil.(inputsize .% kernelsize))


"Creates convolution matrix according to given input, kernel, stride and dilation configuration"
function im2col(weight, input_size::Tuple{Int, Int}, kernel_size::Tuple{Int, Int}, strides::Tuple{Int, Int}, dilations::Tuple{Int, Int}, fm_size::Tuple{Int, Int})
    margin = input_size.%fm_size
    ap=[]
    for i in 1:strides[1]:(fm_size[1]*strides[1])
        for j in 1:strides[2]:(fm_size[2]*strides[2])
            t1 = zeros(input_size)
            t1[i:dilations[1]:(i+kernel_size[1]-1), j:dilations[2]:(j+kernel_size[2]-1)] = weight'
            ap = prod(size(ap) .<= 0) ? t1'[:] : hcat(ap, t1'[:])
        end
    end

    return ap
end

In [None]:
"Fetches convolution kernel from given convolution matrix, input_size, kernel size, stride and dilation configuration"
function col2im(conv_matrix, input_size::Tuple{Int, Int}, weight_size::Tuple{Int, Int}, strides::Tuple{Int, Int}, dilations::Tuple{Int, Int})
    kernel_size = _kernelsize(weight_size, dilations)

    _dimcheck_convolutionkernel(input_size, kernel_size)
    fm_size = _featuremapsize(input_size, kernel_size, strides)
    margin = _convolutionkernelmargin(input_size, kernel_size)

    cki = 1
    ck = zeros(weight_size)
    for i in (1+margin[1]):strides[1]:(input_size[1]-margin[2])
        for j in (1+margin[1]):strides[2]:(input_size[2]-margin[2])
            temp_ck = reshape(conv_matrix[:,cki], input_size)'[i-margin[1]:dilations[1]:(i+(margin[1])), j-margin[1]:dilations[2]:(j+(margin[1]))]
            ck .+= temp_ck'
            cki+=1
        end
    end

    return ck ./(size(conv_matrix)[2])
end

In [None]:
i = rand(13,13,6,10)
k = rand(3,3,6,2)
s = (1,1)
d = (1,1)

strides = length(s) < 2 ? (s, s) : s

dilations = length(d) < 2 ? (d, d) : d

input_size = size(i)[1:2]
kernel_size = _kernelsize(size(k)[1:2], dilations)

_dimcheck_convolutionkernel(input_size, kernel_size)
fm_size = _featuremapsize(input_size, kernel_size, strides)

In [None]:
t1 = cat([cat([im2col(k[:,:,ki,ci], input_size, kernel_size, strides, dilations, fm_size) for ki in 1:size(k)[3]]...,dims=3) for ci in 1:size(k)[4]]...,dims=4)
cut1 = permutedims(t1,(2,1,3,4))

In [None]:
cop = Convolution(cut1, nothing, nothing)
cop.kernel = ()->cat([cat([col2im(collect(cop.matrix[:,:,ki,ci])', input_size, size(k)[1:2], strides, dilations) for ki in 1:size(k)[3]]...,dims=3) for ci in 1:size(k)[4]]...,dims=4)
# col2im(collect(cop.matrix[:,:,1,1])', input_size, size(k)[1:2], strides, dilations)

In [None]:
cop.kernel()

In [None]:
function convit(x)
    x_size = size(x)
    x_channels = x_size[3]
    x_batch = x_size[4]
   
    cm_size = size(cop.matrix)
    cm_channels = cm_size[4]
    
    reshaped_x = reshape(x, (prod(size(x)[1:3]), x_batch))
    return reshape(cat([reshape(reshape(cop.matrix[:,:,:,cmi], (cm_size[1], prod(cm_size[2:3])))*reshaped_x, (prod(fm_size), 1, x_batch)) for cmi in 1:cm_channels]...,dims=2), (fm_size..., cm_channels, x_batch))
end

In [None]:
convit(i)