/
FeedbackTrees.jl
59 lines (50 loc) · 1.65 KB
/
FeedbackTrees.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
module FeedbackTrees
using Flux
import Flux: children, mapchildren
import Base: getindex, show
using MacroTools: @forward
using ..Splitters
using ..AbstractMergers
using ..AbstractFeedbackNets
export FeedbackTree
struct FeedbackTree{T<:Tuple} <: AbstractFeedbackNet
layers::T
FeedbackTree(xs...) = new{typeof(xs)}(xs)
end
# empty feedback tree just returns the input
function (c::FeedbackTree{Tuple{}})(h, x)
return h, x
end # function (c::FeedbackTree{Tuple{}})
"""
(c::FeedbackTree)(h, x)
Apply a `FeedbackTree` to input `x` with hidden state `h`. `h` should take the
form of a dictionary mapping `Splitter` names to states.
"""
function (c::FeedbackTree)(h, x)
newh = Dict{String, Any}()
for layer ∈ c.layers
if layer isa Splitter
newh[splitname(layer)] = x
x = h[splitname(layer)]
elseif layer isa AbstractMerger
x = layer(x, h)
else
x = layer(x)
end
end
return newh, x
end # function (c::FeedbackTree)
# These overloads ensure that a FeedbackTree behaves as Flux expects, e.g.,
# when moving to gpu or collecting parameters.
children(c::FeedbackTree) = c.layers
mapchildren(f, c::FeedbackTree) = FeedbackTree(f.(c.layers)...)
# These overloads ensure that indexing / slicing etc. work with FeedbackTrees
@forward FeedbackTree.layers Base.getindex, Base.length, Base.first, Base.last,
Base.iterate, Base.lastindex
getindex(c::FeedbackTree, i::AbstractArray) = FeedbackTree(c.layers[i]...)
function show(io::IO, c::FeedbackTree)
print(io, "FeedbackTree(")
join(io, c.layers, ", ")
print(io, ")")
end # function show
end # module FeedbackTrees