forked from googleapis/python-aiplatform
/
create_training_pipeline_custom_training_managed_dataset_sample.py
103 lines (89 loc) · 3.58 KB
/
create_training_pipeline_custom_training_managed_dataset_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# [START aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]
from google.cloud import aiplatform
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
def create_training_pipeline_custom_training_managed_dataset_sample(
project: str,
display_name: str,
model_display_name: str,
dataset_id: str,
annotation_schema_uri: str,
training_container_spec_image_uri: str,
model_container_spec_image_uri: str,
base_output_uri_prefix: str,
location: str = "us-central1",
api_endpoint: str = "us-central1-aiplatform.googleapis.com",
):
client_options = {"api_endpoint": api_endpoint}
# Initialize client that will be used to create and send requests.
# This client only needs to be created once, and can be reused for
# multiple requests.
client = aiplatform.gapic.PipelineServiceClient(
client_options=client_options)
# input_data_config
input_data_config = {
"dataset_id": dataset_id,
"annotation_schema_uri": annotation_schema_uri,
"gcs_destination": {"output_uri_prefix": base_output_uri_prefix},
}
# training_task_definition
custom_task_definition = "gs://google-cloud-aiplatform/schema/" \
"trainingjob/definition/custom_task_1.0.0.yaml"
# training_task_inputs
training_container_spec = {
"imageUri": training_container_spec_image_uri,
# AIP_MODEL_DIR is set by the service according to baseOutputDirectory.
"args": ["--model-dir=$(AIP_MODEL_DIR)",],
}
training_worker_pool_spec = {
"replicaCount": 1,
"machineSpec": {"machineType": "n1-standard-8"},
"containerSpec": training_container_spec,
}
training_task_inputs_dict = {
"workerPoolSpecs": [training_worker_pool_spec],
"baseOutputDirectory": {"outputUriPrefix": base_output_uri_prefix},
}
training_task_inputs = json_format.ParseDict(
training_task_inputs_dict, Value())
# model_to_upload
model_container_spec = {
"image_uri": model_container_spec_image_uri,
"command": ["/bin/tensorflow_model_server"],
"args": [
"--model_name=$(AIP_MODEL)",
"--model_base_path=$(AIP_STORAGE_URI)",
"--rest_api_port=8080",
"--port=8500",
"--file_system_poll_wait_seconds=31540000"
],
}
model = {
"display_name": model_display_name,
"container_spec": model_container_spec}
training_pipeline = {
"display_name": display_name,
"input_data_config": input_data_config,
"training_task_definition": custom_task_definition,
"training_task_inputs": training_task_inputs,
"model_to_upload": model,
}
parent = f"projects/{project}/locations/{location}"
response = client.create_training_pipeline(
parent=parent, training_pipeline=training_pipeline
)
print("response:", response)
# [END aiplatform_create_training_pipeline_custom_training_managed_dataset_sample]