-
Notifications
You must be signed in to change notification settings - Fork 1.6k
/
component.yaml
27 lines (26 loc) · 2.34 KB
/
component.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
name: Keras - Train classifier
description: Trains classifier using Keras sequential model
inputs:
- {name: training_set_features_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set features table.'}
- {name: training_set_labels_path, type: {GcsPath: {data_type: TSV}}, description: 'Local or GCS path to the training set labels (each label is a class index from 0 to num-classes - 1).'}
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'Local or GCS path specifying where to save the trained model. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and move to outputs once artifact passing support is checked in.
- {name: model_config, type: {GcsPath: {data_type: Keras model config json}}, description: 'JSON string containing the serialized model structure. Can be obtained by calling model.to_json() on a Keras model.'}
- {name: number_of_classes, type: Integer, description: 'Number of classifier classes.'}
- {name: number_of_epochs, type: Integer, default: '100', description: 'Number of epochs to train the model. An epoch is an iteration over the entire `x` and `y` data provided.'}
- {name: batch_size, type: Integer, default: '32', description: 'Number of samples per gradient update.'}
outputs:
- {name: output_model_uri, type: {GcsPath: {data_type: Keras model}}, description: 'GCS path where the trained model has been saved. The model (topology + weights + optimizer state) is saved in HDF5 format and can be loaded back by calling keras.models.load_model'} #Remove GcsUri and make it a proper output once artifact passing support is checked in.
implementation:
container:
image: gcr.io/ml-pipeline/sample/keras/train_classifier
command: [python3, /pipelines/component/src/train.py]
args: [
--training-set-features-path, {inputValue: training_set_features_path},
--training-set-labels-path, {inputValue: training_set_labels_path},
--output-model-path, {inputValue: output_model_uri},
--model-config-json, {inputValue: model_config},
--num-classes, {inputValue: number_of_classes},
--num-epochs, {inputValue: number_of_epochs},
--batch-size, {inputValue: batch_size},
--output-model-path-file, {outputPath: output_model_uri},
]