-
Notifications
You must be signed in to change notification settings - Fork 10
/
field_tuples.jl
121 lines (99 loc) · 5.57 KB
/
field_tuples.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
### FieldTuple types
# FieldTuple is a thin wrapper around a Tuple or NamedTuple holding some Fields
# and behaving like a Field itself
struct FieldTuple{FS<:Union{Tuple,NamedTuple},B,T} <: Field{B,T}
fs :: FS
function FieldTuple(fs::FS) where {FS<:Union{Tuple,NamedTuple}}
B = BasisProd{Tuple{map(basis,values(fs))...}}
T = promote_type(map(eltype,values(fs))...)
new{FS,B,T}(fs)
end
end
# FieldTuple(args...) or FieldTuple(;kwargs...) calls the inner constructor
# which takes a single Tuple/NamedTuple:
(::Type{FT})(;kwargs...) where {FT<:FieldTuple} = FT((;kwargs...))
(::Type{FT})(f1::Field,f2::Field,fs::Field...) where {FT<:FieldTuple} = FT((f1,f2,fs...))
### printing
getindex(f::FieldTuple,::Colon) = vcat(getindex.(values(f.fs),:)...)[:]
getindex(D::DiagOp{<:FieldTuple}, i::Int, j::Int) = (i==j) ? D.diag[:][i] : diagzero(D, i, j)
typealias_def(::Type{<:FieldTuple{NamedTuple{Names,FS},T}}) where {Names,FS<:Tuple,T} =
"Field-($(join(map(string,Names),",")))-$FS"
typealias_def(::Type{<:FieldTuple{FS,T}}) where {FS<:Tuple,T} =
"Field-$(tuple_type_len(FS))-$FS"
### array interface
size(f::FieldTuple) = (mapreduce(length, +, f.fs, init=0),)
copyto!(dest::FieldTuple, src::FieldTuple) = (map(copyto!,dest.fs,src.fs); dest)
iterate(ft::FieldTuple, args...) = iterate(ft.fs, args...)
getindex(f::FieldTuple, i::Union{Int,UnitRange}) = getindex(f.fs, i)
fill!(ft::FieldTuple, x) = (map(f->fill!(f,x), ft.fs); ft)
adapt_structure(to, f::FieldTuple) = FieldTuple(map(f->adapt(to,f),f.fs))
similar(ft::FieldTuple) = FieldTuple(map(similar,ft.fs))
similar(ft::FieldTuple, ::Type{T}) where {T<:Number} = FieldTuple(map(f->similar(f,T),ft.fs))
sum(f::FieldTuple; dims=:) = dims == (:) ? sum(sum, f.fs) : error("sum(::FieldTuple, dims=$dims not supported")
### broadcasting
# see base_fields.jl for more explanation of all these pieces, its the
# exact same principle
struct FieldTupleStyle{S,Names} <: AbstractArrayStyle{1} end
function BroadcastStyle(::Type{<:FieldTuple{TS}}) where {TS<:Tuple}
FieldTupleStyle{Tuple{map_tupleargs(typeof∘BroadcastStyle,TS)...}, Nothing}()
end
function BroadcastStyle(::Type{<:FieldTuple{NamedTuple{Names,TS}}}) where {Names,TS<:Tuple}
FieldTupleStyle{Tuple{map_tupleargs(typeof∘BroadcastStyle,TS)...}, Names}()
end
function BroadcastStyle(::FieldTupleStyle{S₁,Names}, ::FieldTupleStyle{S₂,Names}) where {S₁,S₂,Names}
FieldTupleStyle{Tuple{map_tupleargs((s₁,s₂)->typeof(result_style(s₁(),s₂())), S₁, S₂)...}, Names}()
end
BroadcastStyle(S::FieldTupleStyle, ::DefaultArrayStyle{0}) = S
@generated function materialize(bc::Broadcasted{FieldTupleStyle{S,Names}}) where {S,Names}
wrapper = Names == Nothing ? :tuple : :(NamedTuple{$Names})
exprs = map_tupleargs(S, tuple(1:tuple_type_len(S)...)) do Sᵢ, i
:(materialize(convert(Broadcasted{$Sᵢ}, preprocess(($(S.parameters[i])(),FieldTupleComponent{$i}()), bc))))
end
:(FieldTuple($wrapper($(exprs...))))
end
@generated function materialize!(dst::FieldTuple, bc::Broadcasted{FieldTupleStyle{S,Names}}) where {S,Names}
exprs = map_tupleargs(S, tuple(1:tuple_type_len(S)...)) do Sᵢ, i
:(materialize!(dst.fs[$i], convert(Broadcasted{$Sᵢ}, preprocess(($(S.parameters[i])(),FieldTupleComponent{$i}()), bc))))
end
:(begin $(exprs...) end; dst)
end
struct FieldTupleComponent{i} end
preprocess(::Tuple{<:Any,FieldTupleComponent{i}}, ft::FieldTuple) where {i} = ft.fs[i]
preprocess(dest::Tuple{FieldTupleStyle{S},<:Any}, bc::Broadcasted) where {S} =
broadcasted(S(), bc.f, preprocess_args(dest, bc.args)...)
### promotion
function promote(ft1::FieldTuple, ft2::FieldTuple)
fts = map(promote, ft1.fs, ft2.fs)
FieldTuple(map(first,fts)), FieldTuple(map(last,fts))
end
### conversion
Basis(ft::FieldTuple) = ft
(::Type{B})(ft::FieldTuple) where {B<:Basis} = FieldTuple(map(B, ft.fs))
(::Type{B})(ft::FieldTuple) where {B<:Basislike} = FieldTuple(map(B, ft.fs))
(::Type{B})(ft::FieldTuple{<:Tuple}) where {Bs,B<:BasisProd{Bs}} =
FieldTuple(map_tupleargs((B,f)->B(f), Bs, ft.fs))
(::Type{B})(ft::FieldTuple{<:NamedTuple{Names}}) where {Names,Bs,B<:BasisProd{Bs}} =
FieldTuple(NamedTuple{Names}(map_tupleargs((B,f)->B(f), Bs, values(ft.fs))))
### properties
getproperty(f::FieldTuple, s::Symbol) = getproperty(f, Val(s))
getproperty(f::FieldTuple, ::Val{:fs}) = getfield(f, :fs)
getproperty(f::FieldTuple, ::Val{s}) where {s} = getproperty(getfield(f,:fs), s)
propertynames(f::FieldTuple) = (:fs, propertynames(f.fs)...)
### simulation
white_noise(ξ::FieldTuple, rng::AbstractRNG) = FieldTuple(map(f -> white_noise(f, rng), ξ.fs))
### Diagonal-ops
# need a method specific for FieldTuple since we don't carry around
# the basis in a way that works with the default implementation
(*)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal(d)*f, D.diag.fs, f.fs))
(\)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal(d)\f, D.diag.fs, f.fs))
# # promote before recursing for these
dot(a::FieldTuple, b::FieldTuple) = reduce(+, map(dot, getfield.(promote(a,b),:fs)...), init=0)
hash(ft::FieldTuple, h::UInt) = foldr(hash, (typeof(ft), ft.fs))
# logdet & trace
logdet(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(logdet∘Diagonal, L.diag.fs), init=0)
tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagonal, L.diag.fs), init=0)
# misc
batch_length(ft::FieldTuple) = only(unique(map(batch_length, ft.fs)))
function global_rng_for(::Type{<:FieldTuple{<:Union{FS,NamedTuple{Names,FS}}}}) where {Names,FS<:Tuple}
only(unique(map_tupleargs(global_rng_for, FS)))
end