/
_chainrules.jl
50 lines (40 loc) 路 1.73 KB
/
_chainrules.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
@testset "chainrules.jl" begin
function FiniteDifferences.to_vec(k::AxisKeys.KaNda)
v, b = to_vec(k.data)
back(x) = wrapdims(b(x); AxisKeys.named_axiskeys(k)...)
return v, back
end
function FiniteDifferences.to_vec(k::KeyedArray)
v, b = to_vec(k.data)
back(x) = wrapdims(b(x), axiskeys(k)...)
return v, back
end
@testset "ProjectTo" begin
data = rand(3)
ka = wrapdims(data, a=1:3)
p = ProjectTo(ka)
@test p(data) == ka
@test p(NoTangent()) == NoTangent()
data = rand(3, 4)
ka = wrapdims(data, a=1:3, b='a':'d')
p = ProjectTo(ka)
@test p(data) == ka
@test p(NoTangent()) == NoTangent()
end
@testset "KeyedVector" begin
data = rand(3)
test_rrule(AxisKeys.keyless_unname, wrapdims(data, a=1:3); check_inferred=false)
test_rrule(AxisKeys.keyless_unname, wrapdims(data, 1:3); check_inferred=false)
test_rrule(AxisKeys.keyless_unname, data; check_inferred=false)
# with matrix output tangent
test_rrule(AxisKeys.keyless_unname, wrapdims(data, a=1:3); output_tangent=rand(3, 1), check_inferred=false)
test_rrule(AxisKeys.keyless_unname, wrapdims(data, 1:3); output_tangent=rand(3, 1), check_inferred=false)
test_rrule(AxisKeys.keyless_unname, data; output_tangent=rand(3, 1), check_inferred=false)
end
@testset "KeyedMatrix" begin
data = rand(3, 4)
test_rrule(AxisKeys.keyless_unname, wrapdims(data, a=1:3, b='a':'d'); check_inferred=false)
test_rrule(AxisKeys.keyless_unname, wrapdims(data, 1:3, 'a':'d'); check_inferred=false)
test_rrule(AxisKeys.keyless_unname, data; check_inferred=false)
end
end