-
Notifications
You must be signed in to change notification settings - Fork 10
/
flat_proj.jl
196 lines (152 loc) · 7.1 KB
/
flat_proj.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
abstract type FlatProj <: FieldMetadata end
struct ProjLambert{T, V<:AbstractVector{T}, M<:AbstractMatrix{T}} <: FlatProj
# these must be the same to broadcast together
Ny :: Int
Nx :: Int
θpix :: Float64
center :: Tuple{Float64,Float64}
# these can be different and still broadcast (including different types)
storage
Δx :: T
Ωpix :: T
nyquist :: T
Δℓx :: T
Δℓy :: T
ℓy :: V
ℓx :: V
ℓmag :: M
sin2ϕ :: M
cos2ϕ :: M
end
real_type(T) = promote_type(real(T), Float32)
@init @require Unitful="1986cc42-f94f-5a68-af5c-568840ba703d" real_type(::Type{<:Unitful.Quantity{T}}) where {T} = real_type(T)
ProjLambert(;Ny, Nx, θpix=1, center=(0,0), T=Float32, storage=Array) =
ProjLambert(Ny, Nx, Float64(θpix), Float64.(center), real_type(T), storage)
@memoize function ProjLambert(Ny, Nx, θpix, center, ::Type{T}, ::Type{storage}) where {T,storage}
Δx = T(deg2rad(θpix/60))
Δℓx = T(2π/(Nx*Δx))
Δℓy = T(2π/(Ny*Δx))
nyquist = T(2π/(2Δx))
Ωpix = T(Δx^2)
ℓy = adapt(storage, (ifftshift(-Ny÷2:(Ny-1)÷2) .* Δℓy)[1:Ny÷2+1])
ℓx = adapt(storage, (ifftshift(-Nx÷2:(Nx-1)÷2) .* Δℓx))
ℓmag = @. sqrt(ℓx'^2 + ℓy^2)
ϕ = @. angle(ℓx' + im*ℓy)
sin2ϕ, cos2ϕ = @. sin(2ϕ), cos(2ϕ)
if iseven(Ny)
sin2ϕ[end, end:-1:(Nx÷2+2)] .= sin2ϕ[end, 2:Nx÷2]
end
ProjLambert(Ny,Nx,θpix,center,storage,Δx,Ωpix,nyquist,Δℓx,Δℓy,ℓy,ℓx,ℓmag,sin2ϕ,cos2ϕ)
end
typealias_def(::Type{<:ProjLambert{T}}) where {T} = "ProjLambert{$T}"
### promotion
# used in broadcasting to decide the resulting metadata when
# broadcasting over two fields
function promote_metadata_strict(metadata₁::ProjLambert{T₁}, metadata₂::ProjLambert{T₂} ) where {T₁,T₂}
if (
metadata₁.θpix === metadata₂.θpix &&
metadata₁.Ny === metadata₂.Ny &&
metadata₁.Nx === metadata₂.Nx
)
# always returning the "wider" metadata even if T₁==T₂ helps
# inference and is optimized away anyway
promote_type(T₁,T₂) == T₁ ? metadata₁ : metadata₂
else
error("""Can't broadcast two fields with the following differing metadata:
1: $(select(fields(metadata₁),(:θpix,:Ny,:Nx)))
2: $(select(fields(metadata₂),(:θpix,:Ny,:Nx)))
""")
end
end
# used in non-broadcasted algebra to decide the resulting metadata
# when performing some operation across two fields. this is free to do
# more generic promotion than promote_metadata_strict (although this
# is currently not used, but in the future could include promoting
# resolution, etc...). the result should be a common metadata which we
# can convert both fields to then do a succesful broadcast
promote_metadata_generic(metadata₁::ProjLambert, metadata₂::ProjLambert) =
promote_metadata_strict(metadata₁, metadata₂)
### preprocessing
# defines how ImplicitFields and BatchedReals behave when broadcasted
# with ProjLambert fields. these can return arrays, but can also
# return `Broadcasted` objects which are spliced into the final
# broadcast, thus avoiding allocating any temporary arrays.
function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
r isa BatchedReal ? adapt(V, reshape(r.vals, 1, 1, 1, :)) : r
end
# need custom adjoint here bc Δ can come back batched from the
# backward pass even though r was not batched on the forward pass
@adjoint function preprocess(m::Tuple{<:Any,<:ProjLambert{T,V}}, r::Real) where {T,V}
preprocess(m, r), Δ -> (nothing, Δ isa AbstractArray ? batch(real.(Δ[:])) : Δ)
end
function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ∇d::∇diag) where {S,B}
(B <: Union{Fourier,QUFourier,IQUFourier}) ||
error("Can't broadcast ∇² as a $(typealias(B)), its not diagonal in this basis.")
# turn both into 2D matrices so this function is type-stable
# (reshape doesnt actually make a copy here, so this doesn't
# impact performance)
if ∇d.coord == 1
broadcasted(*, ∇d.prefactor * im, reshape(proj.ℓx, 1, :))
else
broadcasted(*, ∇d.prefactor * im, reshape(proj.ℓy, :, 1))
end
end
function preprocess((_,proj)::Tuple{BaseFieldStyle{S,B},<:ProjLambert}, ::∇²diag) where {S,B}
(B <: Union{Fourier,<:Basis2Prod{<:Any,Fourier},<:Basis3Prod{<:Any,<:Any,Fourier}}) ||
error("Can't broadcast a BandPass as a $(typealias(B)), its not diagonal in this basis.")
broadcasted(+, broadcasted(^, proj.ℓx', 2), broadcasted(^, proj.ℓy, 2))
end
function preprocess((_,proj)::Tuple{<:Any,<:ProjLambert}, bp::BandPass)
Cℓ_to_2D(bp.Wℓ, proj)
end
function Cℓ_to_2D(Cℓ, proj::ProjLambert{T}) where {T}
Complex{T}.(nan2zero.(Cℓ.(proj.ℓmag)))
end
### adapting
# dont adapt the fields in proj, instead re-call into the memoized
# ProjLambert so we always get back the singleton ProjLambert object
# for the given set of parameters (helps reduce memory usage and
# speed-up subsequent broadcasts which would otherwise not hit the
# "===" branch of the "promote_*" methods)
function adapt_structure(storage, proj::ProjLambert{T}) where {T}
@unpack Ny, Nx, θpix = proj
T′ = eltype(storage)
ProjLambert(;Ny, Nx, θpix, T=(T′==Any ? T : real(T′)), storage)
end
@doc doc"""
pixwin(θpix, ℓ)
Returns the pixel window function for square flat-sky pixels of width `θpix` (in
arcmin) evaluated at some `ℓ`s. This is the scaling of k-modes, the scaling of
the power spectrum will be pixwin^2.
"""
pixwin(θpix, ℓ) = @. sinc(ℓ*deg2rad(θpix/60)/2π)
# ### serialization
# makes it so the arrays in ProjLambert objects aren't actually
# serialized, instead just (Ny, Nx, θpix, center, T) are stored, and
# deserializing just recreates the ProjLambert object, possibly from
# the memoized cache if it already exists.
function _serialization_key(proj::ProjLambert{T}) where {T}
@unpack Ny, Nx, θpix, center, storage = proj
(;Ny, Nx, θpix, center, T, storage)
end
# Julia serialization
function Serialization.serialize(s::AbstractSerializer, proj::ProjLambert)
@unpack Ny, Nx, θpix, center = proj
Serialization.writetag(s.io, Serialization.OBJECT_TAG)
Serialization.serialize(s, ProjLambert)
Serialization.serialize(s, _serialization_key(proj))
end
function Serialization.deserialize(s::AbstractSerializer, ::Type{ProjLambert})
ProjLambert(; Serialization.deserialize(s)...)
end
# JLD2 serialization
# (always deserialize as Array)
function JLD2.writeas(::Type{<:ProjLambert})
Tuple{Val{ProjLambert},NamedTuple{(:Ny,:Nx,:θpix,:center,:T),Tuple{Int,Int,Float64,Tuple{Float64,Float64},DataType}}}
end
function JLD2.wconvert(::Type{<:Tuple{Val{ProjLambert},NamedTuple}}, proj::ProjLambert)
(Val(ProjLambert), delete(_serialization_key(proj), :storage))
end
function JLD2.rconvert(::Type{<:ProjLambert}, (_,s)::Tuple{Val{ProjLambert},NamedTuple})
ProjLambert(; storage=Array, s...)
end