forked from apple/swift
-
Notifications
You must be signed in to change notification settings - Fork 1
/
autodiff_diagnostics.swift
117 lines (91 loc) · 3.73 KB
/
autodiff_diagnostics.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// RUN: %target-swift-frontend -emit-sil -verify %s
//===----------------------------------------------------------------------===//
// Top-level (before primal/adjoint synthesis)
//===----------------------------------------------------------------------===//
// expected-note @+1 {{opaque non-'@autodiff' function is not differentiable}}
func foo(_ f: (Float) -> Float) -> Float {
// expected-error @+1 {{function is not differentiable}}
return gradient(at: 0, in: f)
}
//===----------------------------------------------------------------------===//
// Basic function
//===----------------------------------------------------------------------===//
func one_to_one_0(_ x: Float) -> Float {
return x + 2
}
_ = gradient(at: 0, in: one_to_one_0) // okay!
//===----------------------------------------------------------------------===//
// Generics
//===----------------------------------------------------------------------===//
// expected-note @+3 {{differentiating functions with parameters or result of unknown size is not supported yet}}
// expected-error @+2 {{function is not differentiable}}
@differentiable()
func generic<T: Differentiable & FloatingPoint>(_ x: T) -> T {
return x + 1
}
//===----------------------------------------------------------------------===//
// Non-differentiable stored properties
//===----------------------------------------------------------------------===//
struct S {
let p: Float
}
extension S : Differentiable, VectorNumeric {
static var zero: S { return S(p: 0) }
typealias Scalar = Float
static func + (lhs: S, rhs: S) -> S { return S(p: lhs.p + rhs.p) }
static func - (lhs: S, rhs: S) -> S { return S(p: lhs.p - rhs.p) }
static func * (lhs: Float, rhs: S) -> S { return S(p: lhs * rhs.p) }
typealias TangentVector = S
typealias CotangentVector = S
}
// expected-error @+2 {{function is not differentiable}}
// expected-note @+1 {{property is not differentiable}}
_ = gradient(at: S(p: 0)) { s in 2 * s.p }
//===----------------------------------------------------------------------===//
// Function composition
//===----------------------------------------------------------------------===//
// FIXME: Figure out why diagnostics no longer accumulate after we removed
// gradient synthesis. When it's fixed, replace "xpected" with "expected" below.
#if false
func uses_optionals(_ x: Float) -> Float {
var maybe: Float? = 10
maybe = x
// xpected-note @+1 {{differentiating control flow is not supported yet}}
return maybe!
}
_ = gradient(at: 0, in: uses_optionals) // xpected-error {{function is not differentiable}}
func f0(_ x: Float) -> Float {
return x // okay!
}
func nested(_ x: Float) -> Float {
return gradient(at: x, in: f0) // xpected-note {{nested differentiation is not supported yet}}
}
func middle(_ x: Float) -> Float {
let y = uses_optionals(x)
return nested(y) // xpected-note {{when differentiating this function call}}
}
func middle2(_ x: Float) -> Float {
return middle(x) // xpected-note {{when differentiating this function call}}
}
func func_to_diff(_ x: Float) -> Float {
return middle2(x) // xpected-note {{expression is not differentiable}}
}
func calls_grad_of_nested(_ x: Float) -> Float {
return gradient(at: x, in: func_to_diff) // xpected-error {{function is not differentiable}}
}
//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
func if_else(_ x: Float, _ flag: Bool) -> Float {
let y: Float
// xpected-note @+1 {{differentiating control flow is not supported yet}}
if flag {
y = x + 1
} else {
y = x
}
return y
}
// xpected-error @+1 {{function is not differentiable}}
_ = gradient(at: 0) { x in if_else(0, true) }
#endif