-
Notifications
You must be signed in to change notification settings - Fork 49
/
abstract_learner.proto
297 lines (240 loc) · 11.7 KB
/
abstract_learner.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
/*
* Copyright 2022 Google LLC.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
syntax = "proto2";
package yggdrasil_decision_forests.model.proto;
import public "yggdrasil_decision_forests/model/hyperparameter.proto";
import "yggdrasil_decision_forests/dataset/weight.proto";
import "yggdrasil_decision_forests/model/abstract_model.proto";
import "yggdrasil_decision_forests/utils/distribute/distribute.proto";
// Specification of the computing resources used to perform an action (e.g.
// train a model, run a cross-validation, generate predictions). The deployment
// configuration does not impact the results (e.g. learned model).
//
// If not specified, more consumer will assume local computation with multiple
// threads.
message DeploymentConfig {
// Next ID: 9
// Path to temporary directory available to the training algorithm.
// Currently cache_path is only used (and required) by the
// distributed algorithms or if "try_resume_training=True" (for the
// snapshots).
//
// In case of distributed training, the "cache_path" should be available by
// the manager and the workers (unless specified otherwise) -- so local
// machine/memory partition won't work.
optional string cache_path = 1;
// Number of threads.
optional int32 num_threads = 2 [default = 6];
// If true, try to resume an interrupted training using snapshots stored in
// the "cache_path". Not supported by all learning algorithms. Resuming
// training after changing the hyper-parameters might lead to failure when
// training is resumed.
optional bool try_resume_training = 6 [default = false];
// Indicative number of seconds in between snapshots when
// "try_resume_training=True". Might be ignored by some algorithms.
optional int64 resume_training_snapshot_interval_seconds = 7 [default = 1800];
// Number of threads to use for IO operations e.g. reading a dataset from
// disk. Increasing this value can speed-up IO operations when IO operations
// are either latency or cpu bounded.
optional int32 num_io_threads = 8 [default = 10];
// Maximum number of snapshots to keep.
optional int32 max_kept_snapshots = 9 [default = 3];
// Computation distribution engine.
oneof execution {
// Local execution.
Local local = 3;
// Distribution using the Distribute interface.
// Note that the selected distribution strategy implementation (selected in
// "distribute") needs to be linked with the binary if you are using the C++
// API.
yggdrasil_decision_forests.distribute.proto.Config distribute = 5;
}
reserved 4;
message Local {}
}
// Training configuration.
// Contains all the configuration for the training of a model e.g. label, input
// features, hyper-parameters.
message TrainingConfig {
// Next ID: 13
// Identifier of the learner e.g. "RANDOM_FOREST".
// The learner should be registered i.e. injected as a dependency to the
// binary. The list of available learners is available with
// "AllRegisteredModels()" in "model_library.h".
optional string learner = 1;
// List of regular expressions over the dataset columns defining the input
// features of the model. If empty, all the columns (with the exception of the
// label and cv_group) will be added as input features.
repeated string features = 2;
// Label column.
optional string label = 3;
// Name of the column used to split the dataset for in-training
// cross-validation i.e. all the records with the same "cv_group" value are in
// the same cross-validation fold. If not specified, examples are randomly
// assigned to train and test. This field is ignored by learner that do not
// run in-training cross-validation.
optional string cv_group = 4;
// Task / problem solved by the model.
optional Task task = 5 [default = CLASSIFICATION];
// Weighting of the training examples. If not specified, the weight is
// assumed uniform.
optional dataset.proto.WeightDefinition weight_definition = 6;
// Random seed for the training of the model. Learners are expected to be
// deterministic by the random seed.
optional int64 random_seed = 7 [default = 123456];
// Column identifying the groups in a ranking task.
// For example, in a document/query ranking problem, the "ranking_group" will
// be the query.
//
// The ranking column can be either a HASH or a CATEGORICAL. HASH is
// recommended. If CATEGORICAL, ensure dictionary is not pruned (i.e. minimum
// number of observations = 0 and maximum numbers of items = -1 => infinity).
optional string ranking_group = 8;
// Maximum training duration of the training expressed in seconds. If the
// learner does not support constrained the training time, the training will
// fails immediately. Each learning algorithm is free to use this parameter as
// it see fit. Enabling maximum training duration makes the model training
// non-deterministic.
optional double maximum_training_duration_seconds = 9;
reserved 10;
// Limits the trained model by memory usage. Different algorithms can enforce
// this limit differently. Serialized or compiled models are generally much
// smaller. This limit can be fussy: The final model can be slightly larger.
optional int64 maximum_model_size_in_memory_in_bytes = 11;
// Categorical column identifying the treatment group in an uplift task.
// For example, whether a patient received a treatment in a study about the
// impact of a medication.
//
// Only binary treatments are currently supported.
optional string uplift_treatment = 12;
// Metadata of the model.
// Non specified fields are automatically set. For example, if "metadata.date"
// is not set, it will be automatically set to the training date.
optional Metadata metadata = 13;
// Clear the model from any information that is not required for model
// serving. This includes debugging, model interpretation and other meta-data.
// The size of the serialized model can be reduced significatively (50% model
// size reduction is common). This parameter has no impact on the quality,
// serving speed or RAM usage of model serving.
optional bool pure_serving_model = 14 [default = false];
// Set of monotonic constraints between the model's input features and output.
repeated MonotonicConstraint monotonic_constraints = 15;
// Learner specific configuration/hyper-parameters.
// The message/extension is dependent on the "learner". For example, see
// "yggdrasil_decision_forests/learner/random_forest.proto" for the parameters
// of the "RANDOM_FOREST" learner.
//
// If not specified, all the learners are expected to have good default
// configuration/hyper-parameters.
//
// Common specialized hyper-parameters can be specified with a
// "GenericHyperParameters" proto. In this case, "GenericHyperParameters" will
// have higher priority than the extensions.
extensions 1000 to max;
}
// Resolution column string names into column indices.
// The column indies are defined in a given dataspec e.g. If
// dataspec.columns[5].name = "toto", then the column idx of "toto" is 5.
message TrainingConfigLinking {
// Next ID: 10
// Input features of the models.
repeated int32 features = 1 [packed = true];
// Features of type NUMERICAL.
repeated int32 numerical_features = 9 [packed = true];
// Label column.
optional int32 label = 2;
// Number categories of label (used for classification only).
optional int32 num_label_classes = 3;
// Index of the column matching "cv_group" in the "TrainingConfig".
optional int32 cv_group = 4;
optional dataset.proto.LinkedWeightDefinition weight_definition = 7;
// Index of the column matching "ranking_group" in the "TrainingConfig".
optional int32 ranking_group = 8 [default = -1];
// Index of the column matching "uplift_treatment" in the "TrainingConfig".
optional int32 uplift_treatment = 12 [default = -1];
// Data for specific dataset columns.
// This field is either empty, or contains exactly one value for each column
// in the dataset.
repeated PerColumn per_columns = 13;
}
// Returns a list of hyper-parameter sets that outperforms the default
// hyper-parameters (either generally or in specific scenarios). Like default
// hyper-parameters, existing pre-defined hyper-parameters cannot change.
message PredefinedHyperParameterTemplate {
// Name of the template. Should be unique for a given learning algorithm.
optional string name = 1;
// Version of the template.
optional int32 version = 2;
// Free text describing how this template was created.
optional string description = 3;
// Effective hyper-parameters.
optional GenericHyperParameters parameters = 4;
}
// "Capabilities" of a learner.
//
// Describe the capabilities/constraints/properties of a learner (all called
// "capabilities"). Capabilities are non-restrictive i.e. enabling a capability
// cannot restrict the domain of use of a learner/model (i.e. use "support_tpu"
// instead of "require_tpu").
//
// Using a learner with non-available capabilities raises an error.
message LearnerCapabilities {
// Does the learner support the "maximum_training_duration_seconds" parameter
// in the TrainingConfig.
optional bool support_max_training_duration = 1 [default = false];
// The learner can resume training of the model from the "cache_path" given in
// the deployment configuration.
optional bool resume_training = 2 [default = false];
// If true, the algorithm uses a validation dataset for training (e.g. for
// early stopping) and support for the validation dataset to be passed to the
// training method (with the "valid_dataset" or "typed_valid_path" argument).
// If the learning algorithm has the "use_validation_dataset" capability and
// no validation dataset is given to the training function, the learning
// algorithm will extract a validation dataset from the training dataset.
optional bool support_validation_dataset = 3 [default = false];
// If true, the algorithm supports training datasets in the "partial cache
// dataset" format.
optional bool support_partial_cache_dataset_format = 4 [default = false];
// If true, the algorithm supports training with a maximum model size
// (maximum_model_size_in_memory_in_bytes).
optional bool support_max_model_size_in_memory = 5 [default = false];
// If true, the algorithm supports monotonic constraints over numerical
// features.
optional bool support_monotonic_constraints = 6 [default = false];
// If true, the learner requires a label. If false, the learner does not
// require a label.
optional bool require_label = 7 [default = true];
// If true, the learner supports custom losses.
optional bool support_custom_loss = 8 [default = false];
}
// Monotonic constraints between model's output and numerical input features.
message MonotonicConstraint {
// Regular expressions over the input features.
optional string feature = 1;
optional Direction direction = 2 [default = INCREASING];
enum Direction {
// Ensure the model output is monotonic increasing (non-strict) with the
// feature.
INCREASING = 0;
// Ensure the model output is monotonic decreasing (non-strict) with the
// feature.
DECREASING = 1;
}
}
message PerColumn {
// If set, the attribute has a monotonic constraint.
// Note: monotonic_constraint.feature might not be set.
optional MonotonicConstraint monotonic_constraint = 1;
}