-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
gradient_builder_base.h
432 lines (340 loc) · 15 KB
/
gradient_builder_base.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <string>
#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/graph/graph.h"
#include "core/util/math.h"
#include "orttraining/core/graph/graph_augmenter.h"
#include "orttraining/core/graph/gradient_config.h"
#include "orttraining/core/graph/recompute_graph_utils.h"
#include "orttraining/core/graph/gradient_definition_registry.h"
#include "onnx/defs/attr_proto_util.h"
#include "onnx/defs/tensor_proto_util.h"
namespace onnxruntime {
namespace training {
using Dimension = onnx::TensorShapeProto_Dimension;
void ComputeBroadcastBackwardAxes(
const std::vector<Dimension>& A_dims,
const std::vector<Dimension>& B_dims,
std::vector<int64_t>* A_axes,
std::vector<int64_t>* B_axes,
const std::string& node_name = "");
void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a,
const ArgDef& b,
const ArgDef& a_shape,
const ArgDef& b_shape,
const ArgDef* a_axes,
const ArgDef* b_axes,
std::vector<NodeDef>& output);
Status GetShape(const ArgDef& arg_def, std::vector<Dimension>& shape);
typedef std::vector<NodeDef> GradientDef;
std::string GetGradientDefinitionKeyByNode(const Node& node);
class GradientBuilderBase {
public:
GradientBuilderBase(const GradientGraphConfiguration& gradient_graph_config,
Graph* graph,
const Node* node,
const std::unordered_set<std::string>& gradient_inputs,
const std::unordered_set<std::string>& gradient_outputs,
const logging::Logger& logger,
std::unordered_set<std::string>& stashed_tensors,
std::unordered_map<std::string, std::vector<int64_t>>& python_op_input_requires_grads)
: gradient_graph_config_(gradient_graph_config),
graph_(graph),
node_(node),
gradient_inputs_(gradient_inputs),
gradient_outputs_(gradient_outputs),
logger_(logger),
stashed_tensors_(stashed_tensors),
python_op_input_require_grad_info_(python_op_input_requires_grads) {
unique_node_prefix_ = CreateUniqueNodePrefix();
}
virtual ~GradientBuilderBase() {}
GradientDef GetGradientDefs() const {
GradientDef node_defs = GetGradientDefsImpl();
for (size_t i = 0; i < node_defs.size(); ++i) {
NodeDef& node_def = node_defs[i];
if (node_def.name.empty()) {
node_def.name = Name(node_def.op_type + "_" + std::to_string(i));
}
}
return node_defs;
}
static std::string GradientName(const std::string& name) {
return name + "_grad";
}
static std::string ExternalOutputName(const std::string& name) {
return name + "_external";
}
protected:
virtual GradientDef GetGradientDefsImpl() const = 0;
const GradientGraphConfiguration& GetGradientGraphConfiguration() const {
return gradient_graph_config_;
}
void RecordStashedTensor(const std::string& name) const {
stashed_tensors_.insert(name);
}
bool IsTensorStashed(const std::string& name) const {
return stashed_tensors_.find(name) != stashed_tensors_.end();
}
// i-th input of forward op
ArgDef I(const size_t i, bool record_stashing = true) const {
ORT_ENFORCE(i < node_->InputDefs().size());
const std::string& name = node_->InputDefs()[i]->Name();
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
if (recomputed_nodearg) {
const Node* producer_node = graph_->GetProducerNode(name);
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
}
if (record_stashing) {
RecordStashedTensor(node_->InputDefs()[i]->Name());
}
return ArgDef(node_->InputDefs()[i]->Name(), node_->InputDefs()[i]->TypeAsProto());
}
// i-th output of forward op
ArgDef O(const size_t i, bool record_stashing = true) const {
ORT_ENFORCE(i < node_->OutputDefs().size());
const std::string& name = node_->OutputDefs()[i]->Name();
const NodeArg* recomputed_nodearg = graph_->GetNodeArg(graph_utils::RecomputeName(name));
if (recomputed_nodearg) {
const Node* producer_node = graph_->GetProducerNode(name);
LOGS(logger_, INFO) << "Recomputed node arg found for " << producer_node->Name();
return ArgDef(recomputed_nodearg->Name(), recomputed_nodearg->TypeAsProto());
}
if (record_stashing) {
RecordStashedTensor(node_->OutputDefs()[i]->Name());
}
return ArgDef(node_->OutputDefs()[i]->Name(), node_->OutputDefs()[i]->TypeAsProto());
}
// gradient of i-th input of forward op
ArgDef GI(const size_t i) const {
ORT_ENFORCE(i < node_->InputDefs().size());
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), node_->InputDefs()[i]->TypeAsProto());
}
// gradient of i-th input of forward op - useful when gradient type does not match input type
ArgDef GI(const size_t i, const TypeProto* type) const {
ORT_ENFORCE(i < node_->InputDefs().size());
return ArgDef(GradientName(node_->InputDefs()[i]->Name()), type);
}
// gradient of i-th output of forward op
ArgDef GO(const size_t i) const {
ORT_ENFORCE(i < node_->OutputDefs().size());
return ArgDef(GradientName(node_->OutputDefs()[i]->Name()), node_->OutputDefs()[i]->TypeAsProto());
}
// intermediate argument
ArgDef IA(const std::string& argSuffix, const TypeProto* type_proto = nullptr) const {
return ArgDef(Name(argSuffix), type_proto);
}
// type of i-th input of forward op
const TypeProto* IType(const size_t i) const {
ORT_ENFORCE(i < node_->InputDefs().size());
return node_->InputDefs()[i]->TypeAsProto();
}
// type of i-th output of forward op
const TypeProto* OType(const size_t i) const {
ORT_ENFORCE(i < node_->OutputDefs().size());
return node_->OutputDefs()[i]->TypeAsProto();
}
// Element type of i-th input of forward op.
int IElemType(const size_t i) const {
return IType(i)->tensor_type().elem_type();
}
// Element type of i-th output of forward op.
int OElemType(const size_t i) const {
return OType(i)->tensor_type().elem_type();
}
int GetSrcNodeInputSize() const {
ORT_ENFORCE(node_ != nullptr);
return (int)node_->InputDefs().size();
}
int GetSrcNodeOutputSize() const {
ORT_ENFORCE(node_ != nullptr);
return (int)node_->OutputDefs().size();
}
// returns true if the input at index i of the node_ requires gradient
bool IsGradientRequiredForSrcNodeInput(const size_t i) const {
return i < node_->InputDefs().size() &&
gradient_outputs_.find(node_->InputDefs()[i]->Name()) != gradient_outputs_.end();
}
// returns true if the output at index i of the node_ has a gradient
bool IsGradientAvailableForSrcNodeOutput(const size_t i) const {
return i < node_->OutputDefs().size() &&
gradient_inputs_.find(node_->OutputDefs()[i]->Name()) != gradient_inputs_.end();
}
std::string Name(const std::string& name) const {
return unique_node_prefix_ + name;
}
const NodeAttributes& SrcNodeAttributes() const {
return node_->GetAttributes();
}
const std::string& SrcNodeOpType() const {
return node_->OpType();
}
int SrcNodeOpsetVersion() const {
return node_->Op()->since_version();
}
const std::string& SrcNodeDomain() const {
return node_->Op()->domain();
}
int OnnxOpSetVersion() const {
return graph_ != nullptr && graph_->DomainToVersionMap().find(kOnnxDomain) != graph_->DomainToVersionMap().end()
? graph_->DomainToVersionMap().at(kOnnxDomain)
: -1;
}
template <typename T>
static NodeDef ConstantVectorNode(const std::vector<T>& values, const std::string& arg_name) {
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(values);
t_proto.add_dims(values.size());
return NodeDef("Constant",
{},
{ArgDef(arg_name, nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
}
template <typename T>
static ONNX_NAMESPACE::TensorProto ScalarTensorProto(T value, std::vector<int64_t> shape) {
ORT_ENFORCE(shape.size() == 0 || (shape.size() == 1 && shape[0] == 1));
auto t_proto = ONNX_NAMESPACE::ToTensor<T>(value);
for (auto dim : shape) {
t_proto.add_dims(dim);
}
return t_proto;
}
template <typename T>
static NodeDef ConstantScalarNode(T value, std::vector<int64_t> shape, const std::string& arg_name) {
auto t_proto = ScalarTensorProto(value, shape);
return NodeDef("Constant",
{},
{ArgDef(arg_name, nullptr)},
{ONNX_NAMESPACE::MakeAttribute("value", t_proto)});
}
// We only support FP32, FP16 and BF16 for these constant nodes for now.
static NodeDef ConstantScalarNode(float value, const std::string& arg_name, int elem_type) {
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return ConstantScalarNode(MLFloat16(value), {1}, arg_name);
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
return ConstantScalarNode(BFloat16(value), {1}, arg_name);
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
return ConstantScalarNode(double(value), {1}, arg_name);
}
#if !defined(DISABLE_FLOAT8_TYPES)
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN) {
return ConstantScalarNode(Float8E4M3FN(value, true), {1}, arg_name);
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ) {
return ConstantScalarNode(Float8E4M3FNUZ(value, true), {1}, arg_name);
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2) {
return ConstantScalarNode(Float8E5M2(value, true), {1}, arg_name);
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ) {
return ConstantScalarNode(Float8E5M2FNUZ(value, true), {1}, arg_name);
}
#endif
ORT_ENFORCE(elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT,
"Unsupported element type for constant node: ", elem_type);
return ConstantScalarNode(value, {1}, arg_name);
}
static ONNX_NAMESPACE::TensorProto ScalarTensorProtoByElemType(float value, int elem_type) {
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
return ScalarTensorProto(MLFloat16(value), {1});
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
return ScalarTensorProto(BFloat16(value), {1});
}
#if !defined(DISABLE_FLOAT8_TYPES)
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN) {
return ScalarTensorProto(Float8E4M3FN(value, true), {1});
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FNUZ) {
return ScalarTensorProto(Float8E4M3FNUZ(value, true), {1});
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2) {
return ScalarTensorProto(Float8E5M2(value, true), {1});
}
if (elem_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2FNUZ) {
return ScalarTensorProto(Float8E5M2FNUZ(value, true), {1});
}
#endif
return ScalarTensorProto(value, {1});
}
static NodeDef ZeroConstantNode(int elem_type) {
return ConstantScalarNode(0.0f, "ZeroConstant_Type" + std::to_string(elem_type), elem_type);
}
static NodeDef HalfConstantNode(int elem_type) {
return ConstantScalarNode(0.5f, "HalfConstant_Type" + std::to_string(elem_type), elem_type);
}
static NodeDef OneConstantNode(int elem_type) {
return ConstantScalarNode(1.0f, "OneConstant_Type" + std::to_string(elem_type), elem_type);
}
void AddReduceSumNode(const ArgDef& input_arg_def,
const ArgDef& output_arg_def,
const std::vector<int64_t>& reduce_axes,
bool keep_dims,
std::vector<NodeDef>& output) const;
void HandleBroadcasting(const ArgDef& input_grad,
const ArgDef& target,
const ArgDef& output_grad,
const std::vector<int64_t>& reduce_axes,
std::vector<NodeDef>& output) const;
void HandleBroadcastingDynamic(const ArgDef& input_grad,
const ArgDef& target,
const ArgDef& target_shape,
const ArgDef& output_grad,
const ArgDef& reduce_axes,
std::vector<NodeDef>& output) const;
std::vector<NodeDef> GetBiasGeluGradNodes(
bool use_approximation,
const ArgDef& dY, const ArgDef& X, const ArgDef& B, // inputs
const ArgDef& dX, const ArgDef& dB, // outputs
const ArgDef& b_axes, const ArgDef& b_shape, const ArgDef& x_shape, // intermediate args
const std::string& node_name) const;
const std::string& NodeName() const { return node_->Name(); }
std::string GetGradientDefinitionKey() const { return GetGradientDefinitionKeyByNode(*node_); }
AttributeProto AttributeDefinitionToAttributeProto(const GradientNodeAttributeDefinition& attr_def) const;
void SetPythonOpRequireGradInfo(const std::string& node_name,
std::vector<int64_t> input_requires_grad_info) const;
private:
friend class GradientGraphBuilder;
std::string CreateUniqueNodePrefix() {
ORT_ENFORCE(node_ != nullptr);
auto name = node_->Name();
std::stringstream unique_prefix;
if (!name.empty()) {
unique_prefix << name << "_Grad/";
} else {
unique_prefix << graph_->GenerateNodeName(node_->OpType()) << "_Grad/";
}
return unique_prefix.str();
}
const GradientGraphConfiguration& gradient_graph_config_;
Graph* graph_;
const Node* node_;
std::string unique_node_prefix_;
// contains set of output arg names of node_ which is provided as gradient input to the bw node
std::unordered_set<std::string> gradient_inputs_;
// contains set of input arg names of node_ which requires gradient
std::unordered_set<std::string> gradient_outputs_;
const logging::Logger& logger_;
std::unordered_set<std::string>& stashed_tensors_;
std::unordered_map<std::string, std::vector<int64_t>>& python_op_input_require_grad_info_;
};
class EmptyGradientBuilder : public GradientBuilderBase {
using GradientBuilderBase::GradientBuilderBase;
GradientDef GetGradientDefsImpl() const override {
return GradientDef();
}
};
class UnSupportedGradientBuilder : public GradientBuilderBase {
using GradientBuilderBase::GradientBuilderBase;
GradientDef GetGradientDefsImpl() const override {
ORT_ENFORCE(false, "Gradient should not be requested for this operator");
}
};
} // namespace training
} // namespace onnxruntime