-
Notifications
You must be signed in to change notification settings - Fork 910
/
transforms.h
224 lines (196 loc) · 7.96 KB
/
transforms.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <optional>
#include "mlx/array.h"
namespace mlx::core {
void async_eval(std::vector<array> outputs);
void eval(std::vector<array> outputs);
template <typename... Arrays, typename = enable_for_arrays_t<Arrays...>>
void eval(Arrays&&... outputs) {
eval(std::vector<array>{std::forward<Arrays>(outputs)...});
}
/**
* Computes the output and vector-Jacobian product (VJP) of a function.
*
* Computes the vector-Jacobian product of the vector of cotangents with the
* Jacobian of the function evaluated at the primals. Returns a pair of
* vectors of output arrays and VJP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> vjp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& cotangents);
/**
* Computes the output and vector-Jacobian product (VJP) of a unary function.
*/
std::pair<array, array> vjp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& cotangent);
/**
* Computes the output and Jacobian-vector product (JVP) of a function.
*
* Computes the Jacobian-vector product of the Jacobian of the function
* evaluated at the primals with the vector of tangents. Returns a pair of
* vectors of output arrays and JVP arrays.
**/
std::pair<std::vector<array>, std::vector<array>> jvp(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<array>& primals,
const std::vector<array>& tangents);
/**
* Computes the output and Jacobian-vector product (JVP) of a unary function.
*/
std::pair<array, array> jvp(
const std::function<array(const array&)>& fun,
const array& primal,
const array& tangent);
// Return type of general value_and_grad: a function which takes an input
// vector of arrays and returns a pair of vectors of arrays one for the
// values and one for the gradients wrt the first value.
using ValueAndGradFn =
std::function<std::pair<std::vector<array>, std::vector<array>>(
const std::vector<array>&)>;
using SimpleValueAndGradFn = std::function<std::pair<array, std::vector<array>>(
const std::vector<array>&)>;
/**
* Returns a function which computes the value and gradient of the input
* function with respect to a vector of input arrays.
**/
ValueAndGradFn value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& argnums);
/**
* Returns a function which computes the value and gradient of the input
* function with respect to a single input array.
**/
ValueAndGradFn inline value_and_grad(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the value and gradient of the unary
* input function.
**/
std::function<std::pair<array, array>(const array&)> inline value_and_grad(
const std::function<array(const array&)>& fun) {
return [fun](auto inputs) { return vjp(fun, inputs, array(1.0f)); };
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
return [fun, argnums](auto inputs) {
auto result = value_and_grad(
[fun](auto inputs) { return std::vector<array>{fun(inputs)}; },
argnums)(inputs);
return std::make_pair(result.first[0], result.second);
};
}
SimpleValueAndGradFn inline value_and_grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return value_and_grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the input function with
* respect to a vector of input arrays.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The vector of `argnums` specifies which the arguments to compute
* the gradient with respect to. At least one argument must be specified.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
const std::vector<int>& argnums) {
auto fn = value_and_grad(fun, argnums);
return [fn](const std::vector<array>& inputs) { return fn(inputs).second; };
}
/**
* Returns a function which computes the gradient of the input function with
* respect to a single input array.
*
* The function being differentiated takes a vector of arrays and returns an
* array. The optional `argnum` index specifies which the argument to compute
* the gradient with respect to and defaults to 0.
**/
std::function<std::vector<array>(const std::vector<array>&)> inline grad(
const std::function<array(const std::vector<array>&)>& fun,
int argnum = 0) {
return grad(fun, std::vector<int>{argnum});
}
/**
* Returns a function which computes the gradient of the unary input function.
**/
std::function<array(const array&)> inline grad(
const std::function<array(const array&)>& fun) {
auto fn = value_and_grad(fun);
return [fn](const array& input) { return fn(input).second; };
}
/**
* Automatically vectorize a unary function over the requested axes.
*/
std::function<array(const array&)> vmap(
const std::function<array(const array&)>& fun,
int in_axis = 0,
int out_axis = 0);
/**
* Automatically vectorize a binary function over the requested axes.
*/
std::function<array(const array&, const array&)> vmap(
const std::function<array(const array&, const array&)>& fun,
int in_axis_a = 0,
int in_axis_b = 0,
int out_axis = 0);
/**
* Automatically vectorize a function over the requested axes.
*
* The input function to `vmap` takes as an argument a vector of arrays and
* returns a vector of arrays. Optionally specify the axes to vectorize over
* with `in_axes` and `out_axes`, otherwise a default of 0 is used.
* Returns a vectorized function with the same signature as the input
* function.
*/
std::function<std::vector<array>(const std::vector<array>&)> vmap(
const std::function<std::vector<array>(const std::vector<array>&)>& fun,
const std::vector<int>& in_axes = {},
const std::vector<int>& out_axes = {});
/**
* Redefine the transformations of `fun` according to the provided functions.
*
* Namely when calling the vjp of `fun` then `fun_vjp` will be called,
* `fun_jvp` for the jvp and `fun_vmap` for vmap.
*
* If any transformation is not provided, then a default one is created by
* calling `vjp`, `jvp` and `vmap` on the function directly.
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_function(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)>> fun_vjp = std::nullopt,
std::optional<std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<int>&)>> fun_jvp = std::nullopt,
std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
const std::vector<array>&,
const std::vector<int>&)>> fun_vmap = std::nullopt);
/**
* Return a function that behaves exactly like `fun` but if the vjp of the
* results is computed `fun_vjp` will be used instead of `vjp(fun, ...)` .
*/
std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
std::function<std::vector<array>(const std::vector<array>&)> fun,
std::function<std::vector<array>(
const std::vector<array>&,
const std::vector<array>&,
const std::vector<array>&)> fun_vjp);
/**
* Checkpoint the gradient of a function. Namely, discard all intermediate
* state and recalculate it when we need to compute the gradient.
*/
std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
std::function<std::vector<array>(const std::vector<array>&)> fun);
} // namespace mlx::core