From e0d6332c354968a8da8f37c7df9ead86f048f7f9 Mon Sep 17 00:00:00 2001 From: Chamikara Jayalath Date: Sat, 4 Sep 2021 14:20:26 -0700 Subject: [PATCH] [BEAM-12769] Adds support for expanding a Java cross-language transform using the class name and builder methods (#15343) * Adds support for expanding a Java cross-language transform using the class name and builder methods * Adds an allowlist and adds support for annotations * Fix tests * Address CheckerFramework errors * Adds license * Addresses reviewer comments. * Apply suggestions from code review Co-authored-by: Lukasz Cwik * Addresses reviewer comments. * Updated the proto to include a single schema/payload for constructor and each builder method. Updated the implementation accordingly and added additional tests. * Some doc updates and few other minor updates. * Addressing reviewer comments Co-authored-by: Lukasz Cwik --- .../src/main/proto/external_transforms.proto | 63 + sdks/java/expansion-service/build.gradle | 3 + .../expansion/service/ExpansionService.java | 36 +- .../service/ExpansionServiceOptions.java | 75 ++ .../JavaClassLookupTransformProvider.java | 526 ++++++++ .../service/MultiLanguageBuilderMethod.java | 31 + .../MultiLanguageConstructorMethod.java | 31 + .../service/ExpansionServiceTest.java | 16 +- .../JavaCLassLookupTransformProviderTest.java | 1111 +++++++++++++++++ .../src/test/resources/test_allowlist.yaml | 67 + 10 files changed, 1941 insertions(+), 18 deletions(-) create mode 100644 sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java create mode 100644 sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java create mode 100644 sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java create mode 100644 sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java create mode 100644 sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java create mode 100644 sdks/java/expansion-service/src/test/resources/test_allowlist.yaml diff --git a/model/pipeline/src/main/proto/external_transforms.proto b/model/pipeline/src/main/proto/external_transforms.proto index f2d47a17693a7..a528e565e4b01 100644 --- a/model/pipeline/src/main/proto/external_transforms.proto +++ b/model/pipeline/src/main/proto/external_transforms.proto @@ -29,6 +29,7 @@ option java_package = "org.apache.beam.model.pipeline.v1"; option java_outer_classname = "ExternalTransforms"; import "schema.proto"; +import "beam_runner_api.proto"; // A configuration payload for an external transform. // Used as the payload of ExternalTransform as part of an ExpansionRequest. @@ -40,3 +41,65 @@ message ExternalConfigurationPayload { // schema. bytes payload = 2; } + +// Defines specific expansion methods that may be used to expand cross-language +// transforms. +// Has to be set as the URN of the transform of the expansion request. +message ExpansionMethods { + enum Enum { + // Expand a Java transform using specified constructor and builder methods. + // Transform payload will be of type JavaClassLookupPayload. + JAVA_CLASS_LOOKUP = 0 [(org.apache.beam.model.pipeline.v1.beam_urn) = + "beam:expansion:payload:java_class_lookup:v1"]; + } +} + +// A configuration payload for an external transform. +// Used to define a Java transform that can be directly instantiated by a Java +// expansion service. +message JavaClassLookupPayload { + // Name of the Java transform class. + string class_name = 1; + + // A static method to construct the initial instance of the transform. + // If not provided, the transform should be instantiated using a class + // constructor. + string constructor_method = 2; + + // The top level fields of the schema represent the method parameters in + // order. + // If able, top level field names are also verified against the method + // parameters for a match. + Schema constructor_schema = 3; + + // A payload which can be decoded using beam:coder:row:v1 and the provided + // constructor schema. + bytes constructor_payload = 4; + + // Set of builder methods and corresponding parameters to apply after the + // transform object is constructed. + // When constructing the transform object, given builder methods will be + // applied in order. + repeated BuilderMethod builder_methods = 5; +} + +// This represents a builder method of the transform class that should be +// applied in-order after instantiating the initial transform object. +// Each builder method may take one or more parameters and has to return an +// instance of the transform object. +message BuilderMethod { + // Name of the builder method + string name = 1; + + // The top level fields of the schema represent the method parameters in + // order. + // If able, top level field names are also verified against the method + // parameters for a match. + Schema schema = 2; + + // A payload which can be decoded using beam:coder:row:v1 and the builder + // method schema. + bytes payload = 3; +} + + diff --git a/sdks/java/expansion-service/build.gradle b/sdks/java/expansion-service/build.gradle index a6263036a2bad..2a0ffd0a7b27f 100644 --- a/sdks/java/expansion-service/build.gradle +++ b/sdks/java/expansion-service/build.gradle @@ -38,6 +38,9 @@ dependencies { compile project(path: ":sdks:java:core", configuration: "shadow") compile project(path: ":runners:core-construction-java") compile project(path: ":runners:java-fn-execution") + compile library.java.jackson_annotations + compile library.java.jackson_databind + compile library.java.jackson_dataformat_yaml compile library.java.vendored_grpc_1_36_0 compile library.java.vendored_guava_26_0_jre compile library.java.slf4j_api diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java index eaa1cbe3a7422..6e1f3d30d64d7 100644 --- a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionService.java @@ -17,6 +17,7 @@ */ package org.apache.beam.sdk.expansion.service; +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; import static org.apache.beam.runners.core.construction.resources.PipelineResources.detectClassPathResourcesToStage; import static org.apache.beam.sdk.util.Preconditions.checkArgumentNotNull; @@ -35,8 +36,10 @@ import java.util.stream.Collectors; import org.apache.beam.model.expansion.v1.ExpansionApi; import org.apache.beam.model.expansion.v1.ExpansionServiceGrpc; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExternalConfigurationPayload; import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.runners.core.construction.Environments; import org.apache.beam.runners.core.construction.PipelineTranslation; import org.apache.beam.runners.core.construction.RehydratedComponents; @@ -49,6 +52,7 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.expansion.ExternalTransformRegistrar; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; import org.apache.beam.sdk.options.ExperimentalOptions; import org.apache.beam.sdk.options.PipelineOptions; import org.apache.beam.sdk.options.PipelineOptionsFactory; @@ -70,6 +74,7 @@ import org.apache.beam.sdk.values.POutput; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.Server; import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.ServerBuilder; import org.apache.beam.vendor.grpc.v1p36p0.io.grpc.stub.StreamObserver; @@ -172,8 +177,8 @@ private static Class getConfigClass( return configurationClass; } - private static Row decodeRow(ExternalConfigurationPayload payload) { - Schema payloadSchema = SchemaTranslation.schemaFromProto(payload.getSchema()); + static Row decodeConfigObjectRow(SchemaApi.Schema schema, ByteString payload) { + Schema payloadSchema = SchemaTranslation.schemaFromProto(schema); if (payloadSchema.getFieldCount() == 0) { return Row.withSchema(Schema.of()).build(); @@ -200,7 +205,7 @@ private static Row decodeRow(ExternalConfigurationPayload payload) { Row configRow; try { - configRow = RowCoder.of(payloadSchema).decode(payload.getPayload().newInput()); + configRow = RowCoder.of(payloadSchema).decode(payload.newInput()); } catch (IOException e) { throw new RuntimeException("Error decoding payload", e); } @@ -247,7 +252,7 @@ private static ConfigT payloadToConfigSchema( SerializableFunction fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(configurationClass); - Row payloadRow = decodeRow(payload); + Row payloadRow = decodeConfigObjectRow(payload.getSchema(), payload.getPayload()); if (!payloadRow.getSchema().assignableTo(configSchema)) { throw new IllegalArgumentException( @@ -263,7 +268,7 @@ private static ConfigT payloadToConfigSchema( private static ConfigT payloadToConfigSetters( ExternalConfigurationPayload payload, Class configurationClass) throws ReflectiveOperationException { - Row configRow = decodeRow(payload); + Row configRow = decodeConfigObjectRow(payload.getSchema(), payload.getPayload()); Constructor constructor = configurationClass.getDeclaredConstructor(); constructor.setAccessible(true); @@ -459,13 +464,22 @@ private Map loadRegisteredTransforms() { } })); - @Nullable - TransformProvider transformProvider = - getRegisteredTransforms().get(request.getTransform().getSpec().getUrn()); - if (transformProvider == null) { - throw new UnsupportedOperationException( - "Unknown urn: " + request.getTransform().getSpec().getUrn()); + String urn = request.getTransform().getSpec().getUrn(); + + TransformProvider transformProvider = null; + if (getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP).equals(urn)) { + AllowList allowList = + pipelineOptions.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlist(); + assert allowList != null; + transformProvider = new JavaClassLookupTransformProvider(allowList); + } else { + transformProvider = getRegisteredTransforms().get(urn); + if (transformProvider == null) { + throw new UnsupportedOperationException( + "Unknown urn: " + request.getTransform().getSpec().getUrn()); + } } + Map> outputs = transformProvider.apply( pipeline, diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java new file mode 100644 index 0000000000000..79e870cd07f2a --- /dev/null +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/ExpansionServiceOptions.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import java.io.File; +import java.io.IOException; +import java.util.ArrayList; +import org.apache.beam.sdk.expansion.service.JavaClassLookupTransformProvider.AllowList; +import org.apache.beam.sdk.options.Default; +import org.apache.beam.sdk.options.DefaultValueFactory; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; + +/** Options used to configure the {@link ExpansionService}. */ +public interface ExpansionServiceOptions extends PipelineOptions { + + @Description("Allow list for Java class based transform expansion") + @Default.InstanceFactory(JavaClassLookupAllowListFactory.class) + AllowList getJavaClassLookupAllowlist(); + + void setJavaClassLookupAllowlist(AllowList file); + + @Description("Allow list file for Java class based transform expansion") + String getJavaClassLookupAllowlistFile(); + + void setJavaClassLookupAllowlistFile(String file); + + /** + * Loads the allow list from {@link #getJavaClassLookupAllowlistFile}, defaulting to an empty + * {@link JavaClassLookupTransformProvider.AllowList}. + */ + class JavaClassLookupAllowListFactory implements DefaultValueFactory { + + @Override + public AllowList create(PipelineOptions options) { + String allowListFile = + options.as(ExpansionServiceOptions.class).getJavaClassLookupAllowlistFile(); + if (allowListFile != null) { + ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + File allowListFileObj = new File(allowListFile); + if (!allowListFileObj.exists()) { + throw new IllegalArgumentException( + "Allow list file " + allowListFile + " does not exist"); + } + try { + return mapper.readValue(allowListFileObj, AllowList.class); + } catch (IOException e) { + throw new IllegalArgumentException( + "Could not load the provided allowlist file " + allowListFile, e); + } + } + + // By default produces an empty allow-list. + return new AutoValue_JavaClassLookupTransformProvider_AllowList( + JavaClassLookupTransformProvider.ALLOW_LIST_VERSION, new ArrayList<>()); + } + } +} diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java new file mode 100644 index 0000000000000..d32c7e4207a8d --- /dev/null +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/JavaClassLookupTransformProvider.java @@ -0,0 +1,526 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.auto.value.AutoValue; +import edu.umd.cs.findbugs.annotations.SuppressFBWarnings; +import java.io.IOException; +import java.lang.annotation.Annotation; +import java.lang.reflect.Array; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.JavaClassLookupPayload; +import org.apache.beam.model.pipeline.v1.RunnerApi.FunctionSpec; +import org.apache.beam.model.pipeline.v1.SchemaApi; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.ClassUtils; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.expansion.service.ExpansionService.TransformProvider; +import org.apache.beam.sdk.schemas.JavaFieldSchema; +import org.apache.beam.sdk.schemas.NoSuchSchemaException; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.TypeName; +import org.apache.beam.sdk.schemas.SchemaRegistry; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.util.common.ReflectHelpers; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.POutput; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.checkerframework.checker.nullness.qual.Nullable; + +/** + * A transform provider that can be used to directly instantiate a transform using Java class name + * and builder methods. + * + * @param input {@link PInput} type of the transform + * @param output {@link POutput} type of the transform + */ +@SuppressWarnings({"argument.type.incompatible", "assignment.type.incompatible"}) +@SuppressFBWarnings("UWF_UNWRITTEN_PUBLIC_OR_PROTECTED_FIELD") +class JavaClassLookupTransformProvider + implements TransformProvider { + + public static final String ALLOW_LIST_VERSION = "v1"; + private static final SchemaRegistry SCHEMA_REGISTRY = SchemaRegistry.createDefault(); + private final AllowList allowList; + + public JavaClassLookupTransformProvider(AllowList allowList) { + if (!allowList.getVersion().equals(ALLOW_LIST_VERSION)) { + throw new IllegalArgumentException("Unknown allow-list version"); + } + this.allowList = allowList; + } + + @Override + public PTransform getTransform(FunctionSpec spec) { + JavaClassLookupPayload payload; + try { + payload = JavaClassLookupPayload.parseFrom(spec.getPayload()); + } catch (InvalidProtocolBufferException e) { + throw new IllegalArgumentException( + "Invalid payload type for URN " + getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP), e); + } + + String className = payload.getClassName(); + try { + AllowedClass allowlistClass = null; + if (this.allowList != null) { + for (AllowedClass cls : this.allowList.getAllowedClasses()) { + if (cls.getClassName().equals(className)) { + if (allowlistClass != null) { + throw new IllegalArgumentException( + "Found two matching allowlist classes " + allowlistClass + " and " + cls); + } + allowlistClass = cls; + } + } + } + if (allowlistClass == null) { + throw new UnsupportedOperationException( + "The provided allow list does not enable expanding a transform class by the name " + + className + + "."); + } + Class> transformClass = + (Class>) + ReflectHelpers.findClassLoader().loadClass(className); + PTransform transform; + Row constructorRow = + decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); + if (payload.getConstructorMethod().isEmpty()) { + Constructor[] constructors = transformClass.getConstructors(); + Constructor> constructor = + findMappingConstructor(constructors, payload); + Object[] parameterValues = + getParameterValues( + constructor.getParameters(), + constructorRow, + constructor.getGenericParameterTypes()); + transform = (PTransform) constructor.newInstance(parameterValues); + } else { + Method[] methods = transformClass.getMethods(); + Method method = findMappingConstructorMethod(methods, payload, allowlistClass); + Object[] parameterValues = + getParameterValues( + method.getParameters(), constructorRow, method.getGenericParameterTypes()); + transform = (PTransform) method.invoke(null /* static */, parameterValues); + } + return applyBuilderMethods(transform, payload, allowlistClass); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException("Could not find class " + className, e); + } catch (InstantiationException + | IllegalArgumentException + | IllegalAccessException + | InvocationTargetException e) { + throw new IllegalArgumentException("Could not instantiate class " + className, e); + } + } + + private PTransform applyBuilderMethods( + PTransform transform, + JavaClassLookupPayload payload, + AllowedClass allowListClass) { + for (BuilderMethod builderMethod : payload.getBuilderMethodsList()) { + Method method = getMethod(transform, builderMethod, allowListClass); + try { + Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload()); + transform = + (PTransform) + method.invoke( + transform, + getParameterValues( + method.getParameters(), + builderMethodRow, + method.getGenericParameterTypes())); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalArgumentException( + "Could not invoke the builder method " + + builderMethod + + " on transform " + + transform + + " with parameter schema " + + builderMethod.getSchema(), + e); + } + } + + return transform; + } + + private boolean isBuilderMethodForName( + Method method, String nameFromPayload, AllowedClass allowListClass) { + // Lookup based on method annotations + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof MultiLanguageBuilderMethod) { + if (nameFromPayload.equals(((MultiLanguageBuilderMethod) annotation).name())) { + if (allowListClass.getAllowedBuilderMethods().contains(nameFromPayload)) { + return true; + } else { + throw new RuntimeException( + "Builder method " + nameFromPayload + " has to be explicitly allowed"); + } + } + } + } + + // Lookup based on the method name. + boolean match = method.getName().equals(nameFromPayload); + String consideredMethodName = method.getName(); + + // We provide a simplification for common Java builder pattern naming convention where builder + // methods start with "with". In this case, for a builder method name in the form "withXyz", + // users may just use "xyz". If additional updates to the method name are needed the transform + // has to be updated by adding annotations. + if (!match && consideredMethodName.length() > 4 && consideredMethodName.startsWith("with")) { + consideredMethodName = + consideredMethodName.substring(4, 5).toLowerCase() + consideredMethodName.substring(5); + match = consideredMethodName.equals(nameFromPayload); + } + if (match && !allowListClass.getAllowedBuilderMethods().contains(consideredMethodName)) { + throw new RuntimeException( + "Builder method name " + consideredMethodName + " has to be explicitly allowed"); + } + return match; + } + + private Method getMethod( + PTransform transform, + BuilderMethod builderMethod, + AllowedClass allowListClass) { + + Row builderMethodRow = decodeRow(builderMethod.getSchema(), builderMethod.getPayload()); + + List matchingMethods = + Arrays.stream(transform.getClass().getMethods()) + .filter(m -> isBuilderMethodForName(m, builderMethod.getName(), allowListClass)) + .filter(m -> parametersCompatible(m.getParameters(), builderMethodRow)) + .filter(m -> PTransform.class.isAssignableFrom(m.getReturnType())) + .collect(Collectors.toList()); + + if (matchingMethods.size() != 1) { + throw new RuntimeException( + "Expected to find exactly one matching method in transform " + + transform + + " for BuilderMethod" + + builderMethod + + " but found " + + matchingMethods.size()); + } + return matchingMethods.get(0); + } + + private static boolean isPrimitiveOrWrapperOrString(java.lang.Class type) { + return ClassUtils.isPrimitiveOrWrapper(type) || type == String.class; + } + + private Schema getParameterSchema(Class parameterClass) { + Schema parameterSchema; + try { + parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); + } catch (NoSuchSchemaException e) { + + SCHEMA_REGISTRY.registerSchemaProvider(parameterClass, new JavaFieldSchema()); + try { + parameterSchema = SCHEMA_REGISTRY.getSchema(parameterClass); + } catch (NoSuchSchemaException e1) { + throw new RuntimeException(e1); + } + if (parameterSchema != null && parameterSchema.getFieldCount() == 0) { + throw new RuntimeException( + "Could not determine a valid schema for parameter class " + parameterClass); + } + } + return parameterSchema; + } + + private boolean parametersCompatible( + java.lang.reflect.Parameter[] methodParameters, Row constructorRow) { + Schema constructorSchema = constructorRow.getSchema(); + if (methodParameters.length != constructorSchema.getFieldCount()) { + return false; + } + + for (int i = 0; i < methodParameters.length; i++) { + java.lang.reflect.Parameter parameterFromReflection = methodParameters[i]; + Field parameterFromPayload = constructorSchema.getField(i); + + String paramNameFromReflection = parameterFromReflection.getName(); + if (!paramNameFromReflection.startsWith("arg") + && !paramNameFromReflection.equals(parameterFromPayload.getName())) { + // Parameter name through reflection is from the class file (not through synthesizing, + // hence we can validate names) + return false; + } + + Class parameterClass = parameterFromReflection.getType(); + if (isPrimitiveOrWrapperOrString(parameterClass)) { + continue; + } + + // We perform additional validation for arrays and non-primitive types. + if (parameterClass.isArray()) { + Class arrayFieldClass = parameterClass.getComponentType(); + if (parameterFromPayload.getType().getTypeName() != TypeName.ARRAY) { + throw new RuntimeException( + "Expected a schema with a single array field but received " + + parameterFromPayload.getType().getTypeName()); + } + + // Following is a best-effort validation that may not cover all cases. Idea is to resolve + // ambiguities as much as possible to determine an exact match for the given set of + // parameters. If there are ambiguities, the expansion will fail. + if (!isPrimitiveOrWrapperOrString(arrayFieldClass)) { + @Nullable Collection values = constructorRow.getArray(i); + Schema arrayFieldSchema = getParameterSchema(arrayFieldClass); + if (arrayFieldSchema == null) { + throw new RuntimeException("Could not determine a schema for type " + arrayFieldClass); + } + if (values != null) { + @Nullable Row firstItem = values.iterator().next(); + if (firstItem != null && !(firstItem.getSchema().assignableTo(arrayFieldSchema))) { + return false; + } + } + } + } else if (constructorRow.getValue(i) instanceof Row) { + @Nullable Row parameterRow = constructorRow.getRow(i); + Schema schema = getParameterSchema(parameterClass); + if (schema == null) { + throw new RuntimeException("Could not determine a schema for type " + parameterClass); + } + if (parameterRow != null && !parameterRow.getSchema().assignableTo(schema)) { + return false; + } + } + } + return true; + } + + private @Nullable Object getDecodedValueFromRow( + Class type, Object valueFromRow, @Nullable Type genericType) { + if (isPrimitiveOrWrapperOrString(type)) { + if (!isPrimitiveOrWrapperOrString(valueFromRow.getClass())) { + throw new IllegalArgumentException( + "Expected a Java primitive value but received " + valueFromRow); + } + return valueFromRow; + } else if (type.isArray()) { + Class arrayComponentClass = type.getComponentType(); + return getDecodedArrayValueFromRow(arrayComponentClass, valueFromRow); + } else if (Collection.class.isAssignableFrom(type)) { + List originalList = (List) valueFromRow; + List decodedList = new ArrayList<>(); + for (Object obj : originalList) { + if (genericType instanceof ParameterizedType) { + Class elementType = + (Class) ((ParameterizedType) genericType).getActualTypeArguments()[0]; + decodedList.add(getDecodedValueFromRow(elementType, obj, null)); + } else { + throw new RuntimeException("Could not determine the generic type of the list"); + } + } + return decodedList; + } else if (valueFromRow instanceof Row) { + Row row = (Row) valueFromRow; + SerializableFunction fromRowFunc; + try { + fromRowFunc = SCHEMA_REGISTRY.getFromRowFunction(type); + } catch (NoSuchSchemaException e) { + throw new IllegalArgumentException( + "Could not determine the row function for class " + type, e); + } + return fromRowFunc.apply(row); + } + throw new RuntimeException("Could not decode the value from Row " + valueFromRow); + } + + private Object[] getParameterValues( + java.lang.reflect.Parameter[] parameters, Row constrtuctorRow, Type[] genericTypes) { + ArrayList parameterValues = new ArrayList<>(); + for (int i = 0; i < parameters.length; ++i) { + java.lang.reflect.Parameter parameter = parameters[i]; + Class parameterClass = parameter.getType(); + Object parameterValue = + getDecodedValueFromRow(parameterClass, constrtuctorRow.getValue(i), genericTypes[i]); + parameterValues.add(parameterValue); + } + + return parameterValues.toArray(); + } + + private Object[] getDecodedArrayValueFromRow(Class arrayComponentType, Object valueFromRow) { + List originalValues = (List) valueFromRow; + List decodedValues = new ArrayList<>(); + for (Object obj : originalValues) { + decodedValues.add(getDecodedValueFromRow(arrayComponentType, obj, null)); + } + + // We have to construct and return an array of the correct type. Otherwise Java reflection + // constructor/method invocations that use the returned value may consider the array as varargs + // (different parameters). + Object valueTypeArray = Array.newInstance(arrayComponentType, decodedValues.size()); + for (int i = 0; i < decodedValues.size(); i++) { + Array.set(valueTypeArray, i, arrayComponentType.cast(decodedValues.get(i))); + } + return (Object[]) valueTypeArray; + } + + private Constructor> findMappingConstructor( + Constructor[] constructors, JavaClassLookupPayload payload) { + Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); + + List> mappingConstructors = + Arrays.stream(constructors) + .filter(c -> c.getParameterCount() == payload.getConstructorSchema().getFieldsCount()) + .filter(c -> parametersCompatible(c.getParameters(), constructorRow)) + .collect(Collectors.toList()); + if (mappingConstructors.size() != 1) { + throw new RuntimeException( + "Expected to find a single mapping constructor but found " + mappingConstructors.size()); + } + return (Constructor>) mappingConstructors.get(0); + } + + private boolean isConstructorMethodForName( + Method method, String nameFromPayload, AllowedClass allowListClass) { + for (Annotation annotation : method.getAnnotations()) { + if (annotation instanceof MultiLanguageConstructorMethod) { + if (nameFromPayload.equals(((MultiLanguageConstructorMethod) annotation).name())) { + if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) { + return true; + } else { + throw new RuntimeException( + "Constructor method " + nameFromPayload + " needs to be explicitly allowed"); + } + } + } + } + if (method.getName().equals(nameFromPayload)) { + if (allowListClass.getAllowedConstructorMethods().contains(nameFromPayload)) { + return true; + } else { + throw new RuntimeException( + "Constructor method " + nameFromPayload + " needs to be explicitly allowed"); + } + } + return false; + } + + private Method findMappingConstructorMethod( + Method[] methods, JavaClassLookupPayload payload, AllowedClass allowListClass) { + + Row constructorRow = decodeRow(payload.getConstructorSchema(), payload.getConstructorPayload()); + + List mappingConstructorMethods = + Arrays.stream(methods) + .filter( + m -> isConstructorMethodForName(m, payload.getConstructorMethod(), allowListClass)) + .filter(m -> m.getParameterCount() == payload.getConstructorSchema().getFieldsCount()) + .filter(m -> parametersCompatible(m.getParameters(), constructorRow)) + .collect(Collectors.toList()); + + if (mappingConstructorMethods.size() != 1) { + throw new RuntimeException( + "Expected to find a single mapping constructor method but found " + + mappingConstructorMethods.size() + + " Payload was " + + payload); + } + return mappingConstructorMethods.get(0); + } + + @AutoValue + public abstract static class AllowList { + + public abstract String getVersion(); + + public abstract List getAllowedClasses(); + + @JsonCreator + static AllowList create( + @JsonProperty("version") String version, + @JsonProperty("allowedClasses") @javax.annotation.Nullable + List allowedClasses) { + if (allowedClasses == null) { + allowedClasses = new ArrayList<>(); + } + return new AutoValue_JavaClassLookupTransformProvider_AllowList(version, allowedClasses); + } + } + + @AutoValue + public abstract static class AllowedClass { + + public abstract String getClassName(); + + public abstract List getAllowedBuilderMethods(); + + public abstract List getAllowedConstructorMethods(); + + @JsonCreator + static AllowedClass create( + @JsonProperty("className") String className, + @JsonProperty("allowedBuilderMethods") @javax.annotation.Nullable + List allowedBuilderMethods, + @JsonProperty("allowedConstructorMethods") @javax.annotation.Nullable + List allowedConstructorMethods) { + if (allowedBuilderMethods == null) { + allowedBuilderMethods = new ArrayList<>(); + } + if (allowedConstructorMethods == null) { + allowedConstructorMethods = new ArrayList<>(); + } + return new AutoValue_JavaClassLookupTransformProvider_AllowedClass( + className, allowedBuilderMethods, allowedConstructorMethods); + } + } + + static Row decodeRow(SchemaApi.Schema schema, ByteString payload) { + Schema payloadSchema = SchemaTranslation.schemaFromProto(schema); + + if (payloadSchema.getFieldCount() == 0) { + return Row.withSchema(Schema.of()).build(); + } + + Row row; + try { + row = RowCoder.of(payloadSchema).decode(payload.newInput()); + } catch (IOException e) { + throw new RuntimeException("Error decoding payload", e); + } + return row; + } +} diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java new file mode 100644 index 0000000000000..3ee9ef5a7dac4 --- /dev/null +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageBuilderMethod.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Documented +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +public @interface MultiLanguageBuilderMethod { + String name(); +} diff --git a/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java new file mode 100644 index 0000000000000..e89f460edd86c --- /dev/null +++ b/sdks/java/expansion-service/src/main/java/org/apache/beam/sdk/expansion/service/MultiLanguageConstructorMethod.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Documented +@Target({ElementType.METHOD}) +@Retention(RetentionPolicy.RUNTIME) +public @interface MultiLanguageConstructorMethod { + String name(); +} diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java index 5e2a243dbd6ff..e8ecf469b784e 100644 --- a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/ExpansionServiceTest.java @@ -90,7 +90,8 @@ public class ExpansionServiceTest { /** Registers a single test transformation. */ @AutoService(ExpansionService.ExpansionServiceRegistrar.class) - public static class TestTransforms implements ExpansionService.ExpansionServiceRegistrar { + public static class TestTransformRegistrar implements ExpansionService.ExpansionServiceRegistrar { + @Override public Map knownTransforms() { return ImmutableMap.of(TEST_URN, spec -> Count.perElement()); @@ -140,9 +141,9 @@ public void testConstruct() { } @Test - public void testConstructGenerateSequence() { + public void testConstructGenerateSequenceWithRegistration() { ExternalTransforms.ExternalConfigurationPayload payload = - encodeRow( + encodeRowIntoExternalConfigurationPayload( Row.withSchema( Schema.of( Field.of("start", FieldType.INT64), @@ -176,7 +177,7 @@ public void testConstructGenerateSequence() { @Test public void testCompoundCodersForExternalConfiguration_setters() throws Exception { ExternalTransforms.ExternalConfigurationPayload externalConfig = - encodeRow( + encodeRowIntoExternalConfigurationPayload( Row.withSchema( Schema.of( Field.nullable("config_key1", FieldType.INT64), @@ -253,7 +254,7 @@ public void setConfigKey4(@Nullable Map> configKey4) { @Test public void testCompoundCodersForExternalConfiguration_schemas() throws Exception { ExternalTransforms.ExternalConfigurationPayload externalConfig = - encodeRow( + encodeRowIntoExternalConfigurationPayload( Row.withSchema( Schema.of( Field.nullable("configKey1", FieldType.INT64), @@ -320,7 +321,7 @@ abstract static class TestConfigSchema { @Test public void testExternalConfiguration_simpleSchema() throws Exception { ExternalTransforms.ExternalConfigurationPayload externalConfig = - encodeRow( + encodeRowIntoExternalConfigurationPayload( Row.withSchema( Schema.of( Field.of("bar", FieldType.STRING), @@ -350,7 +351,8 @@ abstract static class TestConfigSimpleSchema { abstract List getList(); } - private static ExternalTransforms.ExternalConfigurationPayload encodeRow(Row row) { + private static ExternalTransforms.ExternalConfigurationPayload + encodeRowIntoExternalConfigurationPayload(Row row) { ByteString.Output outputStream = ByteString.newOutput(); try { SchemaCoder.of(row.getSchema()).encode(row, outputStream); diff --git a/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java new file mode 100644 index 0000000000000..52441087d4ce3 --- /dev/null +++ b/sdks/java/expansion-service/src/test/java/org/apache/beam/sdk/expansion/service/JavaCLassLookupTransformProviderTest.java @@ -0,0 +1,1111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.expansion.service; + +import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.containsString; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.io.File; +import java.io.IOException; +import java.io.Serializable; +import java.net.URL; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.beam.model.expansion.v1.ExpansionApi; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.BuilderMethod; +import org.apache.beam.model.pipeline.v1.ExternalTransforms.ExpansionMethods; +import org.apache.beam.model.pipeline.v1.RunnerApi; +import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload; +import org.apache.beam.model.pipeline.v1.SchemaApi; +import org.apache.beam.runners.core.construction.ParDoTranslation; +import org.apache.beam.runners.core.construction.PipelineTranslation; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaCoder; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.ByteString; +import org.apache.beam.vendor.grpc.v1p36p0.com.google.protobuf.InvalidProtocolBufferException; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap; +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.Resources; +import org.hamcrest.Matchers; +import org.junit.BeforeClass; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for {@link JavaCLassLookupTransformProvider}. */ +@RunWith(JUnit4.class) +public class JavaCLassLookupTransformProviderTest { + + private static final String TEST_URN = "test:beam:transforms:count"; + + private static final String TEST_NAME = "TestName"; + + private static final String TEST_NAMESPACE = "namespace"; + + private static ExpansionService expansionService; + + @BeforeClass + public static void setupExpansionService() { + PipelineOptionsFactory.register(ExpansionServiceOptions.class); + URL allowListFile = Resources.getResource("./test_allowlist.yaml"); + System.out.println("Exists: " + new File(allowListFile.getPath()).exists()); + expansionService = + new ExpansionService( + new String[] {"--javaClassLookupAllowlistFile=" + allowListFile.getPath()}); + } + + static class DummyDoFn extends DoFn { + String strField1; + String strField2; + int intField1; + Double doubleWrapperField; + String[] strArrayField; + DummyComplexType complexTypeField; + DummyComplexType[] complexTypeArrayField; + List strListField; + List complexTypeListField; + + private DummyDoFn( + String strField1, + String strField2, + int intField1, + Double doubleWrapperField, + String[] strArrayField, + DummyComplexType complexTypeField, + DummyComplexType[] complexTypeArrayField, + List strListField, + List complexTypeListField) { + this.intField1 = intField1; + this.strField1 = strField1; + this.strField2 = strField2; + this.doubleWrapperField = doubleWrapperField; + this.strArrayField = strArrayField; + this.complexTypeField = complexTypeField; + this.complexTypeArrayField = complexTypeArrayField; + this.strListField = strListField; + this.complexTypeListField = complexTypeListField; + } + + @ProcessElement + public void processElement(ProcessContext c) { + c.output(c.element()); + } + } + + public static class DummyComplexType implements Serializable { + String complexTypeStrField; + int complexTypeIntField; + + public DummyComplexType() {} + + public DummyComplexType(String complexTypeStrField, int complexTypeIntField) { + this.complexTypeStrField = complexTypeStrField; + this.complexTypeIntField = complexTypeIntField; + } + + @Override + public int hashCode() { + return this.complexTypeStrField.hashCode() + this.complexTypeIntField * 31; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof DummyComplexType)) { + return false; + } + DummyComplexType toCompare = (DummyComplexType) obj; + return (this.complexTypeIntField == toCompare.complexTypeIntField) + && (this.complexTypeStrField.equals(toCompare.complexTypeStrField)); + } + } + + public static class DummyTransform extends PTransform> { + String strField1; + String strField2; + int intField1; + Double doubleWrapperField; + String[] strArrayField; + DummyComplexType complexTypeField; + DummyComplexType[] complexTypeArrayField; + List strListField; + List complexTypeListField; + + @Override + public PCollection expand(PBegin input) { + return input + .apply("MyCreateTransform", Create.of("aaa", "bbb", "ccc")) + .apply( + "MyParDoTransform", + ParDo.of( + new DummyDoFn( + this.strField1, + this.strField2, + this.intField1, + this.doubleWrapperField, + this.strArrayField, + this.complexTypeField, + this.complexTypeArrayField, + this.strListField, + this.complexTypeListField))); + } + } + + public static class DummyTransformWithConstructor extends DummyTransform { + + public DummyTransformWithConstructor(String strField1) { + this.strField1 = strField1; + } + } + + public static class DummyTransformWithConstructorAndBuilderMethods extends DummyTransform { + + public DummyTransformWithConstructorAndBuilderMethods(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithConstructorAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiArgumentBuilderMethod extends DummyTransform { + + public DummyTransformWithMultiArgumentBuilderMethod(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithMultiArgumentBuilderMethod withFields( + String strField2, int intField1) { + this.strField2 = strField2; + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiArgumentConstructor extends DummyTransform { + + public DummyTransformWithMultiArgumentConstructor(String strField1, String strField2) { + this.strField1 = strField1; + this.strField2 = strField2; + } + } + + public static class DummyTransformWithConstructorMethod extends DummyTransform { + + public static DummyTransformWithConstructorMethod from(String strField1) { + DummyTransformWithConstructorMethod transform = new DummyTransformWithConstructorMethod(); + transform.strField1 = strField1; + return transform; + } + } + + public static class DummyTransformWithConstructorMethodAndBuilderMethods extends DummyTransform { + + public static DummyTransformWithConstructorMethodAndBuilderMethods from(String strField1) { + DummyTransformWithConstructorMethodAndBuilderMethods transform = + new DummyTransformWithConstructorMethodAndBuilderMethods(); + transform.strField1 = strField1; + return transform; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + public DummyTransformWithConstructorMethodAndBuilderMethods withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithMultiLanguageAnnotations extends DummyTransform { + + @MultiLanguageConstructorMethod(name = "create_transform") + public static DummyTransformWithMultiLanguageAnnotations from(String strField1) { + DummyTransformWithMultiLanguageAnnotations transform = + new DummyTransformWithMultiLanguageAnnotations(); + transform.strField1 = strField1; + return transform; + } + + @MultiLanguageBuilderMethod(name = "abc") + public DummyTransformWithMultiLanguageAnnotations withStrField2(String strField2) { + this.strField2 = strField2; + return this; + } + + @MultiLanguageBuilderMethod(name = "xyz") + public DummyTransformWithMultiLanguageAnnotations withIntField1(int intField1) { + this.intField1 = intField1; + return this; + } + } + + public static class DummyTransformWithWrapperTypes extends DummyTransform { + public DummyTransformWithWrapperTypes(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithWrapperTypes withDoubleWrapperField(Double doubleWrapperField) { + this.doubleWrapperField = doubleWrapperField; + return this; + } + } + + public static class DummyTransformWithComplexTypes extends DummyTransform { + public DummyTransformWithComplexTypes(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithComplexTypes withComplexTypeField(DummyComplexType complexTypeField) { + this.complexTypeField = complexTypeField; + return this; + } + } + + public static class DummyTransformWithArray extends DummyTransform { + public DummyTransformWithArray(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithArray withStrArrayField(String[] strArrayField) { + this.strArrayField = strArrayField; + return this; + } + } + + public static class DummyTransformWithList extends DummyTransform { + public DummyTransformWithList(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithList withStrListField(List strListField) { + this.strListField = strListField; + return this; + } + } + + public static class DummyTransformWithComplexTypeArray extends DummyTransform { + public DummyTransformWithComplexTypeArray(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithComplexTypeArray withComplexTypeArrayField( + DummyComplexType[] complexTypeArrayField) { + this.complexTypeArrayField = complexTypeArrayField; + return this; + } + } + + public static class DummyTransformWithComplexTypeList extends DummyTransform { + public DummyTransformWithComplexTypeList(String strField1) { + this.strField1 = strField1; + } + + public DummyTransformWithComplexTypeList withComplexTypeListField( + List complexTypeListField) { + this.complexTypeListField = complexTypeListField; + return this; + } + } + + void testClassLookupExpansionRequestConstruction( + ExternalTransforms.JavaClassLookupPayload payload, Map fieldsToVerify) { + Pipeline p = Pipeline.create(); + + RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); + + ExpansionApi.ExpansionRequest request = + ExpansionApi.ExpansionRequest.newBuilder() + .setComponents(pipelineProto.getComponents()) + .setTransform( + RunnerApi.PTransform.newBuilder() + .setUniqueName(TEST_NAME) + .setSpec( + RunnerApi.FunctionSpec.newBuilder() + .setUrn(getUrn(ExpansionMethods.Enum.JAVA_CLASS_LOOKUP)) + .setPayload(payload.toByteString()))) + .setNamespace(TEST_NAMESPACE) + .build(); + ExpansionApi.ExpansionResponse response = expansionService.expand(request); + RunnerApi.PTransform expandedTransform = response.getTransform(); + assertEquals(TEST_NAMESPACE + TEST_NAME, expandedTransform.getUniqueName()); + assertThat(expandedTransform.getInputsCount(), Matchers.is(0)); + assertThat(expandedTransform.getOutputsCount(), Matchers.is(1)); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertEquals(2, expandedTransform.getSubtransformsCount()); + assertThat( + expandedTransform.getSubtransforms(0), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + assertThat( + expandedTransform.getSubtransforms(1), + anyOf(containsString("MyCreateTransform"), containsString("MyParDoTransform"))); + + org.apache.beam.model.pipeline.v1.RunnerApi.PTransform userParDoTransform = null; + for (String transformId : response.getComponents().getTransformsMap().keySet()) { + if (transformId.contains("ParMultiDo-Dummy-")) { + userParDoTransform = response.getComponents().getTransformsMap().get(transformId); + } + } + assertNotNull(userParDoTransform); + ParDoPayload parDoPayload = null; + try { + parDoPayload = ParDoPayload.parseFrom(userParDoTransform.getSpec().getPayload()); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + assertNotNull(parDoPayload); + DummyDoFn doFn = + (DummyDoFn) + ParDoTranslation.doFnWithExecutionInformationFromProto(parDoPayload.getDoFn()) + .getDoFn(); + System.out.println("DoFn" + doFn); + + List verifiedFields = new ArrayList<>(); + if (fieldsToVerify.keySet().contains("strField1")) { + assertEquals(doFn.strField1, fieldsToVerify.get("strField1")); + verifiedFields.add("strField1"); + } + if (fieldsToVerify.keySet().contains("strField2")) { + assertEquals(doFn.strField2, fieldsToVerify.get("strField2")); + verifiedFields.add("strField2"); + } + if (fieldsToVerify.keySet().contains("intField1")) { + assertEquals(doFn.intField1, fieldsToVerify.get("intField1")); + verifiedFields.add("intField1"); + } + if (fieldsToVerify.keySet().contains("doubleWrapperField")) { + assertEquals(doFn.doubleWrapperField, fieldsToVerify.get("doubleWrapperField")); + verifiedFields.add("doubleWrapperField"); + } + if (fieldsToVerify.containsKey("complexTypeStrField")) { + assertEquals( + doFn.complexTypeField.complexTypeStrField, fieldsToVerify.get("complexTypeStrField")); + verifiedFields.add("complexTypeStrField"); + } + if (fieldsToVerify.containsKey("complexTypeIntField")) { + assertEquals( + doFn.complexTypeField.complexTypeIntField, fieldsToVerify.get("complexTypeIntField")); + verifiedFields.add("complexTypeIntField"); + } + + if (fieldsToVerify.keySet().contains("strArrayField")) { + assertArrayEquals(doFn.strArrayField, (String[]) fieldsToVerify.get("strArrayField")); + verifiedFields.add("strArrayField"); + } + + if (fieldsToVerify.keySet().contains("strListField")) { + assertEquals(doFn.strListField, (List) fieldsToVerify.get("strListField")); + verifiedFields.add("strListField"); + } + + if (fieldsToVerify.keySet().contains("complexTypeArrayField")) { + assertArrayEquals( + doFn.complexTypeArrayField, + (DummyComplexType[]) fieldsToVerify.get("complexTypeArrayField")); + verifiedFields.add("complexTypeArrayField"); + } + + if (fieldsToVerify.keySet().contains("complexTypeListField")) { + assertEquals(doFn.complexTypeListField, (List) fieldsToVerify.get("complexTypeListField")); + verifiedFields.add("complexTypeListField"); + } + + List unverifiedFields = new ArrayList<>(fieldsToVerify.keySet()); + unverifiedFields.removeAll(verifiedFields); + if (!unverifiedFields.isEmpty()) { + throw new RuntimeException("Failed to verify some fields: " + unverifiedFields); + } + } + + @Test + public void testJavaClassLookupWithConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("strField1", "test_str_1")); + } + + @Test + public void testJavaClassLookupWithConstructorMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod"); + + payloadBuilder.setConstructorMethod("from"); + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("strField1", "test_str_1")); + } + + @Test + public void testJavaClassLookupWithConstructorAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + builderMethodRow = + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2", "intField1", 10)); + } + + @Test + public void testJavaClassLookupWithMultiArgumentConstructor() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor"); + + Row constructorRow = + Row.withSchema( + Schema.of( + Field.of("strField1", FieldType.STRING), + Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .withFieldValue("strField2", "test_str_2") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2")); + } + + @Test + public void testJavaClassLookupWithMultiArgumentBuilderMethod() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withFields"); + Row builderMethodRow = + Row.withSchema( + Schema.of( + Field.of("strField2", FieldType.STRING), + Field.of("intField1", FieldType.INT32))) + .withFieldValue("strField2", "test_str_2") + .withFieldValue("intField1", 10) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2", "intField1", 10)); + } + + @Test + public void testJavaClassLookupWithWrapperTypes() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withDoubleWrapperField"); + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("doubleWrapperField", FieldType.DOUBLE))) + .withFieldValue("doubleWrapperField", 123.56) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("doubleWrapperField", 123.56)); + } + + @Test + public void testJavaClassLookupWithComplexTypes() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + Schema complexTypeSchema = + Schema.builder() + .addStringField("complexTypeStrField") + .addInt32Field("complexTypeIntField") + .build(); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withComplexTypeField"); + + Row builderMethodParamRow = + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_1") + .withFieldValue("complexTypeIntField", 123) + .build(); + + Schema builderMethodSchema = + Schema.builder().addRowField("complexTypeField", complexTypeSchema).build(); + Row builderMethodRow = + Row.withSchema(builderMethodSchema) + .withFieldValue("complexTypeField", builderMethodParamRow) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("complexTypeStrField", "complex_type_str_1", "complexTypeIntField", 123)); + } + + @Test + public void testJavaClassLookupWithSimpleArrayType() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrArrayField"); + + Schema builderMethodSchema = + Schema.builder().addArrayField("strArrayField", FieldType.STRING).build(); + + Row builderMethodRow = + Row.withSchema(builderMethodSchema) + .withFieldValue( + "strArrayField", ImmutableList.of("test_str_1", "test_str_2", "test_str_3")) + .build(); + + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + String[] resultArray = {"test_str_1", "test_str_2", "test_str_3"}; + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("strArrayField", resultArray)); + } + + @Test + public void testJavaClassLookupWithSimpleListType() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrListField"); + + Schema builderMethodSchema = + Schema.builder().addIterableField("strListField", FieldType.STRING).build(); + + Row builderMethodRow = + Row.withSchema(builderMethodSchema) + .withFieldValue( + "strListField", ImmutableList.of("test_str_1", "test_str_2", "test_str_3")) + .build(); + + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + List resultList = new ArrayList<>(); + resultList.add("test_str_1"); + resultList.add("test_str_2"); + resultList.add("test_str_3"); + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("strListField", resultList)); + } + + @Test + public void testJavaClassLookupWithComplexArrayType() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray"); + + Schema complexTypeSchema = + Schema.builder() + .addStringField("complexTypeStrField") + .addInt32Field("complexTypeIntField") + .build(); + + Schema builderMethodSchema = + Schema.builder() + .addArrayField("complexTypeArrayField", FieldType.row(complexTypeSchema)) + .build(); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + List complexTypeList = new ArrayList<>(); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_1") + .withFieldValue("complexTypeIntField", 123) + .build()); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_2") + .withFieldValue("complexTypeIntField", 456) + .build()); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_3") + .withFieldValue("complexTypeIntField", 789) + .build()); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withComplexTypeArrayField"); + + Row builderMethodRow = + Row.withSchema(builderMethodSchema) + .withFieldValue("complexTypeArrayField", complexTypeList) + .build(); + + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + ArrayList resultList = new ArrayList<>(); + resultList.add(new DummyComplexType("complex_type_str_1", 123)); + resultList.add(new DummyComplexType("complex_type_str_2", 456)); + resultList.add(new DummyComplexType("complex_type_str_3", 789)); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("complexTypeArrayField", resultList.toArray(new DummyComplexType[0]))); + } + + @Test + public void testJavaClassLookupWithComplexListType() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList"); + + Schema complexTypeSchema = + Schema.builder() + .addStringField("complexTypeStrField") + .addInt32Field("complexTypeIntField") + .build(); + + Schema builderMethodSchema = + Schema.builder() + .addIterableField("complexTypeListField", FieldType.row(complexTypeSchema)) + .build(); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + List complexTypeList = new ArrayList<>(); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_1") + .withFieldValue("complexTypeIntField", 123) + .build()); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_2") + .withFieldValue("complexTypeIntField", 456) + .build()); + complexTypeList.add( + Row.withSchema(complexTypeSchema) + .withFieldValue("complexTypeStrField", "complex_type_str_3") + .withFieldValue("complexTypeIntField", 789) + .build()); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withComplexTypeListField"); + + Row builderMethodRow = + Row.withSchema(builderMethodSchema) + .withFieldValue("complexTypeListField", complexTypeList) + .build(); + + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + ArrayList resultList = new ArrayList<>(); + resultList.add(new DummyComplexType("complex_type_str_1", 123)); + resultList.add(new DummyComplexType("complex_type_str_2", 456)); + resultList.add(new DummyComplexType("complex_type_str_3", 789)); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of("complexTypeListField", resultList)); + } + + @Test + public void testJavaClassLookupWithConstructorMethodAndBuilderMethods() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withIntField1"); + + builderMethodRow = + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2", "intField1", 10)); + } + + @Test + public void testJavaClassLookupWithSimplifiedBuilderMethodNames() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods"); + payloadBuilder.setConstructorMethod("from"); + + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("strField2"); + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("intField1"); + builderMethodRow = + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2", "intField1", 10)); + } + + @Test + public void testJavaClassLookupWithAnnotations() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations"); + payloadBuilder.setConstructorMethod("create_transform"); + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("abc"); + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("strField2", FieldType.STRING))) + .withFieldValue("strField2", "test_str_2") + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("xyz"); + builderMethodRow = + Row.withSchema(Schema.of(Field.of("intField1", FieldType.INT32))) + .withFieldValue("intField1", 10) + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), + ImmutableMap.of("strField1", "test_str_1", "strField2", "test_str_2", "intField1", 10)); + } + + @Test + public void testJavaClassLookupClassNotAvailable() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$UnavailableClass"); + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + RuntimeException thrown = + assertThrows( + RuntimeException.class, + () -> + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of())); + assertTrue(thrown.getMessage().contains("does not enable")); + } + + @Test + public void testJavaClassLookupIncorrectConstructionParameter() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor"); + Row constructorRow = + Row.withSchema(Schema.of(Field.of("incorrectField", FieldType.STRING))) + .withFieldValue("incorrectField", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + RuntimeException thrown = + assertThrows( + RuntimeException.class, + () -> + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of())); + assertTrue(thrown.getMessage().contains("Expected to find a single mapping constructor")); + } + + @Test + public void testJavaClassLookupIncorrectBuilderMethodParameter() { + ExternalTransforms.JavaClassLookupPayload.Builder payloadBuilder = + ExternalTransforms.JavaClassLookupPayload.newBuilder(); + payloadBuilder.setClassName( + "org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods"); + Row constructorRow = + Row.withSchema(Schema.of(Field.of("strField1", FieldType.STRING))) + .withFieldValue("strField1", "test_str_1") + .build(); + + payloadBuilder.setConstructorSchema(getProtoSchemaFromRow(constructorRow)); + payloadBuilder.setConstructorPayload(getProtoPayloadFromRow(constructorRow)); + + BuilderMethod.Builder builderMethodBuilder = BuilderMethod.newBuilder(); + builderMethodBuilder.setName("withStrField2"); + Row builderMethodRow = + Row.withSchema(Schema.of(Field.of("incorrectParam", FieldType.STRING))) + .withFieldValue("incorrectParam", "test_str_2") + .build(); + builderMethodBuilder.setSchema(getProtoSchemaFromRow(builderMethodRow)); + builderMethodBuilder.setPayload(getProtoPayloadFromRow(builderMethodRow)); + + payloadBuilder.addBuilderMethods(builderMethodBuilder); + + RuntimeException thrown = + assertThrows( + RuntimeException.class, + () -> + testClassLookupExpansionRequestConstruction( + payloadBuilder.build(), ImmutableMap.of())); + assertTrue(thrown.getMessage().contains("Expected to find exactly one matching method")); + } + + private SchemaApi.Schema getProtoSchemaFromRow(Row row) { + return SchemaTranslation.schemaToProto(row.getSchema(), true); + } + + private ByteString getProtoPayloadFromRow(Row row) { + ByteString.Output outputStream = ByteString.newOutput(); + try { + SchemaCoder.of(row.getSchema()).encode(row, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return outputStream.toByteString(); + } +} diff --git a/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml new file mode 100644 index 0000000000000..ad11523be3f02 --- /dev/null +++ b/sdks/java/expansion-service/src/test/resources/test_allowlist.yaml @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +version: v1 +allowedClasses: +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructor +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethod + allowedConstructorMethods: + - from +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorAndBuilderMethods + allowedBuilderMethods: + - withStrField2 + - withIntField1 +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentConstructor +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiArgumentBuilderMethod + allowedBuilderMethods: + - withFields +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithConstructorMethodAndBuilderMethods + allowedConstructorMethods: + - from + allowedBuilderMethods: + - withStrField2 + - withIntField1 + - strField2 + - intField1 +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithMultiLanguageAnnotations + allowedConstructorMethods: + - create_transform + allowedBuilderMethods: + - abc + - xyz +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithWrapperTypes + allowedBuilderMethods: + - withDoubleWrapperField +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypes + allowedBuilderMethods: + - withComplexTypeField +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithArray + allowedBuilderMethods: + - withStrArrayField +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithList + allowedBuilderMethods: + - withStrListField +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeArray + allowedBuilderMethods: + - withComplexTypeArrayField +- className: org.apache.beam.sdk.expansion.service.JavaCLassLookupTransformProviderTest$DummyTransformWithComplexTypeList + allowedBuilderMethods: + - withComplexTypeListField + +