-
Notifications
You must be signed in to change notification settings - Fork 10
/
bilinearlens.jl
197 lines (155 loc) · 6.06 KB
/
bilinearlens.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
export BilinearLens
@doc doc"""
BilinearLens(ϕ)
`BilinearLens` is a lensing operator that computes lensing with
bilinear interpolation. The action of the operator, as well as its
adjoint, inverse, inverse-adjoint, and gradient w.r.t. `ϕ` can all be
computed. The log-determinant of the operation is non-zero and can't
be computed.
Internally, `BilinearLens` forms a sparse matrix with the
interpolation weights, which can be applied and adjoint-ed extremely
fast (e.g. at least an order of magnitude faster than
[`LenseFlow`](@ref)). Inverse and inverse-adjoint lensing is somewhat
slower since it requires an iterative solve, here performed with the
[preconditioned generalized minimal
residual](https://en.wikipedia.org/wiki/Generalized_minimal_residual_method)
algorithm.
"""
mutable struct BilinearLens{T,Φ<:Field{<:Any,T},S} <: ImplicitOp{T}
ϕ :: Φ
sparse_repr :: S
anti_lensing_sparse_repr :: Union{S, Nothing}
end
function BilinearLens(ϕ::FlatField)
# if ϕ == 0 then just return identity operator
if norm(ϕ) == 0
return BilinearLens(ϕ,I,I)
end
@unpack Nbatch,Nx,Ny,Δx = ϕ
T = real(ϕ.T)
Nbatch > 1 && error("BilinearLens with batched ϕ not implemented yet.")
# the (i,j)-th pixel is deflected to (ĩs[i],j̃s[j])
j̃s,ĩs = getindex.((∇*ϕ)./Δx, :Ix)
ĩs .= ĩs .+ (1:Ny)
j̃s .= (j̃s' .+ (1:Nx))'
# sub2ind converts a 2D index to 1D index, including wrapping at edges
indexwrap(i,N) = mod(i - 1, N) + 1
sub2ind(i,j) = Base._sub2ind((Ny,Nx),indexwrap(i,Ny),indexwrap(j,Nx))
# compute the 4 non-zero entries in L[I,:] (ie the Ith row of the sparse
# lensing representation, L) and add these to the sparse constructor
# matrices, M, and V, accordingly. this function is split off so it can be
# called directly or used as a CUDA kernel
function compute_row!(I, ĩ, j̃, M, V)
# (i,j) indices of the 4 nearest neighbors
left,right = floor(Int,ĩ) .+ (0, 1)
top,bottom = floor(Int,j̃) .+ (0, 1)
# 1-D indices of the 4 nearest neighbors
M[4I-3:4I] .= @SVector[sub2ind(left,top), sub2ind(right,top), sub2ind(left,bottom), sub2ind(right,bottom)]
# weights of these neighbors in the bilinear interpolation
Δx⁻, Δx⁺ = ((left,right) .- ĩ)
Δy⁻, Δy⁺ = ((top,bottom) .- j̃)
A = @SMatrix[
1 Δx⁻ Δy⁻ Δx⁻*Δy⁻;
1 Δx⁺ Δy⁻ Δx⁺*Δy⁻;
1 Δx⁻ Δy⁺ Δx⁻*Δy⁺;
1 Δx⁺ Δy⁺ Δx⁺*Δy⁺
]
V[4I-3:4I] .= inv(A)[1,:]
end
# a surprisingly large fraction of the computation for large Nside, so memoize it:
@memoize getK(Nx,Ny) = Int32.((4:4*Nx*Ny+3) .÷ 4)
# CPU
function compute_sparse_repr(is_gpu_backed::Val{false})
K = Vector{Int32}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
for I in 1:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
sparse(K,M,V,Nx*Ny,Nx*Ny)
end
# GPU
function compute_sparse_repr(is_gpu_backed::Val{true})
K = CuVector{Cint}(getK(Nx,Ny))
M = similar(K)
V = similar(K,T)
cuda(ĩs, j̃s, M, V; threads=256) do ĩs, j̃s, M, V
index = threadIdx().x
stride = blockDim().x
for I in index:stride:length(ĩs)
compute_row!(I, ĩs[I], j̃s[I], M, V)
end
end
CuSparseMatrixCSR(CuSparseMatrixCOO{T}(K,M,V,(Nx*Ny,Nx*Ny)))
end
BilinearLens(ϕ, compute_sparse_repr(Val(is_gpu_backed(ϕ))), nothing)
end
# lazily computing the sparse representation for anti-lensing
function get_anti_lensing_sparse_repr!(Lϕ::BilinearLens)
if Lϕ.anti_lensing_sparse_repr == nothing
Lϕ.anti_lensing_sparse_repr = BilinearLens(-Lϕ.ϕ).sparse_repr
end
Lϕ.anti_lensing_sparse_repr
end
getϕ(Lϕ::BilinearLens) = Lϕ.ϕ
(Lϕ::BilinearLens)(ϕ::FlatField) = BilinearLens(ϕ)
hash(L::BilinearLens, h::UInt64) = foldr(hash, (typeof(L), getϕ(L)), init=h)
# applying various forms of the operator
function *(Lϕ::BilinearLens, f::FlatField)
Lϕ.sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
for batch in 1:size(f.arr,4), pol in 1:size(f.arr,3)
mul!(@views(f̃.arr[:,:,pol,batch][:]), Lϕ.sparse_repr, @views(Łf.arr[:,:,pol,batch][:]))
end
f̃
end
function *(Lϕ::Adjoint{<:Any,<:BilinearLens}, f::FlatField)
parent(Lϕ).sparse_repr===I && return f
Łf = Ł(f)
f̃ = similar(Łf)
for batch in 1:size(f.arr,4), pol in 1:size(f.arr,3)
mul!(@views(f̃.arr[:,:,pol,batch][:]), parent(Lϕ).sparse_repr', @views(Łf.arr[:,:,pol,batch][:]))
end
f̃
end
function \(Lϕ::BilinearLens, f̃::FlatField)
Lϕ.sparse_repr===I && return f̃
Łf̃ = Ł(f̃)
f = similar(Łf̃)
for batch in 1:size(f.arr,4), pol in 1:size(f.arr,3)
@views(f.arr[:,:,pol,batch][:]) .= gmres(
Lϕ.sparse_repr, @views(Łf̃.arr[:,:,pol,batch][:]),
Pl = get_anti_lensing_sparse_repr!(Lϕ), maxiter = 5
)
end
f
end
function \(Lϕ::Adjoint{<:Any,<:BilinearLens}, f̃::FlatField)
parent(Lϕ).sparse_repr===I && return f̃
Łf̃ = Ł(f̃)
f = similar(Łf̃)
for batch in 1:size(f.arr,4), pol in 1:size(f.arr,3)
@views(f.arr[:,:,pol,batch][:]) .= gmres(
parent(Lϕ).sparse_repr', @views(Łf̃.arr[:,:,pol,batch][:]),
Pl = get_anti_lensing_sparse_repr!(parent(Lϕ))', maxiter = 5
)
end
f
end
for op in (:*, :\)
@eval function ($op)(Lϕ::Union{BilinearLens, Adjoint{<:Any,<:BilinearLens}}, f::FieldTuple)
FieldTuple(map(f->($op)(Lϕ,f), f.fs))
end
end
# gradients
@adjoint BilinearLens(ϕ) = BilinearLens(ϕ), Δ -> (Δ,)
@adjoint function *(Lϕ::BilinearLens, f::Field{B}) where {B}
f̃ = Lϕ * f
function back(Δ)
(∇' * (Ref(spin_adjoint(Ł(Δ))) .* Ł(∇*f̃))), B(Lϕ*Δ)
end
f̃, back
end
# gpu
adapt_structure(storage, Lϕ::BilinearLens) = BilinearLens(adapt(storage, fieldvalues(Lϕ))...)