Skip to content
This repository was archived by the owner on Sep 9, 2023. It is now read-only.

Commit cf0b763

Browse files
feat: adds ValueConverter utility and demo samples (#108)
* feat: adds value converter utility class and demo samples * feat: samples updated for EJCL * fix: removed local file references * feat: adds ValueConverter tests Co-authored-by: yoshi-code-bot <70984784+yoshi-code-bot@users.noreply.github.com>
1 parent e8d357a commit cf0b763

File tree

7 files changed

+262
-30
lines changed

7 files changed

+262
-30
lines changed

Diff for: google-cloud-aiplatform/pom.xml

+9
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@
8383
<classifier>testlib</classifier>
8484
<scope>test</scope>
8585
</dependency>
86+
<dependency>
87+
<groupId>com.google.protobuf</groupId>
88+
<artifactId>protobuf-java-util</artifactId>
89+
</dependency>
90+
<dependency>
91+
<groupId>com.google.code.gson</groupId>
92+
<artifactId>gson</artifactId>
93+
<scope>test</scope>
94+
</dependency>
8695
</dependencies>
8796

8897
<profiles>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.aiplatform.util;
18+
19+
import com.google.protobuf.InvalidProtocolBufferException;
20+
import com.google.protobuf.Message;
21+
import com.google.protobuf.Value;
22+
import com.google.protobuf.util.JsonFormat;
23+
24+
/**
25+
* Exposes utility methods for converting AI Platform messages to and from
26+
* {@com.google.protobuf.Value} objects.
27+
*/
28+
public class ValueConverter {
29+
30+
/** An empty {@com.google.protobuf.Value} message. */
31+
public static final Value EMPTY_VALUE = Value.newBuilder().build();
32+
33+
/**
34+
* Converts a message type to a {@com.google.protobuf.Value}.
35+
*
36+
* @param message the message to convert
37+
* @return the message as a {@com.google.protobuf.Value}
38+
* @throws InvalidProtocolBufferException
39+
*/
40+
public static Value toValue(Message message) throws InvalidProtocolBufferException {
41+
String jsonString = JsonFormat.printer().print(message);
42+
Value.Builder value = Value.newBuilder();
43+
JsonFormat.parser().merge(jsonString, value);
44+
return value.build();
45+
}
46+
47+
/**
48+
* Converts a {@com.google.protobuf.Value} to a {@com.google.protobuf.Message} of the provided
49+
* {@com.google.protobuf.Message.Builder}.
50+
*
51+
* @param messageBuilder a builder for the message type
52+
* @param value the Value to convert to a message
53+
* @return the value as a message
54+
* @throws InvalidProtocolBufferException
55+
*/
56+
public static Message fromValue(Message.Builder messageBuilder, Value value)
57+
throws InvalidProtocolBufferException {
58+
String valueString = JsonFormat.printer().print(value);
59+
JsonFormat.parser().merge(valueString, messageBuilder);
60+
return messageBuilder.build();
61+
}
62+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.aiplatform.util;
18+
19+
import static org.junit.Assert.assertEquals;
20+
import static org.junit.Assert.assertThrows;
21+
22+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
23+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
24+
import com.google.gson.JsonObject;
25+
import com.google.protobuf.InvalidProtocolBufferException;
26+
import com.google.protobuf.MapEntry;
27+
import com.google.protobuf.Struct;
28+
import com.google.protobuf.Value;
29+
import com.google.protobuf.util.JsonFormat;
30+
import java.util.Collection;
31+
import org.junit.Test;
32+
import org.junit.function.ThrowingRunnable;
33+
34+
public class ValueConverterTest {
35+
36+
@Test
37+
public void testValueConverterToValue() throws InvalidProtocolBufferException {
38+
AutoMlImageClassificationInputs testObjectInputs =
39+
AutoMlImageClassificationInputs.newBuilder()
40+
.setModelType(ModelType.CLOUD)
41+
.setBudgetMilliNodeHours(8000)
42+
.setMultiLabel(true)
43+
.setDisableEarlyStopping(false)
44+
.build();
45+
46+
Value actualConvertedValue = ValueConverter.toValue(testObjectInputs);
47+
48+
Struct actualStruct = actualConvertedValue.getStructValue();
49+
assertEquals(3, actualStruct.getFieldsCount());
50+
51+
Collection<Object> innerFields = actualStruct.getAllFields().values();
52+
Collection<MapEntry> fieldEntries = (Collection<MapEntry>) innerFields.toArray()[0];
53+
54+
MapEntry actualBoolValueEntry = null;
55+
MapEntry actualStringValueEntry = null;
56+
MapEntry actualNumberValueEntry = null;
57+
58+
for (MapEntry entry : fieldEntries) {
59+
String key = entry.getKey().toString();
60+
if (key.equals("multiLabel")) {
61+
actualBoolValueEntry = entry;
62+
} else if (key.equals("modelType")) {
63+
actualStringValueEntry = entry;
64+
} else if (key.equals("budgetMilliNodeHours")) {
65+
actualNumberValueEntry = entry;
66+
}
67+
}
68+
69+
Value actualBoolValue = (Value) actualBoolValueEntry.getValue();
70+
assertEquals(testObjectInputs.getMultiLabel(), actualBoolValue.getBoolValue());
71+
72+
Value actualStringValue = (Value) actualStringValueEntry.getValue();
73+
assertEquals("CLOUD", actualStringValue.getStringValue());
74+
75+
Value actualNumberValue = (Value) actualNumberValueEntry.getValue();
76+
// protobuf stores int64 values as strings rather than numbers
77+
long actualNumber = Long.parseLong(actualNumberValue.getStringValue());
78+
assertEquals(testObjectInputs.getBudgetMilliNodeHours(), actualNumber);
79+
}
80+
81+
@Test
82+
public void testValueConverterFromValue() throws InvalidProtocolBufferException {
83+
84+
JsonObject testJsonInputs = new JsonObject();
85+
testJsonInputs.addProperty("multi_label", true);
86+
testJsonInputs.addProperty("model_type", "CLOUD");
87+
testJsonInputs.addProperty("budget_milli_node_hours", 8000);
88+
89+
Value.Builder valueBuilder = Value.newBuilder();
90+
JsonFormat.parser().merge(testJsonInputs.toString(), valueBuilder);
91+
Value testValueInputs = valueBuilder.build();
92+
93+
AutoMlImageClassificationInputs actualInputs =
94+
(AutoMlImageClassificationInputs)
95+
ValueConverter.fromValue(AutoMlImageClassificationInputs.newBuilder(), testValueInputs);
96+
97+
assertEquals(8000, actualInputs.getBudgetMilliNodeHours());
98+
assertEquals(true, actualInputs.getMultiLabel());
99+
assertEquals(ModelType.CLOUD, actualInputs.getModelType());
100+
}
101+
102+
@Test
103+
public void testValueConverterFromValueWithBadInputs() throws InvalidProtocolBufferException {
104+
JsonObject testBadJsonInputs = new JsonObject();
105+
testBadJsonInputs.addProperty("wrong_key", "some_value");
106+
107+
Value.Builder badValueBuilder = Value.newBuilder();
108+
JsonFormat.parser().merge(testBadJsonInputs.toString(), badValueBuilder);
109+
final Value testBadValueInputs = badValueBuilder.build();
110+
111+
assertThrows(
112+
InvalidProtocolBufferException.class,
113+
new ThrowingRunnable() {
114+
@Override
115+
public void run() throws Throwable {
116+
AutoMlImageClassificationInputs actualBadInput =
117+
(AutoMlImageClassificationInputs)
118+
ValueConverter.fromValue(
119+
AutoMlImageClassificationInputs.newBuilder(), testBadValueInputs);
120+
}
121+
});
122+
}
123+
}

Diff for: samples/snippets/pom.xml

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
<dependency>
2828
<groupId>com.google.cloud</groupId>
2929
<artifactId>google-cloud-aiplatform</artifactId>
30-
<version>0.1.0</version>
30+
<version>0.1.1-SNAPSHOT</version>
3131
</dependency>
3232
<!-- [END aiplatform_install_with_bom] -->
3333
<dependency>

Diff for: samples/snippets/src/main/java/aiplatform/CreateTrainingPipelineImageClassificationSample.java

+11-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_create_training_pipeline_image_classification_sample]
20-
20+
import com.google.cloud.aiplatform.util.ValueConverter;
2121
import com.google.cloud.aiplatform.v1beta1.DeployedModelRef;
2222
import com.google.cloud.aiplatform.v1beta1.EnvVar;
2323
import com.google.cloud.aiplatform.v1beta1.ExplanationMetadata;
@@ -38,8 +38,8 @@
3838
import com.google.cloud.aiplatform.v1beta1.SampledShapleyAttribution;
3939
import com.google.cloud.aiplatform.v1beta1.TimestampSplit;
4040
import com.google.cloud.aiplatform.v1beta1.TrainingPipeline;
41-
import com.google.protobuf.Value;
42-
import com.google.protobuf.util.JsonFormat;
41+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs;
42+
import com.google.cloud.aiplatform.v1beta1.schema.trainingjob.definition.AutoMlImageClassificationInputs.ModelType;
4343
import com.google.rpc.Status;
4444
import java.io.IOException;
4545

@@ -74,11 +74,13 @@ static void createTrainingPipelineImageClassificationSample(
7474
+ "automl_image_classification_1.0.0.yaml";
7575
LocationName locationName = LocationName.of(project, location);
7676

77-
String jsonString =
78-
"{\"multiLabel\": false, \"modelType\": \"CLOUD\", \"budgetMilliNodeHours\": 8000,"
79-
+ " \"disableEarlyStopping\": false}";
80-
Value.Builder trainingTaskInputs = Value.newBuilder();
81-
JsonFormat.parser().merge(jsonString, trainingTaskInputs);
77+
AutoMlImageClassificationInputs autoMlImageClassificationInputs =
78+
AutoMlImageClassificationInputs.newBuilder()
79+
.setModelType(ModelType.CLOUD)
80+
.setMultiLabel(false)
81+
.setBudgetMilliNodeHours(8000)
82+
.setDisableEarlyStopping(false)
83+
.build();
8284

8385
InputDataConfig trainingInputDataConfig =
8486
InputDataConfig.newBuilder().setDatasetId(datasetId).build();
@@ -87,7 +89,7 @@ static void createTrainingPipelineImageClassificationSample(
8789
TrainingPipeline.newBuilder()
8890
.setDisplayName(trainingPipelineDisplayName)
8991
.setTrainingTaskDefinition(trainingTaskDefinition)
90-
.setTrainingTaskInputs(trainingTaskInputs)
92+
.setTrainingTaskInputs(ValueConverter.toValue(autoMlImageClassificationInputs))
9193
.setInputDataConfig(trainingInputDataConfig)
9294
.setModelToUpload(model)
9395
.build();

Diff for: samples/snippets/src/main/java/aiplatform/PredictImageClassificationSample.java

+31-9
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
// [START aiplatform_predict_image_classification_sample]
2020

2121
import com.google.api.client.util.Base64;
22+
import com.google.cloud.aiplatform.util.ValueConverter;
2223
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2324
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
2425
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
2526
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
27+
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.ImageClassificationPredictionInstance;
28+
import com.google.cloud.aiplatform.v1beta1.schema.predict.params.ImageClassificationPredictionParams;
29+
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
2630
import com.google.protobuf.Value;
27-
import com.google.protobuf.util.JsonFormat;
2831
import java.io.IOException;
2932
import java.nio.charset.StandardCharsets;
3033
import java.nio.file.Files;
@@ -60,23 +63,42 @@ static void predictImageClassification(String project, String fileName, String e
6063
byte[] contents = Base64.encodeBase64(Files.readAllBytes(Paths.get(fileName)));
6164
String content = new String(contents, StandardCharsets.UTF_8);
6265

63-
Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
64-
65-
String contentDict = "{\"content\": \"" + content + "\"}";
66-
Value.Builder instance = Value.newBuilder();
67-
JsonFormat.parser().merge(contentDict, instance);
66+
ImageClassificationPredictionInstance predictionInstance =
67+
ImageClassificationPredictionInstance.newBuilder()
68+
.setContent(content)
69+
.build();
6870

6971
List<Value> instances = new ArrayList<>();
70-
instances.add(instance.build());
72+
instances.add(ValueConverter.toValue(predictionInstance));
73+
74+
ImageClassificationPredictionParams predictionParams =
75+
ImageClassificationPredictionParams.newBuilder()
76+
.setConfidenceThreshold((float) 0.5)
77+
.setMaxPredictions(5)
78+
.build();
7179

7280
PredictResponse predictResponse =
73-
predictionServiceClient.predict(endpointName, instances, parameter);
81+
predictionServiceClient.predict(endpointName, instances,
82+
ValueConverter.toValue(predictionParams));
7483
System.out.println("Predict Image Classification Response");
7584
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
7685

7786
System.out.println("Predictions");
7887
for (Value prediction : predictResponse.getPredictionsList()) {
79-
System.out.format("\tPrediction: %s\n", prediction);
88+
89+
ClassificationPredictionResult.Builder resultBuilder =
90+
ClassificationPredictionResult.newBuilder();
91+
// Display names and confidences values correspond to
92+
// IDs in the ID list.
93+
ClassificationPredictionResult result =
94+
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
95+
int counter = 0;
96+
for (Long id : result.getIdsList()) {
97+
System.out.printf("Label ID: %d\n", id);
98+
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
99+
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
100+
counter++;
101+
}
80102
}
81103
}
82104
}

Diff for: samples/snippets/src/main/java/aiplatform/PredictTextClassificationSingleLabelSample.java

+25-11
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,14 @@
1717
package aiplatform;
1818

1919
// [START aiplatform_predict_text_classification_sample]
20-
20+
import com.google.cloud.aiplatform.util.ValueConverter;
2121
import com.google.cloud.aiplatform.v1beta1.EndpointName;
2222
import com.google.cloud.aiplatform.v1beta1.PredictResponse;
2323
import com.google.cloud.aiplatform.v1beta1.PredictionServiceClient;
2424
import com.google.cloud.aiplatform.v1beta1.PredictionServiceSettings;
25+
import com.google.cloud.aiplatform.v1beta1.schema.predict.instance.TextClassificationPredictionInstance;
26+
import com.google.cloud.aiplatform.v1beta1.schema.predict.prediction.ClassificationPredictionResult;
2527
import com.google.protobuf.Value;
26-
import com.google.protobuf.util.JsonFormat;
2728
import java.io.IOException;
2829
import java.util.ArrayList;
2930
import java.util.List;
@@ -52,25 +53,38 @@ static void predictTextClassificationSingleLabel(
5253
try (PredictionServiceClient predictionServiceClient =
5354
PredictionServiceClient.create(predictionServiceSettings)) {
5455
String location = "us-central1";
55-
String jsonString = "{\"content\": \"" + content + "\"}";
56-
5756
EndpointName endpointName = EndpointName.of(project, location, endpointId);
5857

59-
Value parameter = Value.newBuilder().setNumberValue(0).setNumberValue(5).build();
60-
Value.Builder instance = Value.newBuilder();
61-
JsonFormat.parser().merge(jsonString, instance);
58+
TextClassificationPredictionInstance predictionInstance = TextClassificationPredictionInstance
59+
.newBuilder()
60+
.setContent(content)
61+
.build();
6262

6363
List<Value> instances = new ArrayList<>();
64-
instances.add(instance.build());
64+
instances.add(ValueConverter.toValue(predictionInstance));
6565

6666
PredictResponse predictResponse =
67-
predictionServiceClient.predict(endpointName, instances, parameter);
67+
predictionServiceClient.predict(endpointName, instances, ValueConverter.EMPTY_VALUE);
6868
System.out.println("Predict Text Classification Response");
6969
System.out.format("\tDeployed Model Id: %s\n", predictResponse.getDeployedModelId());
7070

71-
System.out.println("Predictions");
71+
System.out.println("Predictions:\n\n");
7272
for (Value prediction : predictResponse.getPredictionsList()) {
73-
System.out.format("\tPrediction: %s\n", prediction);
73+
74+
ClassificationPredictionResult.Builder resultBuilder =
75+
ClassificationPredictionResult.newBuilder();
76+
77+
// Display names and confidences values correspond to
78+
// IDs in the ID list.
79+
ClassificationPredictionResult result =
80+
(ClassificationPredictionResult) ValueConverter.fromValue(resultBuilder, prediction);
81+
int counter = 0;
82+
for (Long id : result.getIdsList()) {
83+
System.out.printf("Label ID: %d\n", id);
84+
System.out.printf("Label: %s\n", result.getDisplayNames(counter));
85+
System.out.printf("Confidence: %.4f\n", result.getConfidences(counter));
86+
counter++;
87+
}
7488
}
7589
}
7690
}

0 commit comments

Comments
 (0)