Skip to content

Commit 64f5119

Browse files
MaxSagebaumhsutter
andauthored
Documentation for autodiff. (#1424)
* Documentation for autodiff. * Fix remarks and build errors. * Trying to fix msvc. * Update results for regression tests. * Make AD warning be not an error so regression tests run --------- Co-authored-by: Herb Sutter <herb.sutter@gmail.com>
1 parent a877eac commit 64f5119

29 files changed

+3465
-931
lines changed

docs/cpp2/metafunctions.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,36 @@ main: () = {
367367

368368
### For computational and functional types
369369

370+
#### `autodiff`
371+
372+
An `autodiff` type is extended so that derivatives can be computed. The metafunction adds for each function and member function a differentiated version. **This is a proof of concept implementation. Expect it to break.**
373+
A simple hello diff example is:
374+
```
375+
ad: @autodiff type = {
376+
func: (x: double) -> (r: double) = {
377+
r = x * x;
378+
}
379+
}
380+
381+
main: (args) = {
382+
x := 3.0;
383+
x_d := 1.0;
384+
385+
r := ad::func_d(x, x_d);
386+
387+
std::cout << "Derivative of 'x*x' at (x)$ is (r.r_d)$" << std::endl;
388+
}
389+
```
390+
391+
The `@autodiff` metafunction mostly supports the forward mode of algorithmic differentiation. The reverse mode is only partly implemented and not yet well tested.
392+
See [Supported autodiff features](../notes/autodiff_status.md) for a list of supported language features.
393+
394+
Options can be given by text template arguments, e.g. `@autodiff<"reverse">` enables the reverse mode.
395+
| Option | Description |
396+
| `"reverse"` | Reverse mode algorithmic differentiation. Default suffix `_b`. |
397+
| `"order=<n>"` | Higher order derivatives. `<n>` can be arbitrary. See `regression-tests/pure2-autodiff-higher-order.cpp2` for examples. |
398+
| `"suffix=<s>"` | Change the forward mode suffix. Can be used to apply autodiff multiple times. E.g. `@autodiff @autodiff<"suffix=_d2">`. |
399+
| `"rws_suffix=<s>"` | Change the reverse mode suffix. |
370400

371401
#### `regex`
372402

docs/notes/autodiff_status.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Supported algorithmic differentiation (autodiff) features
2+
3+
The listings might be incomplete. If something is missing, it is not supported. Algorithmic differentiation is applied via the [`autodiff` metafunction](../cpp2/metafunctions.md#autodiff). Maybe the planned features are added in 2026. Do not wait for them. The autodif feature is a proof of concept implementation.
4+
5+
** Reverse mode algorithmic differentiation is very experimental. Expect it to break. **
6+
7+
## Currently supported or planned features
8+
9+
| Description | Status forward | Status reverse |
10+
| --- | --- | --- |
11+
| Type definitions (structures) | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
12+
| Member values | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
13+
| Member functions | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
14+
| Function arguments | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
15+
| Function return arguments | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
16+
| Addition and multiplication | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
17+
| Prefix addition and subtraction | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
18+
| Static member function calls | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
19+
| Member function calls | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
20+
| Function calls | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
21+
| Math functions (sin, cos, exp, sqrt) | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
22+
| If else | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
23+
| Return statement | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
24+
| Intermediate variables | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
25+
| Passive variables | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
26+
| While loop | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
27+
| Do while loop | <span style="color:green">Supported</span> | <span style="color:gray">Planned</span> |
28+
| For loop | <span style="color:green">Supported</span> | <span style="color:green">Supported</span> |
29+
| Template arguments | <span style="color:gray">Planned</span> | <span style="color:gray">Planned</span> |
30+
| Lambda functions | <span style="color:gray">Planned</span> | <span style="color:gray">Planned</span> |
31+
32+
33+
34+
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
x + x = 4.000000
2+
x + x diff order 1 = 2.000000
3+
x + x diff order 2 = 0.000000
4+
x + x diff order 3 = 0.000000
5+
x + x diff order 4 = 0.000000
6+
x + x diff order 5 = 0.000000
7+
x + x diff order 6 = 0.000000
8+
0 - x = -2.000000
9+
0 - x diff order 1 = -1.000000
10+
0 - x diff order 2 = 0.000000
11+
0 - x diff order 3 = 0.000000
12+
0 - x diff order 4 = 0.000000
13+
0 - x diff order 5 = 0.000000
14+
0 - x diff order 6 = 0.000000
15+
x^7 = 128.000000
16+
x^7 diff order 1 = 448.000000
17+
x^7 diff order 2 = 1344.000000
18+
x^7 diff order 3 = 3360.000000
19+
x^7 diff order 4 = 6720.000000
20+
x^7 diff order 5 = 10080.000000
21+
x^7 diff order 6 = 10080.000000
22+
1/x = 0.500000
23+
1/x diff order 1 = -0.250000
24+
1/x diff order 2 = 0.250000
25+
1/x diff order 3 = -0.375000
26+
1/x diff order 4 = 0.750000
27+
1/x diff order 5 = -1.875000
28+
1/x diff order 6 = 5.625000
29+
sqrt(x) = 1.414214
30+
sqrt(x) diff order 1 = 0.353553
31+
sqrt(x) diff order 2 = -0.088388
32+
sqrt(x) diff order 3 = 0.066291
33+
sqrt(x) diff order 4 = -0.082864
34+
sqrt(x) diff order 5 = 0.145012
35+
sqrt(x) diff order 6 = -0.326277
36+
log(x) = 0.693147
37+
log(x) diff order 1 = 0.500000
38+
log(x) diff order 2 = -0.250000
39+
log(x) diff order 3 = 0.250000
40+
log(x) diff order 4 = -0.375000
41+
log(x) diff order 5 = 0.750000
42+
log(x) diff order 6 = -1.875000
43+
exp(x) = 7.389056
44+
exp(x) diff order 1 = 7.389056
45+
exp(x) diff order 2 = 7.389056
46+
exp(x) diff order 3 = 7.389056
47+
exp(x) diff order 4 = 7.389056
48+
exp(x) diff order 5 = 7.389056
49+
exp(x) diff order 6 = 7.389056
50+
sin(x) = 0.909297
51+
sin(x) diff order 1 = -0.416147
52+
sin(x) diff order 2 = -0.909297
53+
sin(x) diff order 3 = 0.416147
54+
sin(x) diff order 4 = 0.909297
55+
sin(x) diff order 5 = -0.416147
56+
sin(x) diff order 6 = -0.909297
57+
cos(x) = -0.416147
58+
cos(x) diff order 1 = -0.909297
59+
cos(x) diff order 2 = 0.416147
60+
cos(x) diff order 3 = 0.909297
61+
cos(x) diff order 4 = -0.416147
62+
cos(x) diff order 5 = -0.909297
63+
cos(x) diff order 6 = 0.416147
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
diff(x + y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
2+
r = 5.000000
3+
d1 = 3.000000
4+
d2 = 0.000000
5+
d3 = 0.000000
6+
d4 = 0.000000
7+
d5 = 0.000000
8+
d6 = 0.000000
9+
diff(x + y + x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
10+
r = 7.000000
11+
d1 = 4.000000
12+
d2 = 0.000000
13+
d3 = 0.000000
14+
d4 = 0.000000
15+
d5 = 0.000000
16+
d6 = 0.000000
17+
diff(x - y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
18+
r = -1.000000
19+
d1 = -1.000000
20+
d2 = 0.000000
21+
d3 = 0.000000
22+
d4 = 0.000000
23+
d5 = 0.000000
24+
d6 = 0.000000
25+
diff(x - y - x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
26+
r = -3.000000
27+
d1 = -2.000000
28+
d2 = 0.000000
29+
d3 = 0.000000
30+
d4 = 0.000000
31+
d5 = 0.000000
32+
d6 = 0.000000
33+
diff(x + y - x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
34+
r = 3.000000
35+
d1 = 2.000000
36+
d2 = 0.000000
37+
d3 = 0.000000
38+
d4 = 0.000000
39+
d5 = 0.000000
40+
d6 = 0.000000
41+
diff(x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
42+
r = 6.000000
43+
d1 = 7.000000
44+
d2 = 4.000000
45+
d3 = 0.000000
46+
d4 = 0.000000
47+
d5 = 0.000000
48+
d6 = 0.000000
49+
diff(x * y * x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
50+
r = 12.000000
51+
d1 = 20.000000
52+
d2 = 22.000000
53+
d3 = 12.000000
54+
d4 = 0.000000
55+
d5 = 0.000000
56+
d6 = 0.000000
57+
diff(x / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
58+
r = 0.666667
59+
d1 = -0.111111
60+
d2 = 0.148148
61+
d3 = -0.296296
62+
d4 = 0.790123
63+
d5 = -2.633745
64+
d6 = 10.534979
65+
diff(x / y / y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
66+
r = 0.222222
67+
d1 = -0.185185
68+
d2 = 0.296296
69+
d3 = -0.691358
70+
d4 = 2.106996
71+
d5 = -7.901235
72+
d6 = 35.116598
73+
diff(x * y / x) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
74+
r = 3.000000
75+
d1 = 2.000000
76+
d2 = 0.000000
77+
d3 = 0.000000
78+
d4 = 0.000000
79+
d5 = 0.000000
80+
d6 = 0.000000
81+
diff(x * (x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
82+
r = 10.000000
83+
d1 = 11.000000
84+
d2 = 6.000000
85+
d3 = 0.000000
86+
d4 = 0.000000
87+
d5 = 0.000000
88+
d6 = 0.000000
89+
diff(x + x * y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
90+
r = 8.000000
91+
d1 = 8.000000
92+
d2 = 4.000000
93+
d3 = 0.000000
94+
d4 = 0.000000
95+
d5 = 0.000000
96+
d6 = 0.000000
97+
diff(+x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
98+
r = 5.000000
99+
d1 = 3.000000
100+
d2 = 0.000000
101+
d3 = 0.000000
102+
d4 = 0.000000
103+
d5 = 0.000000
104+
d6 = 0.000000
105+
diff(-x + y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
106+
r = 1.000000
107+
d1 = 1.000000
108+
d2 = 0.000000
109+
d3 = 0.000000
110+
d4 = 0.000000
111+
d5 = 0.000000
112+
d6 = 0.000000
113+
diff(x * func(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
114+
r = 10.000000
115+
d1 = 11.000000
116+
d2 = 6.000000
117+
d3 = 0.000000
118+
d4 = 0.000000
119+
d5 = 0.000000
120+
d6 = 0.000000
121+
diff(x * func_outer(x, y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
122+
r = 10.000000
123+
d1 = 11.000000
124+
d2 = 6.000000
125+
d3 = 0.000000
126+
d4 = 0.000000
127+
d5 = 0.000000
128+
d6 = 0.000000
129+
diff(sin(x - y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
130+
r = -0.841471
131+
d1 = -0.540302
132+
d2 = 0.841471
133+
d3 = 0.540302
134+
d4 = -0.841471
135+
d5 = -0.540302
136+
d6 = 0.841471
137+
diff(if branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
138+
r = 2.000000
139+
d1 = 1.000000
140+
d2 = 0.000000
141+
d3 = 0.000000
142+
d4 = 0.000000
143+
d5 = 0.000000
144+
d6 = 0.000000
145+
diff(if else branch) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
146+
r = 2.000000
147+
d1 = 1.000000
148+
d2 = 0.000000
149+
d3 = 0.000000
150+
d4 = 0.000000
151+
d5 = 0.000000
152+
d6 = 0.000000
153+
diff(direct return) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
154+
r = 5.000000
155+
d1 = 3.000000
156+
d2 = 0.000000
157+
d3 = 0.000000
158+
d4 = 0.000000
159+
d5 = 0.000000
160+
d6 = 0.000000
161+
diff(intermediate var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
162+
r = 5.000000
163+
d1 = 3.000000
164+
d2 = 0.000000
165+
d3 = 0.000000
166+
d4 = 0.000000
167+
d5 = 0.000000
168+
d6 = 0.000000
169+
diff(intermediate passive var) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
170+
r = 5.000000
171+
d1 = 3.000000
172+
d2 = 0.000000
173+
d3 = 0.000000
174+
d4 = 0.000000
175+
d5 = 0.000000
176+
d6 = 0.000000
177+
diff(intermediate untyped) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
178+
r = 5.000000
179+
d1 = 3.000000
180+
d2 = 0.000000
181+
d3 = 0.000000
182+
d4 = 0.000000
183+
d5 = 0.000000
184+
d6 = 0.000000
185+
diff(intermediate default init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
186+
r = 5.000000
187+
d1 = 3.000000
188+
d2 = 0.000000
189+
d3 = 0.000000
190+
d4 = 0.000000
191+
d5 = 0.000000
192+
d6 = 0.000000
193+
diff(intermediate no init) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
194+
r = 5.000000
195+
d1 = 3.000000
196+
d2 = 0.000000
197+
d3 = 0.000000
198+
d4 = 0.000000
199+
d5 = 0.000000
200+
d6 = 0.000000
201+
diff(while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
202+
r = 8.000000
203+
d1 = 5.000000
204+
d2 = 0.000000
205+
d3 = 0.000000
206+
d4 = 0.000000
207+
d5 = 0.000000
208+
d6 = 0.000000
209+
diff(do while loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
210+
r = 8.000000
211+
d1 = 5.000000
212+
d2 = 0.000000
213+
d3 = 0.000000
214+
d4 = 0.000000
215+
d5 = 0.000000
216+
d6 = 0.000000
217+
diff(for loop) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
218+
r = 5.000000
219+
d1 = 3.000000
220+
d2 = 0.000000
221+
d3 = 0.000000
222+
d4 = 0.000000
223+
d5 = 0.000000
224+
d6 = 0.000000
225+
diff(tye_outer.a + y) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
226+
r = 5.000000
227+
d1 = 3.000000
228+
d2 = 0.000000
229+
d3 = 0.000000
230+
d4 = 0.000000
231+
d5 = 0.000000
232+
d6 = 0.000000
233+
diff(type_outer.add(y)) at (x = 2.000000, x_d = ( 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 ), y = 3.000000, y_d = ( 2.000000 0.000000 0.000000 0.000000 0.000000 0.000000 )):
234+
r = 5.000000
235+
d1 = 3.000000
236+
d2 = 0.000000
237+
d3 = 0.000000
238+
d4 = 0.000000
239+
d5 = 0.000000
240+
d6 = 0.000000

0 commit comments

Comments
 (0)