/
convolutions.jl
240 lines (189 loc) · 6.57 KB
/
convolutions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
export conv, plan_conv, conv_psf, plan_conv_psf
export plan_conv_buffer, plan_conv_psf_buffer
"""
conv(u, v[, dims])
Convolve `u` with `v` over `dims` dimensions with an FFT based method.
Note, that this method introduces wrap-around artifacts without
proper padding/windowing.
# Arguments
* `u` is an array in real space.
* `v` is the array to be convolved in real space as well.
* Per default `ntuple(+, min(N, M)))` means that we perform the convolution
over all dimensions of that array which has less dimensions.
If `dims` is an array with integers, we perform convolution
only over these dimensions. Eg. `dims=[1,3]` would perform the convolution
over the first and third dimension. Second dimension is not convolved.
If `u` and `v` are both a real valued array we use `rfft` and hence
the output is real as well.
If either `u` or `v` is complex we use `fft` and output is hence complex.
# Examples
1D with FFT over all dimensions. We choose `v` to be a delta peak.
Therefore convolution should act as identity.
```jldoctest
julia> u = [1 2 3 4 5]
1×5 Array{Int64,2}:
1 2 3 4 5
julia> v = [0 0 1 0 0]
1×5 Array{Int64,2}:
0 0 1 0 0
julia> conv(u, v)
1×5 Matrix{Float64}:
4.0 5.0 1.0 2.0 3.0
```
2D with FFT with different `dims` arguments.
```jldoctest
julia> u = 1im .* [1 2 3; 4 5 6]
2×3 Matrix{Complex{Int64}}:
0+1im 0+2im 0+3im
0+4im 0+5im 0+6im
julia> v = [1im 0 0; 1im 0 0]
2×3 Matrix{Complex{Int64}}:
0+1im 0+0im 0+0im
0+1im 0+0im 0+0im
julia> conv(u, v)
2×3 Matrix{ComplexF64}:
-5.0+0.0im -7.0+0.0im -9.0+0.0im
-5.0+0.0im -7.0+0.0im -9.0+0.0im
```
"""
function conv(u::AbstractArray{T, N}, v::AbstractArray{D, M}, dims=ntuple(+, min(N, M))) where {T, D, N, M}
return ifft(fft(u, dims) .* fft(v, dims), dims)
end
function conv(u::AbstractArray{<:Real, N}, v::AbstractArray{<:Real, M}, dims=ntuple(+, min(N, M))) where {N, M}
return irfft(rfft(u, dims) .* rfft(v, dims), size(u, dims[1]), dims)
end
"""
conv_psf(u, psf[, dims])
`conv_psf` is a shorthand for `conv(u,ifftshift(psf))`. For examples see `conv`.
"""
function conv_psf(u::AbstractArray{T, N}, psf::AbstractArray{D, M}, dims=ntuple(+, min(N, M))) where {T, D, N, M}
return conv(u, ifftshift(psf, dims), dims)
end
"""
plan_conv(u, v [, dims]; kwargs...)
Pre-plan an optimized convolution for arrays shaped like `u` and `v` (based on pre-plan FFT)
along the given dimenions `dims`.
`dims = 1:ndims(u)` per default.
The 0 frequency of `u` must be located at the first entry.
We return two arguments:
The first one is `v_ft` (obtained by `fft(v)` or `rfft(v)`).
The second return is the convolution function `pconv`.
`pconv` itself has two arguments. `pconv(u, v_ft=v_ft)` where `u` is the object and `v_ft` the v_ft.
This function achieves faster convolution than `conv(u, u)`.
Depending whether `u` is real or complex we do `fft`s or `rfft`s
Additionally, it is possible to provide `flags=FFTW.MEASURE` as `kwargs`
to change the planning of the FFT.
# Examples
```jldoctest
julia> u = [1 2 3 4 5]
1×5 Matrix{Int64}:
1 2 3 4 5
julia> v = [1 0 0 0 0]
1×5 Matrix{Int64}:
1 0 0 0 0
julia> v_ft, pconv = plan_conv(u, v);
julia> pconv(u, v_ft)
1×5 Matrix{Float64}:
1.0 2.0 3.0 4.0 5.0
julia> pconv(u)
1×5 Matrix{Float64}:
1.0 2.0 3.0 4.0 5.0
```
"""
function plan_conv(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N);
kwargs...) where {T1, T2, N, M}
eltype_error(T1, T2)
plan = get_plan(T1)
# do the preplanning step
P = let
# FFTW.MEASURE flag might overwrite input! Hence copy!
if (:flags in keys(kwargs) &&
(getindex(kwargs, :flags) == FFTW.MEASURE || getindex(kwargs, :flags) == FFTW.PATIENT))
plan(copy(u), dims; kwargs...)
else
plan(u, dims; kwargs...)
end
end
v_ft = fft_or_rfft(T1)(v, dims)
# construct the efficient conv function
# P and P_inv can be understood like matrices
# but their computation is fast
conv = let P = P,
P_inv = inv(P),
# put a different name here! See https://discourse.julialang.org/t/type-issue-with-captured-variables-let-workaround-failed/85661
v_ft = v_ft
conv(u, v_ft=v_ft) = p_conv_aux(P, P_inv, u, v_ft)
end
return v_ft, conv
end
"""
plan_conv_buffer(u, v [, dims]; kwargs...)
Similar to [`plan_conv`](@ref) but instead uses buffers to prevent memory allocations.
Not AD friendly!
"""
function plan_conv_buffer(u::AbstractArray{T1, N}, v::AbstractArray{T2, M}, dims=ntuple(+, N);
kwargs...) where {T1, T2, N, M}
eltype_error(T1, T2)
plan = get_plan(T1)
# do the preplanning step
P_u = plan(u, dims; kwargs...)
P_v = plan(v, dims)
u_buff = P_u * u
v_ft = P_v * v
uv_buff = u_buff .* v_ft
# for fourier space we need a new plan
P = plan(u .* v, dims; kwargs...)
P_inv = inv(P)
out_buff = P_inv * uv_buff
# construct the efficient conv function
# P and P_inv can be understood like matrices
# but their computation is fast
function conv(u, v_ft=v_ft)
mul!(u_buff, P_u, u)
uv_buff .= u_buff .* v_ft
mul!(out_buff, P_inv, uv_buff)
return out_buff
end
return v_ft, conv
end
"""
plan_conv_psf_buffer(u, psf [, dims]; kwargs...) where {T, N}
`plan_conv_psf_buffer` is a shorthand for `plan_conv_buffer(u, ifftshift(psf))`. For examples see `plan_conv`.
"""
function plan_conv_psf_buffer(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N);
kwargs...) where {T, N, M}
return plan_conv_buffer(u, ifftshift(psf, dims), dims; kwargs...)
end
"""
plan_conv_psf(u, psf [, dims]; kwargs...) where {T, N}
`plan_conv_psf` is a shorthand for `plan_conv(u, ifftshift(psf))`. For examples see `plan_conv`.
"""
function plan_conv_psf(u::AbstractArray{T, N}, psf::AbstractArray{T, M}, dims=ntuple(+, N);
kwargs...) where {T, N, M}
return plan_conv(u, ifftshift(psf, dims), dims; kwargs...)
end
function p_conv_aux(P, P_inv, u, v_ft)
return (P_inv.p * ((P * u) .* v_ft .* P_inv.scale))
end
"""
fft_or_rfft(T)
Small helper function to decide whether a real
or a complex valued FFT is appropriate.
"""
function fft_or_rfft(::Type{<:Real})
return rfft
end
function fft_or_rfft(::Type{T}) where T
return fft
end
"""
get_plan(T)
Small helper function to decide whether a real
or a complex valued FFT plan is appropriate.
"""
function get_plan(::Type{<:Real})
return plan_rfft
end
function get_plan(::Type{T}) where T
return plan_fft
end