From e4f8fe905d3c182416e1b69bf60bfbc0e8348f8a Mon Sep 17 00:00:00 2001 From: Chen Zhiling Date: Tue, 7 Apr 2020 18:45:45 +0800 Subject: [PATCH 1/4] Add general storage API and refactor existing store implementations (#567) * Add storage interfaces, basic file structure (#529) * Add storage interfaces, basic file structure * Apply spotless, add comments * Move parseResponse and isEmpty to response object * Make changes to write interface to be more beam-like * Pass feature specs to the retriever * Pass feature specs to online retriever * Add FeatureSetRequest * Add mistakenly removed TestUtil * Add mistakenly removed TestUtil * Add BigQuery storage (#546) * Add Redis storage implementation (#547) * Add Redis storage * Remove staleness check; can be checked at the service level * Remove staleness related tests * Add dependencies to top level pom * Clean up code * Change serving and ingestion to use storage API (#553) * Change serving and ingestion to use storage API * Remove extra exclusion clause * Storage refactor API and docstring tweaks (#569) * API and docstring tweaks * Fix javadoc linting errors * Apply spotless * Fix javadoc formatting * Drop result from HistoricalRetrievalResult constructors * Change pipeline to use DeadletterSink API (#586) * Add better code docs to storage refactor (#601) * Add better code documentation, make GetFeastServingInfo independent of retriever * Make getStagingLocation method of historical retriever * Apply spotless * Clean up dependencies, remove exclusions at serving (#607) * Clean up OnlineServingService code (#605) * Clean up OnlineServingService code to be more readable * Revert Metrics * Rename storage API packages to nouns --- ingestion/pom.xml | 18 + .../main/java/feast/ingestion/ImportJob.java | 78 ++-- .../ingestion/transform/ReadFromSource.java | 2 +- .../transform/ValidateFeatureRows.java | 9 +- .../WriteFailedElementToBigQuery.java | 2 +- .../ingestion/transform/WriteToStore.java | 168 ------- .../fn/KafkaRecordToFeatureRowDoFn.java | 3 +- .../transform/fn/ValidateFeatureRowDoFn.java | 2 +- .../WriteDeadletterRowMetricsDoFn.java | 2 +- .../metrics/WriteFailureMetricsTransform.java | 52 +++ ...java => WriteSuccessMetricsTransform.java} | 78 ++-- .../java/feast/ingestion/utils/SpecUtil.java | 7 +- .../java/feast/ingestion/utils/StoreUtil.java | 181 +------- .../feast/ingestion/values/FeatureSet.java | 6 +- .../redis/FeatureRowToRedisMutationDoFn.java | 116 ----- .../store/serving/redis/RedisCustomIO.java | 341 -------------- .../java/feast/ingestion/ImportJobTest.java | 4 +- .../transform/ValidateFeatureRowsTest.java | 159 +++---- .../feast/ingestion/utils/StoreUtilTest.java | 211 --------- .../serving/redis/RedisCustomIOTest.java | 238 ---------- .../src/test/java/feast/test/TestUtil.java | 48 +- pom.xml | 2 + serving/pom.xml | 43 +- .../configuration/ServingServiceConfig.java | 41 +- .../service/BigQueryServingService.java | 282 ------------ .../service/HistoricalServingService.java | 119 +++++ .../serving/service/OnlineServingService.java | 176 ++++++++ .../serving/service/RedisServingService.java | 345 -------------- .../feast/serving/service/ServingService.java | 63 +++ .../serving/specs/CachedSpecService.java | 1 + .../main/java/feast/serving/util/RefUtil.java | 8 + .../service/CachedSpecServiceTest.java | 2 +- ...est.java => OnlineServingServiceTest.java} | 215 ++------- storage/api/pom.xml | 72 +++ .../api/retriever}/FeatureSetRequest.java | 9 +- .../retriever/HistoricalRetrievalResult.java | 100 ++++ .../api/retriever/HistoricalRetriever.java | 49 ++ .../api/retriever/OnlineRetriever.java | 40 ++ .../storage/api/writer/DeadletterSink.java | 38 ++ .../storage/api/writer/FailedElement.java | 83 ++++ .../feast/storage/api/writer/FeatureSink.java | 54 +++ .../feast/storage/api/writer/WriteResult.java | 97 ++++ .../common}/retry/BackOffExecutor.java | 2 +- .../storage/common}/retry/Retriable.java | 2 +- .../storage/common/testing/TestUtil.java | 188 ++++++++ storage/connectors/bigquery/pom.xml | 94 ++++ .../connectors/bigquery/common/TypeUtil.java | 66 +++ .../BigQueryHistoricalRetriever.java | 426 ++++++++++-------- .../retriever/FeatureSetQueryInfo.java | 8 +- .../bigquery/retriever}/QueryTemplater.java | 20 +- .../bigquery/retriever}/SubqueryCallable.java | 24 +- .../writer/BigQueryDeadletterSink.java | 133 ++++++ .../bigquery/writer/BigQueryFeatureSink.java | 188 ++++++++ .../bigquery/writer/BigQueryWrite.java | 107 +++++ .../writer}/FeatureRowToTableRow.java | 2 +- .../bigquery/writer}/GetTableDestination.java | 2 +- .../schemas/deadletter_table_schema.json | 34 ++ .../resources/templates/join_featuresets.sql | 24 + .../templates/single_featureset_pit_join.sql | 90 ++++ storage/connectors/pom.xml | 51 +++ storage/connectors/redis/pom.xml | 82 ++++ .../redis/retriever}/FeatureRowDecoder.java | 2 +- .../redis/retriever/RedisOnlineRetriever.java | 204 +++++++++ .../redis/writer/RedisCustomIO.java | 292 ++++++++++++ .../redis/writer/RedisFeatureSink.java | 74 +++ .../redis/writer}/RedisIngestionClient.java | 4 +- .../RedisStandaloneIngestionClient.java | 9 +- .../retriever}/FeatureRowDecoderTest.java | 2 +- .../retriever/RedisOnlineRetrieverTest.java | 262 +++++++++++ .../connectors/redis/test/TestUtil.java | 44 ++ .../redis/writer/RedisFeatureSinkTest.java | 422 +++++++++-------- 71 files changed, 3711 insertions(+), 2711 deletions(-) delete mode 100644 ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java create mode 100644 ingestion/src/main/java/feast/ingestion/transform/metrics/WriteFailureMetricsTransform.java rename ingestion/src/main/java/feast/ingestion/transform/metrics/{WriteMetricsTransform.java => WriteSuccessMetricsTransform.java} (65%) delete mode 100644 ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java delete mode 100644 ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java delete mode 100644 ingestion/src/test/java/feast/ingestion/utils/StoreUtilTest.java delete mode 100644 ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java delete mode 100644 serving/src/main/java/feast/serving/service/BigQueryServingService.java create mode 100644 serving/src/main/java/feast/serving/service/HistoricalServingService.java create mode 100644 serving/src/main/java/feast/serving/service/OnlineServingService.java delete mode 100644 serving/src/main/java/feast/serving/service/RedisServingService.java rename serving/src/test/java/feast/serving/service/{RedisServingServiceTest.java => OnlineServingServiceTest.java} (72%) create mode 100644 storage/api/pom.xml rename {serving/src/main/java/feast/serving/specs => storage/api/src/main/java/feast/storage/api/retriever}/FeatureSetRequest.java (84%) create mode 100644 storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetrievalResult.java create mode 100644 storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetriever.java create mode 100644 storage/api/src/main/java/feast/storage/api/retriever/OnlineRetriever.java create mode 100644 storage/api/src/main/java/feast/storage/api/writer/DeadletterSink.java create mode 100644 storage/api/src/main/java/feast/storage/api/writer/FailedElement.java create mode 100644 storage/api/src/main/java/feast/storage/api/writer/FeatureSink.java create mode 100644 storage/api/src/main/java/feast/storage/api/writer/WriteResult.java rename {ingestion/src/main/java/feast => storage/api/src/main/java/feast/storage/common}/retry/BackOffExecutor.java (98%) rename {ingestion/src/main/java/feast => storage/api/src/main/java/feast/storage/common}/retry/Retriable.java (95%) create mode 100644 storage/api/src/main/java/feast/storage/common/testing/TestUtil.java create mode 100644 storage/connectors/bigquery/pom.xml create mode 100644 storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/common/TypeUtil.java rename serving/src/main/java/feast/serving/store/bigquery/BatchRetrievalQueryRunnable.java => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/BigQueryHistoricalRetriever.java (52%) rename serving/src/main/java/feast/serving/store/bigquery/model/FeatureSetInfo.java => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/FeatureSetQueryInfo.java (90%) rename {serving/src/main/java/feast/serving/store/bigquery => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever}/QueryTemplater.java (90%) rename {serving/src/main/java/feast/serving/store/bigquery => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever}/SubqueryCallable.java (70%) create mode 100644 storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryDeadletterSink.java create mode 100644 storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java create mode 100644 storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryWrite.java rename {ingestion/src/main/java/feast/store/serving/bigquery => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer}/FeatureRowToTableRow.java (98%) rename {ingestion/src/main/java/feast/store/serving/bigquery => storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer}/GetTableDestination.java (97%) create mode 100644 storage/connectors/bigquery/src/main/resources/schemas/deadletter_table_schema.json create mode 100644 storage/connectors/bigquery/src/main/resources/templates/join_featuresets.sql create mode 100644 storage/connectors/bigquery/src/main/resources/templates/single_featureset_pit_join.sql create mode 100644 storage/connectors/pom.xml create mode 100644 storage/connectors/redis/pom.xml rename {serving/src/main/java/feast/serving/encoding => storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever}/FeatureRowDecoder.java (98%) create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java create mode 100644 storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisFeatureSink.java rename {ingestion/src/main/java/feast/store/serving/redis => storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer}/RedisIngestionClient.java (92%) rename {ingestion/src/main/java/feast/store/serving/redis => storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer}/RedisStandaloneIngestionClient.java (93%) rename {serving/src/test/java/feast/serving/encoding => storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever}/FeatureRowDecoderTest.java (98%) create mode 100644 storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/RedisOnlineRetrieverTest.java create mode 100644 storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java rename ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java => storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java (50%) diff --git a/ingestion/pom.xml b/ingestion/pom.xml index 47204c33f29..9386d066bfd 100644 --- a/ingestion/pom.xml +++ b/ingestion/pom.xml @@ -101,6 +101,24 @@ ${project.version} + + dev.feast + feast-storage-api + ${project.version} + + + + dev.feast + feast-storage-connector-redis + ${project.version} + + + + dev.feast + feast-storage-connector-bigquery + ${project.version} + + com.google.auto.value auto-value-annotations diff --git a/ingestion/src/main/java/feast/ingestion/ImportJob.java b/ingestion/src/main/java/feast/ingestion/ImportJob.java index c4973ce3cae..ef6039e5536 100644 --- a/ingestion/src/main/java/feast/ingestion/ImportJob.java +++ b/ingestion/src/main/java/feast/ingestion/ImportJob.java @@ -17,9 +17,11 @@ package feast.ingestion; import static feast.ingestion.utils.SpecUtil.getFeatureSetReference; +import static feast.ingestion.utils.StoreUtil.getFeatureSink; import com.google.protobuf.InvalidProtocolBufferException; import feast.core.FeatureSetProto.FeatureSet; +import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.SourceProto.Source; import feast.core.StoreProto.Store; import feast.ingestion.options.BZip2Decompressor; @@ -27,13 +29,14 @@ import feast.ingestion.options.StringListStreamConverter; import feast.ingestion.transform.ReadFromSource; import feast.ingestion.transform.ValidateFeatureRows; -import feast.ingestion.transform.WriteFailedElementToBigQuery; -import feast.ingestion.transform.WriteToStore; -import feast.ingestion.transform.metrics.WriteMetricsTransform; -import feast.ingestion.utils.ResourceUtil; +import feast.ingestion.transform.metrics.WriteFailureMetricsTransform; +import feast.ingestion.transform.metrics.WriteSuccessMetricsTransform; import feast.ingestion.utils.SpecUtil; -import feast.ingestion.utils.StoreUtil; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.DeadletterSink; +import feast.storage.api.writer.FailedElement; +import feast.storage.api.writer.FeatureSink; +import feast.storage.api.writer.WriteResult; +import feast.storage.connectors.bigquery.writer.BigQueryDeadletterSink; import feast.types.FeatureRowProto.FeatureRow; import java.io.IOException; import java.util.HashMap; @@ -93,17 +96,24 @@ public static PipelineResult runPipeline(ImportOptions options) throws IOExcepti SpecUtil.getSubscribedFeatureSets(store.getSubscriptionsList(), featureSets); // Generate tags by key - Map featureSetsByKey = new HashMap<>(); + Map featureSetSpecsByKey = new HashMap<>(); subscribedFeatureSets.stream() .forEach( fs -> { - String ref = getFeatureSetReference(fs); - featureSetsByKey.put(ref, fs); + String ref = getFeatureSetReference(fs.getSpec()); + featureSetSpecsByKey.put(ref, fs.getSpec()); }); + FeatureSink featureSink = getFeatureSink(store, featureSetSpecsByKey); + // TODO: make the source part of the job initialisation options Source source = subscribedFeatureSets.get(0).getSpec().getSource(); + for (FeatureSet featureSet : subscribedFeatureSets) { + // Ensure Store has valid configuration and Feast can access it. + featureSink.prepareWrite(featureSet); + } + // Step 1. Read messages from Feast Source as FeatureRow. PCollectionTuple convertedFeatureRows = pipeline.apply( @@ -114,58 +124,48 @@ public static PipelineResult runPipeline(ImportOptions options) throws IOExcepti .setFailureTag(DEADLETTER_OUT) .build()); - for (FeatureSet featureSet : subscribedFeatureSets) { - // Ensure Store has valid configuration and Feast can access it. - StoreUtil.setupStore(store, featureSet); - } - // Step 2. Validate incoming FeatureRows PCollectionTuple validatedRows = convertedFeatureRows .get(FEATURE_ROW_OUT) .apply( ValidateFeatureRows.newBuilder() - .setFeatureSets(featureSetsByKey) + .setFeatureSetSpecs(featureSetSpecsByKey) .setSuccessTag(FEATURE_ROW_OUT) .setFailureTag(DEADLETTER_OUT) .build()); // Step 3. Write FeatureRow to the corresponding Store. - validatedRows - .get(FEATURE_ROW_OUT) - .apply( - "WriteFeatureRowToStore", - WriteToStore.newBuilder().setFeatureSets(featureSetsByKey).setStore(store).build()); + WriteResult writeFeatureRows = + validatedRows.get(FEATURE_ROW_OUT).apply("WriteFeatureRowToStore", featureSink.writer()); // Step 4. Write FailedElements to a dead letter table in BigQuery. if (options.getDeadLetterTableSpec() != null) { + // TODO: make deadletter destination type configurable + DeadletterSink deadletterSink = + new BigQueryDeadletterSink(options.getDeadLetterTableSpec()); + convertedFeatureRows .get(DEADLETTER_OUT) - .apply( - "WriteFailedElements_ReadFromSource", - WriteFailedElementToBigQuery.newBuilder() - .setJsonSchema(ResourceUtil.getDeadletterTableSchemaJson()) - .setTableSpec(options.getDeadLetterTableSpec()) - .build()); + .apply("WriteFailedElements_ReadFromSource", deadletterSink.write()); validatedRows .get(DEADLETTER_OUT) - .apply( - "WriteFailedElements_ValidateRows", - WriteFailedElementToBigQuery.newBuilder() - .setJsonSchema(ResourceUtil.getDeadletterTableSchemaJson()) - .setTableSpec(options.getDeadLetterTableSpec()) - .build()); + .apply("WriteFailedElements_ValidateRows", deadletterSink.write()); + + writeFeatureRows + .getFailedInserts() + .apply("WriteFailedElements_WriteFeatureRowToStore", deadletterSink.write()); } // Step 5. Write metrics to a metrics sink. - validatedRows.apply( - "WriteMetrics", - WriteMetricsTransform.newBuilder() - .setStoreName(store.getName()) - .setSuccessTag(FEATURE_ROW_OUT) - .setFailureTag(DEADLETTER_OUT) - .build()); + writeFeatureRows + .getSuccessfulInserts() + .apply("WriteSuccessMetrics", WriteSuccessMetricsTransform.create(store.getName())); + + writeFeatureRows + .getFailedInserts() + .apply("WriteFailureMetrics", WriteFailureMetricsTransform.create(store.getName())); } return pipeline.run(); diff --git a/ingestion/src/main/java/feast/ingestion/transform/ReadFromSource.java b/ingestion/src/main/java/feast/ingestion/transform/ReadFromSource.java index 65e95b287dc..fb013d0375f 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/ReadFromSource.java +++ b/ingestion/src/main/java/feast/ingestion/transform/ReadFromSource.java @@ -21,7 +21,7 @@ import feast.core.SourceProto.Source; import feast.core.SourceProto.SourceType; import feast.ingestion.transform.fn.KafkaRecordToFeatureRowDoFn; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.FailedElement; import feast.types.FeatureRowProto.FeatureRow; import org.apache.beam.sdk.io.kafka.KafkaIO; import org.apache.beam.sdk.transforms.PTransform; diff --git a/ingestion/src/main/java/feast/ingestion/transform/ValidateFeatureRows.java b/ingestion/src/main/java/feast/ingestion/transform/ValidateFeatureRows.java index 5ca6a710f62..06df06c074c 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/ValidateFeatureRows.java +++ b/ingestion/src/main/java/feast/ingestion/transform/ValidateFeatureRows.java @@ -19,8 +19,8 @@ import com.google.auto.value.AutoValue; import feast.core.FeatureSetProto; import feast.ingestion.transform.fn.ValidateFeatureRowDoFn; -import feast.ingestion.values.FailedElement; import feast.ingestion.values.FeatureSet; +import feast.storage.api.writer.FailedElement; import feast.types.FeatureRowProto.FeatureRow; import java.util.Map; import java.util.stream.Collectors; @@ -36,7 +36,7 @@ public abstract class ValidateFeatureRows extends PTransform, PCollectionTuple> { - public abstract Map getFeatureSets(); + public abstract Map getFeatureSetSpecs(); public abstract TupleTag getSuccessTag(); @@ -49,7 +49,8 @@ public static Builder newBuilder() { @AutoValue.Builder public abstract static class Builder { - public abstract Builder setFeatureSets(Map featureSets); + public abstract Builder setFeatureSetSpecs( + Map featureSets); public abstract Builder setSuccessTag(TupleTag successTag); @@ -62,7 +63,7 @@ public abstract static class Builder { public PCollectionTuple expand(PCollection input) { Map featureSets = - getFeatureSets().entrySet().stream() + getFeatureSetSpecs().entrySet().stream() .map(e -> Pair.of(e.getKey(), new FeatureSet(e.getValue()))) .collect(Collectors.toMap(Pair::getLeft, Pair::getRight)); diff --git a/ingestion/src/main/java/feast/ingestion/transform/WriteFailedElementToBigQuery.java b/ingestion/src/main/java/feast/ingestion/transform/WriteFailedElementToBigQuery.java index cda590b21aa..0da281790c5 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/WriteFailedElementToBigQuery.java +++ b/ingestion/src/main/java/feast/ingestion/transform/WriteFailedElementToBigQuery.java @@ -18,7 +18,7 @@ import com.google.api.services.bigquery.model.TableRow; import com.google.auto.value.AutoValue; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.FailedElement; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; diff --git a/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java b/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java deleted file mode 100644 index 4e9082f5554..00000000000 --- a/ingestion/src/main/java/feast/ingestion/transform/WriteToStore.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.ingestion.transform; - -import com.google.api.services.bigquery.model.TableDataInsertAllResponse.InsertErrors; -import com.google.api.services.bigquery.model.TableRow; -import com.google.auto.value.AutoValue; -import feast.core.FeatureSetProto.FeatureSet; -import feast.core.StoreProto.Store; -import feast.core.StoreProto.Store.BigQueryConfig; -import feast.core.StoreProto.Store.StoreType; -import feast.ingestion.options.ImportOptions; -import feast.ingestion.utils.ResourceUtil; -import feast.ingestion.values.FailedElement; -import feast.store.serving.bigquery.FeatureRowToTableRow; -import feast.store.serving.bigquery.GetTableDestination; -import feast.store.serving.redis.FeatureRowToRedisMutationDoFn; -import feast.store.serving.redis.RedisCustomIO; -import feast.types.FeatureRowProto.FeatureRow; -import java.io.IOException; -import java.util.Map; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.Method; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryInsertError; -import org.apache.beam.sdk.io.gcp.bigquery.InsertRetryPolicy; -import org.apache.beam.sdk.io.gcp.bigquery.WriteResult; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.MapElements; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PDone; -import org.apache.beam.sdk.values.TypeDescriptors; -import org.slf4j.Logger; - -@AutoValue -public abstract class WriteToStore extends PTransform, PDone> { - - private static final Logger log = org.slf4j.LoggerFactory.getLogger(WriteToStore.class); - - public static final String METRIC_NAMESPACE = "WriteToStore"; - public static final String ELEMENTS_WRITTEN_METRIC = "elements_written"; - - private static final Counter elementsWritten = - Metrics.counter(METRIC_NAMESPACE, ELEMENTS_WRITTEN_METRIC); - - public abstract Store getStore(); - - public abstract Map getFeatureSets(); - - public static Builder newBuilder() { - return new AutoValue_WriteToStore.Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder setStore(Store store); - - public abstract Builder setFeatureSets(Map featureSets); - - public abstract WriteToStore build(); - } - - @Override - public PDone expand(PCollection input) { - ImportOptions options = input.getPipeline().getOptions().as(ImportOptions.class); - StoreType storeType = getStore().getType(); - - switch (storeType) { - case REDIS: - PCollection redisWriteResult = - input - .apply( - "FeatureRowToRedisMutation", - ParDo.of(new FeatureRowToRedisMutationDoFn(getFeatureSets()))) - .apply("WriteRedisMutationToRedis", RedisCustomIO.write(getStore())); - if (options.getDeadLetterTableSpec() != null) { - redisWriteResult.apply( - WriteFailedElementToBigQuery.newBuilder() - .setTableSpec(options.getDeadLetterTableSpec()) - .setJsonSchema(ResourceUtil.getDeadletterTableSchemaJson()) - .build()); - } - break; - case BIGQUERY: - BigQueryConfig bigqueryConfig = getStore().getBigqueryConfig(); - - WriteResult bigqueryWriteResult = - input.apply( - "WriteTableRowToBigQuery", - BigQueryIO.write() - .to( - new GetTableDestination( - bigqueryConfig.getProjectId(), bigqueryConfig.getDatasetId())) - .withFormatFunction(new FeatureRowToTableRow(options.getJobName())) - .withCreateDisposition(CreateDisposition.CREATE_NEVER) - .withWriteDisposition(WriteDisposition.WRITE_APPEND) - .withExtendedErrorInfo() - .withMethod(Method.STREAMING_INSERTS) - .withFailedInsertRetryPolicy(InsertRetryPolicy.retryTransientErrors())); - - if (options.getDeadLetterTableSpec() != null) { - bigqueryWriteResult - .getFailedInsertsWithErr() - .apply( - "WrapBigQueryInsertionError", - ParDo.of( - new DoFn() { - @ProcessElement - public void processElement(ProcessContext context) { - InsertErrors error = context.element().getError(); - TableRow row = context.element().getRow(); - try { - context.output( - FailedElement.newBuilder() - .setErrorMessage(error.toPrettyString()) - .setPayload(row.toPrettyString()) - .setJobName(context.getPipelineOptions().getJobName()) - .setTransformName("WriteTableRowToBigQuery") - .build()); - } catch (IOException e) { - log.error(e.getMessage()); - } - } - })) - .apply( - WriteFailedElementToBigQuery.newBuilder() - .setTableSpec(options.getDeadLetterTableSpec()) - .setJsonSchema(ResourceUtil.getDeadletterTableSchemaJson()) - .build()); - } - break; - default: - log.error("Store type '{}' is not supported. No Feature Row will be written.", storeType); - break; - } - - input.apply( - "IncrementWriteToStoreElementsWrittenCounter", - MapElements.into(TypeDescriptors.booleans()) - .via( - (FeatureRow row) -> { - elementsWritten.inc(); - return true; - })); - - return PDone.in(input.getPipeline()); - } -} diff --git a/ingestion/src/main/java/feast/ingestion/transform/fn/KafkaRecordToFeatureRowDoFn.java b/ingestion/src/main/java/feast/ingestion/transform/fn/KafkaRecordToFeatureRowDoFn.java index 25aafd6ee71..b332c0ca09a 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/fn/KafkaRecordToFeatureRowDoFn.java +++ b/ingestion/src/main/java/feast/ingestion/transform/fn/KafkaRecordToFeatureRowDoFn.java @@ -18,8 +18,7 @@ import com.google.auto.value.AutoValue; import com.google.protobuf.InvalidProtocolBufferException; -import feast.ingestion.transform.ReadFromSource.Builder; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.FailedElement; import feast.types.FeatureRowProto.FeatureRow; import java.util.Base64; import org.apache.beam.sdk.io.kafka.KafkaRecord; diff --git a/ingestion/src/main/java/feast/ingestion/transform/fn/ValidateFeatureRowDoFn.java b/ingestion/src/main/java/feast/ingestion/transform/fn/ValidateFeatureRowDoFn.java index c31d3c535e9..85ac3c86faa 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/fn/ValidateFeatureRowDoFn.java +++ b/ingestion/src/main/java/feast/ingestion/transform/fn/ValidateFeatureRowDoFn.java @@ -17,9 +17,9 @@ package feast.ingestion.transform.fn; import com.google.auto.value.AutoValue; -import feast.ingestion.values.FailedElement; import feast.ingestion.values.FeatureSet; import feast.ingestion.values.Field; +import feast.storage.api.writer.FailedElement; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FieldProto; import feast.types.ValueProto.Value.ValCase; diff --git a/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteDeadletterRowMetricsDoFn.java b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteDeadletterRowMetricsDoFn.java index 687670c5cf0..b4338cda09b 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteDeadletterRowMetricsDoFn.java +++ b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteDeadletterRowMetricsDoFn.java @@ -20,7 +20,7 @@ import com.timgroup.statsd.NonBlockingStatsDClient; import com.timgroup.statsd.StatsDClient; import com.timgroup.statsd.StatsDClientException; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.FailedElement; import org.apache.beam.sdk.transforms.DoFn; import org.slf4j.Logger; diff --git a/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteFailureMetricsTransform.java b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteFailureMetricsTransform.java new file mode 100644 index 00000000000..65a27fa8bf4 --- /dev/null +++ b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteFailureMetricsTransform.java @@ -0,0 +1,52 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.ingestion.transform.metrics; + +import com.google.auto.value.AutoValue; +import feast.ingestion.options.ImportOptions; +import feast.storage.api.writer.FailedElement; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; + +@AutoValue +public abstract class WriteFailureMetricsTransform + extends PTransform, PDone> { + + public abstract String getStoreName(); + + public static WriteFailureMetricsTransform create(String storeName) { + return new AutoValue_WriteFailureMetricsTransform(storeName); + } + + @Override + public PDone expand(PCollection input) { + ImportOptions options = input.getPipeline().getOptions().as(ImportOptions.class); + if ("statsd".equals(options.getMetricsExporterType())) { + input.apply( + "WriteDeadletterMetrics", + ParDo.of( + WriteDeadletterRowMetricsDoFn.newBuilder() + .setStatsdHost(options.getStatsdHost()) + .setStatsdPort(options.getStatsdPort()) + .setStoreName(getStoreName()) + .build())); + } + return PDone.in(input.getPipeline()); + } +} diff --git a/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteMetricsTransform.java b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteSuccessMetricsTransform.java similarity index 65% rename from ingestion/src/main/java/feast/ingestion/transform/metrics/WriteMetricsTransform.java rename to ingestion/src/main/java/feast/ingestion/transform/metrics/WriteSuccessMetricsTransform.java index 8a5869d78ec..37eed7455a9 100644 --- a/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteMetricsTransform.java +++ b/ingestion/src/main/java/feast/ingestion/transform/metrics/WriteSuccessMetricsTransform.java @@ -18,68 +18,54 @@ import com.google.auto.value.AutoValue; import feast.ingestion.options.ImportOptions; -import feast.ingestion.values.FailedElement; import feast.types.FeatureRowProto.FeatureRow; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.GroupByKey; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.transforms.*; import org.apache.beam.sdk.transforms.windowing.FixedWindows; import org.apache.beam.sdk.transforms.windowing.Window; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.PDone; -import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TypeDescriptors; import org.joda.time.Duration; @AutoValue -public abstract class WriteMetricsTransform extends PTransform { +public abstract class WriteSuccessMetricsTransform + extends PTransform, PDone> { - public abstract String getStoreName(); - - public abstract TupleTag getSuccessTag(); - - public abstract TupleTag getFailureTag(); - - public static Builder newBuilder() { - return new AutoValue_WriteMetricsTransform.Builder(); - } - - @AutoValue.Builder - public abstract static class Builder { + public static final String METRIC_NAMESPACE = "WriteToStoreSuccess"; + public static final String ELEMENTS_WRITTEN_METRIC = "elements_written"; + private static final Counter elementsWritten = + Metrics.counter(METRIC_NAMESPACE, ELEMENTS_WRITTEN_METRIC); - public abstract Builder setStoreName(String storeName); - - public abstract Builder setSuccessTag(TupleTag successTag); - - public abstract Builder setFailureTag(TupleTag failureTag); + public abstract String getStoreName(); - public abstract WriteMetricsTransform build(); + public static WriteSuccessMetricsTransform create(String storeName) { + return new AutoValue_WriteSuccessMetricsTransform(storeName); } @Override - public PDone expand(PCollectionTuple input) { + public PDone expand(PCollection input) { ImportOptions options = input.getPipeline().getOptions().as(ImportOptions.class); + + input.apply( + "IncrementSuccessfulWriteToStoreElementsWrittenCounter", + MapElements.into(TypeDescriptors.booleans()) + .via( + (FeatureRow row) -> { + elementsWritten.inc(); + return true; + })); + switch (options.getMetricsExporterType()) { case "statsd": - input - .get(getFailureTag()) - .apply( - "WriteDeadletterMetrics", - ParDo.of( - WriteDeadletterRowMetricsDoFn.newBuilder() - .setStatsdHost(options.getStatsdHost()) - .setStatsdPort(options.getStatsdPort()) - .setStoreName(getStoreName()) - .build())); // Fixed window is applied so the metric collector will not be overwhelmed with the metrics // data. For validation, only summaries of the values are usually required vs the actual // values. PCollection>> validRowsGroupedByRef = input - .get(getSuccessTag()) .apply( "FixedWindow", Window.into( @@ -119,15 +105,13 @@ public void processElement( return PDone.in(input.getPipeline()); case "none": default: - input - .get(getSuccessTag()) - .apply( - "Noop", - ParDo.of( - new DoFn() { - @ProcessElement - public void processElement(ProcessContext c) {} - })); + input.apply( + "Noop", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext c) {} + })); return PDone.in(input.getPipeline()); } } diff --git a/ingestion/src/main/java/feast/ingestion/utils/SpecUtil.java b/ingestion/src/main/java/feast/ingestion/utils/SpecUtil.java index 9163c5b2d6f..f28dfc9ee39 100644 --- a/ingestion/src/main/java/feast/ingestion/utils/SpecUtil.java +++ b/ingestion/src/main/java/feast/ingestion/utils/SpecUtil.java @@ -33,9 +33,10 @@ public class SpecUtil { - public static String getFeatureSetReference(FeatureSet featureSet) { - FeatureSetSpec spec = featureSet.getSpec(); - return String.format("%s/%s:%d", spec.getProject(), spec.getName(), spec.getVersion()); + public static String getFeatureSetReference(FeatureSetSpec featureSetSpec) { + return String.format( + "%s/%s:%d", + featureSetSpec.getProject(), featureSetSpec.getName(), featureSetSpec.getVersion()); } /** Get only feature set specs that matches the subscription */ diff --git a/ingestion/src/main/java/feast/ingestion/utils/StoreUtil.java b/ingestion/src/main/java/feast/ingestion/utils/StoreUtil.java index a02b8626945..1b884333818 100644 --- a/ingestion/src/main/java/feast/ingestion/utils/StoreUtil.java +++ b/ingestion/src/main/java/feast/ingestion/utils/StoreUtil.java @@ -18,39 +18,16 @@ import static feast.types.ValueProto.ValueType; -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.BigQueryOptions; -import com.google.cloud.bigquery.DatasetId; -import com.google.cloud.bigquery.DatasetInfo; -import com.google.cloud.bigquery.Field; -import com.google.cloud.bigquery.Field.Builder; -import com.google.cloud.bigquery.Field.Mode; -import com.google.cloud.bigquery.Schema; import com.google.cloud.bigquery.StandardSQLTypeName; -import com.google.cloud.bigquery.StandardTableDefinition; -import com.google.cloud.bigquery.Table; -import com.google.cloud.bigquery.TableDefinition; -import com.google.cloud.bigquery.TableId; -import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.TimePartitioning; -import com.google.cloud.bigquery.TimePartitioning.Type; -import com.google.common.collect.ImmutableMap; -import feast.core.FeatureSetProto.EntitySpec; -import feast.core.FeatureSetProto.FeatureSet; import feast.core.FeatureSetProto.FeatureSetSpec; -import feast.core.FeatureSetProto.FeatureSpec; import feast.core.StoreProto.Store; -import feast.core.StoreProto.Store.RedisConfig; import feast.core.StoreProto.Store.StoreType; +import feast.storage.api.writer.FeatureSink; +import feast.storage.connectors.bigquery.writer.BigQueryFeatureSink; +import feast.storage.connectors.redis.writer.RedisFeatureSink; import feast.types.ValueProto.ValueType.Enum; -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisConnectionException; -import io.lettuce.core.RedisURI; -import java.util.ArrayList; import java.util.HashMap; -import java.util.List; import java.util.Map; -import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; // TODO: Create partitioned table by default @@ -101,155 +78,19 @@ public class StoreUtil { VALUE_TYPE_TO_STANDARD_SQL_TYPE.put(Enum.BOOL_LIST, StandardSQLTypeName.BOOL); } - public static void setupStore(Store store, FeatureSet featureSet) { + public static FeatureSink getFeatureSink( + Store store, Map featureSetSpecs) { StoreType storeType = store.getType(); switch (storeType) { case REDIS: - StoreUtil.checkRedisConnection(store.getRedisConfig()); - break; + return RedisFeatureSink.builder() + .setRedisConfig(store.getRedisConfig()) + .setFeatureSetSpecs(featureSetSpecs) + .build(); case BIGQUERY: - StoreUtil.setupBigQuery( - featureSet, - store.getBigqueryConfig().getProjectId(), - store.getBigqueryConfig().getDatasetId(), - BigQueryOptions.getDefaultInstance().getService()); - break; + return BigQueryFeatureSink.fromConfig(store.getBigqueryConfig()); default: - log.warn("Store type '{}' is unsupported", storeType); - break; + throw new RuntimeException(String.format("Store type '{}' is unsupported", storeType)); } } - - @SuppressWarnings("DuplicatedCode") - public static TableDefinition createBigQueryTableDefinition(FeatureSetSpec featureSetSpec) { - List fields = new ArrayList<>(); - log.info("Table will have the following fields:"); - - for (EntitySpec entitySpec : featureSetSpec.getEntitiesList()) { - Builder builder = - Field.newBuilder( - entitySpec.getName(), VALUE_TYPE_TO_STANDARD_SQL_TYPE.get(entitySpec.getValueType())); - if (entitySpec.getValueType().name().toLowerCase().endsWith("_list")) { - builder.setMode(Mode.REPEATED); - } - Field field = builder.build(); - log.info("- {}", field.toString()); - fields.add(field); - } - for (FeatureSpec featureSpec : featureSetSpec.getFeaturesList()) { - Builder builder = - Field.newBuilder( - featureSpec.getName(), - VALUE_TYPE_TO_STANDARD_SQL_TYPE.get(featureSpec.getValueType())); - if (featureSpec.getValueType().name().toLowerCase().endsWith("_list")) { - builder.setMode(Mode.REPEATED); - } - Field field = builder.build(); - log.info("- {}", field.toString()); - fields.add(field); - } - - // Refer to protos/feast/core/Store.proto for reserved fields in BigQuery. - Map> - reservedFieldNameToPairOfStandardSQLTypeAndDescription = - ImmutableMap.of( - "event_timestamp", - Pair.of(StandardSQLTypeName.TIMESTAMP, BIGQUERY_EVENT_TIMESTAMP_FIELD_DESCRIPTION), - "created_timestamp", - Pair.of( - StandardSQLTypeName.TIMESTAMP, BIGQUERY_CREATED_TIMESTAMP_FIELD_DESCRIPTION), - "job_id", - Pair.of(StandardSQLTypeName.STRING, BIGQUERY_JOB_ID_FIELD_DESCRIPTION)); - for (Map.Entry> entry : - reservedFieldNameToPairOfStandardSQLTypeAndDescription.entrySet()) { - Field field = - Field.newBuilder(entry.getKey(), entry.getValue().getLeft()) - .setDescription(entry.getValue().getRight()) - .build(); - log.info("- {}", field.toString()); - fields.add(field); - } - - TimePartitioning timePartitioning = - TimePartitioning.newBuilder(Type.DAY).setField("event_timestamp").build(); - log.info("Table partitioning: " + timePartitioning.toString()); - - return StandardTableDefinition.newBuilder() - .setTimePartitioning(timePartitioning) - .setSchema(Schema.of(fields)) - .build(); - } - - /** - * This method ensures that, given a FeatureSetSpec object, the relevant BigQuery table is created - * with the correct schema. - * - *

Refer to protos/feast/core/Store.proto for the derivation of the table name and schema from - * a FeatureSetSpec object. - * - * @param featureSet FeatureSet object - * @param bigqueryProjectId BigQuery project id - * @param bigqueryDatasetId BigQuery dataset id - * @param bigquery BigQuery service object - */ - public static void setupBigQuery( - FeatureSet featureSet, - String bigqueryProjectId, - String bigqueryDatasetId, - BigQuery bigquery) { - - FeatureSetSpec featureSetSpec = featureSet.getSpec(); - // Ensure BigQuery dataset exists. - DatasetId datasetId = DatasetId.of(bigqueryProjectId, bigqueryDatasetId); - if (bigquery.getDataset(datasetId) == null) { - log.info("Creating dataset '{}' in project '{}'", datasetId.getDataset(), bigqueryProjectId); - bigquery.create(DatasetInfo.of(datasetId)); - } - - String tableName = - String.format( - "%s_%s_v%d", - featureSetSpec.getProject(), featureSetSpec.getName(), featureSetSpec.getVersion()) - .replaceAll("-", "_"); - TableId tableId = TableId.of(bigqueryProjectId, datasetId.getDataset(), tableName); - - // Return if there is an existing table - Table table = bigquery.getTable(tableId); - if (table != null) { - log.info( - "Writing to existing BigQuery table '{}:{}.{}'", - bigqueryProjectId, - datasetId.getDataset(), - tableName); - return; - } - - log.info( - "Creating table '{}' in dataset '{}' in project '{}'", - tableId.getTable(), - datasetId.getDataset(), - bigqueryProjectId); - TableDefinition tableDefinition = createBigQueryTableDefinition(featureSet.getSpec()); - TableInfo tableInfo = TableInfo.of(tableId, tableDefinition); - bigquery.create(tableInfo); - } - - /** - * Ensure Redis is accessible, else throw a RuntimeException. - * - * @param redisConfig Plase refer to feast.core.Store proto - */ - public static void checkRedisConnection(RedisConfig redisConfig) { - RedisClient redisClient = - RedisClient.create(RedisURI.create(redisConfig.getHost(), redisConfig.getPort())); - try { - redisClient.connect(); - } catch (RedisConnectionException e) { - throw new RuntimeException( - String.format( - "Failed to connect to Redis at host: '%s' port: '%d'. Please check that your Redis is running and accessible from Feast.", - redisConfig.getHost(), redisConfig.getPort())); - } - redisClient.shutdown(); - } } diff --git a/ingestion/src/main/java/feast/ingestion/values/FeatureSet.java b/ingestion/src/main/java/feast/ingestion/values/FeatureSet.java index bf07bcec966..758fbd0ba31 100644 --- a/ingestion/src/main/java/feast/ingestion/values/FeatureSet.java +++ b/ingestion/src/main/java/feast/ingestion/values/FeatureSet.java @@ -34,9 +34,9 @@ public class FeatureSet implements Serializable { private final Map fields; - public FeatureSet(FeatureSetProto.FeatureSet featureSet) { - this.reference = getFeatureSetReference(featureSet); - this.fields = getFieldsByName(featureSet.getSpec()); + public FeatureSet(FeatureSetProto.FeatureSetSpec featureSetSpec) { + this.reference = getFeatureSetReference(featureSetSpec); + this.fields = getFieldsByName(featureSetSpec); } public String getReference() { diff --git a/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java b/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java deleted file mode 100644 index 1f5a0f19677..00000000000 --- a/ingestion/src/main/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFn.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.store.serving.redis; - -import feast.core.FeatureSetProto.EntitySpec; -import feast.core.FeatureSetProto.FeatureSet; -import feast.core.FeatureSetProto.FeatureSetSpec; -import feast.core.FeatureSetProto.FeatureSpec; -import feast.storage.RedisProto.RedisKey; -import feast.storage.RedisProto.RedisKey.Builder; -import feast.store.serving.redis.RedisCustomIO.Method; -import feast.store.serving.redis.RedisCustomIO.RedisMutation; -import feast.types.FeatureRowProto.FeatureRow; -import feast.types.FieldProto.Field; -import feast.types.ValueProto; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import org.apache.beam.sdk.transforms.DoFn; -import org.slf4j.Logger; - -public class FeatureRowToRedisMutationDoFn extends DoFn { - - private static final Logger log = - org.slf4j.LoggerFactory.getLogger(FeatureRowToRedisMutationDoFn.class); - private Map featureSets; - - public FeatureRowToRedisMutationDoFn(Map featureSets) { - this.featureSets = featureSets; - } - - private RedisKey getKey(FeatureRow featureRow) { - FeatureSet featureSet = featureSets.get(featureRow.getFeatureSet()); - List entityNames = - featureSet.getSpec().getEntitiesList().stream() - .map(EntitySpec::getName) - .sorted() - .collect(Collectors.toList()); - - Map entityFields = new HashMap<>(); - Builder redisKeyBuilder = RedisKey.newBuilder().setFeatureSet(featureRow.getFeatureSet()); - for (Field field : featureRow.getFieldsList()) { - if (entityNames.contains(field.getName())) { - entityFields.putIfAbsent( - field.getName(), - Field.newBuilder().setName(field.getName()).setValue(field.getValue()).build()); - } - } - for (String entityName : entityNames) { - if (entityFields.containsKey(entityName)) { - redisKeyBuilder.addEntities(entityFields.get(entityName)); - } - } - return redisKeyBuilder.build(); - } - - private byte[] getValue(FeatureRow featureRow) { - FeatureSetSpec spec = featureSets.get(featureRow.getFeatureSet()).getSpec(); - - List featureNames = - spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); - Map fieldValueOnlyMap = - featureRow.getFieldsList().stream() - .filter(field -> featureNames.contains(field.getName())) - .distinct() - .collect( - Collectors.toMap( - Field::getName, - field -> Field.newBuilder().setValue(field.getValue()).build())); - - List values = - featureNames.stream() - .sorted() - .map( - featureName -> - fieldValueOnlyMap.getOrDefault( - featureName, - Field.newBuilder().setValue(ValueProto.Value.getDefaultInstance()).build())) - .collect(Collectors.toList()); - - return FeatureRow.newBuilder() - .setEventTimestamp(featureRow.getEventTimestamp()) - .addAllFields(values) - .build() - .toByteArray(); - } - - /** Output a redis mutation object for every feature in the feature row. */ - @ProcessElement - public void processElement(ProcessContext context) { - FeatureRow featureRow = context.element(); - try { - byte[] key = getKey(featureRow).toByteArray(); - byte[] value = getValue(featureRow); - RedisMutation redisMutation = new RedisMutation(Method.SET, key, value, null, null); - context.output(redisMutation); - } catch (Exception e) { - log.error(e.getMessage(), e); - } - } -} diff --git a/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java b/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java deleted file mode 100644 index 633c2eb551d..00000000000 --- a/ingestion/src/main/java/feast/store/serving/redis/RedisCustomIO.java +++ /dev/null @@ -1,341 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.store.serving.redis; - -import feast.core.StoreProto; -import feast.ingestion.values.FailedElement; -import feast.retry.Retriable; -import io.lettuce.core.RedisConnectionException; -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ExecutionException; -import org.apache.avro.reflect.Nullable; -import org.apache.beam.sdk.coders.AvroCoder; -import org.apache.beam.sdk.coders.DefaultCoder; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.PTransform; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.windowing.GlobalWindow; -import org.apache.beam.sdk.values.PCollection; -import org.apache.commons.lang3.exception.ExceptionUtils; -import org.joda.time.Instant; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -public class RedisCustomIO { - - private static final int DEFAULT_BATCH_SIZE = 1000; - private static final int DEFAULT_TIMEOUT = 2000; - - private static final Logger log = LoggerFactory.getLogger(RedisCustomIO.class); - - private RedisCustomIO() {} - - public static Write write(StoreProto.Store store) { - return new Write(store); - } - - public enum Method { - - /** - * Use APPEND command. If key already exists and is a string, this command appends the value at - * the end of the string. - */ - APPEND, - - /** Use SET command. If key already holds a value, it is overwritten. */ - SET, - - /** - * Use LPUSH command. Insert value at the head of the list stored at key. If key does not exist, - * it is created as empty list before performing the push operations. When key holds a value - * that is not a list, an error is returned. - */ - LPUSH, - - /** - * Use RPUSH command. Insert value at the tail of the list stored at key. If key does not exist, - * it is created as empty list before performing the push operations. When key holds a value - * that is not a list, an error is returned. - */ - RPUSH, - - /** - * Use SADD command. Insert value into a set with a defined key. If key does not exist, it is - * created as empty set before performing the add operations. When key holds a value that is not - * a set, an error is returned. - */ - SADD, - - /** - * Use ZADD command. Adds all the specified members with the specified scores to the sorted set - * stored at key. It is possible to specify multiple score / member pairs. If a specified member - * is already a member of the sorted set, the score is updated and the element reinserted at the - * right position to ensure the correct ordering. - */ - ZADD - } - - @DefaultCoder(AvroCoder.class) - public static class RedisMutation { - - private Method method; - private byte[] key; - private byte[] value; - @Nullable private Long expiryMillis; - @Nullable private Long score; - - public RedisMutation() {} - - public RedisMutation( - Method method, - byte[] key, - byte[] value, - @Nullable Long expiryMillis, - @Nullable Long score) { - this.method = method; - this.key = key; - this.value = value; - this.expiryMillis = expiryMillis; - this.score = score; - } - - public Method getMethod() { - return method; - } - - public void setMethod(Method method) { - this.method = method; - } - - public byte[] getKey() { - return key; - } - - public void setKey(byte[] key) { - this.key = key; - } - - public byte[] getValue() { - return value; - } - - public void setValue(byte[] value) { - this.value = value; - } - - @Nullable - public Long getExpiryMillis() { - return expiryMillis; - } - - public void setExpiryMillis(@Nullable Long expiryMillis) { - this.expiryMillis = expiryMillis; - } - - @Nullable - public Long getScore() { - return score; - } - - public void setScore(@Nullable Long score) { - this.score = score; - } - } - - /** ServingStoreWrite data to a Redis server. */ - public static class Write - extends PTransform, PCollection> { - - private WriteDoFn dofn; - - private Write(StoreProto.Store store) { - this.dofn = new WriteDoFn(store); - } - - public Write withBatchSize(int batchSize) { - this.dofn.withBatchSize(batchSize); - return this; - } - - public Write withTimeout(int timeout) { - this.dofn.withTimeout(timeout); - return this; - } - - @Override - public PCollection expand(PCollection input) { - return input.apply(ParDo.of(dofn)); - } - - public static class WriteDoFn extends DoFn { - - private final List mutations = new ArrayList<>(); - private int batchSize = DEFAULT_BATCH_SIZE; - private int timeout = DEFAULT_TIMEOUT; - private RedisIngestionClient redisIngestionClient; - - WriteDoFn(StoreProto.Store store) { - if (store.getType() == StoreProto.Store.StoreType.REDIS) - this.redisIngestionClient = new RedisStandaloneIngestionClient(store.getRedisConfig()); - } - - public WriteDoFn withBatchSize(int batchSize) { - if (batchSize > 0) { - this.batchSize = batchSize; - } - return this; - } - - public WriteDoFn withTimeout(int timeout) { - if (timeout > 0) { - this.timeout = timeout; - } - return this; - } - - @Setup - public void setup() { - this.redisIngestionClient.setup(); - } - - @StartBundle - public void startBundle() { - try { - redisIngestionClient.connect(); - } catch (RedisConnectionException e) { - log.error("Connection to redis cannot be established ", e); - } - mutations.clear(); - } - - private void executeBatch() throws Exception { - this.redisIngestionClient - .getBackOffExecutor() - .execute( - new Retriable() { - @Override - public void execute() throws ExecutionException, InterruptedException { - if (!redisIngestionClient.isConnected()) { - redisIngestionClient.connect(); - } - mutations.forEach( - mutation -> { - writeRecord(mutation); - if (mutation.getExpiryMillis() != null - && mutation.getExpiryMillis() > 0) { - redisIngestionClient.pexpire( - mutation.getKey(), mutation.getExpiryMillis()); - } - }); - redisIngestionClient.sync(); - mutations.clear(); - } - - @Override - public Boolean isExceptionRetriable(Exception e) { - return e instanceof RedisConnectionException; - } - - @Override - public void cleanUpAfterFailure() {} - }); - } - - private FailedElement toFailedElement( - RedisMutation mutation, Exception exception, String jobName) { - return FailedElement.newBuilder() - .setJobName(jobName) - .setTransformName("RedisCustomIO") - .setPayload(Arrays.toString(mutation.getValue())) - .setErrorMessage(exception.getMessage()) - .setStackTrace(ExceptionUtils.getStackTrace(exception)) - .build(); - } - - @ProcessElement - public void processElement(ProcessContext context) { - RedisMutation mutation = context.element(); - mutations.add(mutation); - if (mutations.size() >= batchSize) { - try { - executeBatch(); - } catch (Exception e) { - mutations.forEach( - failedMutation -> { - FailedElement failedElement = - toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); - context.output(failedElement); - }); - mutations.clear(); - } - } - } - - private void writeRecord(RedisMutation mutation) { - switch (mutation.getMethod()) { - case APPEND: - redisIngestionClient.append(mutation.getKey(), mutation.getValue()); - return; - case SET: - redisIngestionClient.set(mutation.getKey(), mutation.getValue()); - return; - case LPUSH: - redisIngestionClient.lpush(mutation.getKey(), mutation.getValue()); - return; - case RPUSH: - redisIngestionClient.rpush(mutation.getKey(), mutation.getValue()); - return; - case SADD: - redisIngestionClient.sadd(mutation.getKey(), mutation.getValue()); - return; - case ZADD: - redisIngestionClient.zadd(mutation.getKey(), mutation.getScore(), mutation.getValue()); - return; - default: - throw new UnsupportedOperationException( - String.format("Not implemented writing records for %s", mutation.getMethod())); - } - } - - @FinishBundle - public void finishBundle(FinishBundleContext context) - throws IOException, InterruptedException { - if (mutations.size() > 0) { - try { - executeBatch(); - } catch (Exception e) { - mutations.forEach( - failedMutation -> { - FailedElement failedElement = - toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); - context.output(failedElement, Instant.now(), GlobalWindow.INSTANCE); - }); - mutations.clear(); - } - } - } - - @Teardown - public void teardown() { - redisIngestionClient.shutdown(); - } - } - } -} diff --git a/ingestion/src/test/java/feast/ingestion/ImportJobTest.java b/ingestion/src/test/java/feast/ingestion/ImportJobTest.java index 0b000df0f59..13df73e96a4 100644 --- a/ingestion/src/test/java/feast/ingestion/ImportJobTest.java +++ b/ingestion/src/test/java/feast/ingestion/ImportJobTest.java @@ -188,8 +188,8 @@ public void runPipeline_ShouldWriteToRedisCorrectlyGivenValidSpecAndFeatureRow() IntStream.range(0, IMPORT_JOB_SAMPLE_FEATURE_ROW_SIZE) .forEach( i -> { - FeatureRow randomRow = TestUtil.createRandomFeatureRow(featureSet); - RedisKey redisKey = TestUtil.createRedisKey(featureSet, randomRow); + FeatureRow randomRow = TestUtil.createRandomFeatureRow(featureSet.getSpec()); + RedisKey redisKey = TestUtil.createRedisKey(featureSet.getSpec(), randomRow); input.add(randomRow); List fields = randomRow.getFieldsList().stream() diff --git a/ingestion/src/test/java/feast/ingestion/transform/ValidateFeatureRowsTest.java b/ingestion/src/test/java/feast/ingestion/transform/ValidateFeatureRowsTest.java index 5c9860ed97f..3737a736168 100644 --- a/ingestion/src/test/java/feast/ingestion/transform/ValidateFeatureRowsTest.java +++ b/ingestion/src/test/java/feast/ingestion/transform/ValidateFeatureRowsTest.java @@ -16,13 +16,10 @@ */ package feast.ingestion.transform; -import static org.junit.Assert.*; - import feast.core.FeatureSetProto.EntitySpec; -import feast.core.FeatureSetProto.FeatureSet; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSpec; -import feast.ingestion.values.FailedElement; +import feast.storage.api.writer.FailedElement; import feast.test.TestUtil; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FieldProto.Field; @@ -52,73 +49,57 @@ public class ValidateFeatureRowsTest { @Test public void shouldWriteSuccessAndFailureTagsCorrectly() { - FeatureSet fs1 = - FeatureSet.newBuilder() - .setSpec( - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setVersion(1) - .setProject("myproject") - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_1") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_2") - .setValueType(Enum.INT64) - .build())) + FeatureSetSpec fs1 = + FeatureSetSpec.newBuilder() + .setName("feature_set") + .setVersion(1) + .setProject("myproject") + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_primary") + .setValueType(Enum.INT32) + .build()) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_secondary") + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) .build(); - FeatureSet fs2 = - FeatureSet.newBuilder() - .setSpec( - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setVersion(2) - .setProject("myproject") - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_1") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_2") - .setValueType(Enum.INT64) - .build())) + FeatureSetSpec fs2 = + FeatureSetSpec.newBuilder() + .setName("feature_set") + .setVersion(2) + .setProject("myproject") + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_primary") + .setValueType(Enum.INT32) + .build()) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_secondary") + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) .build(); - Map featureSets = new HashMap<>(); - featureSets.put("myproject/feature_set:1", fs1); - featureSets.put("myproject/feature_set:2", fs2); + Map featureSetSpecs = new HashMap<>(); + featureSetSpecs.put("myproject/feature_set:1", fs1); + featureSetSpecs.put("myproject/feature_set:2", fs2); List input = new ArrayList<>(); List expected = new ArrayList<>(); - for (FeatureSet featureSet : featureSets.values()) { - FeatureRow randomRow = TestUtil.createRandomFeatureRow(featureSet); + for (FeatureSetSpec featureSetSpec : featureSetSpecs.values()) { + FeatureRow randomRow = TestUtil.createRandomFeatureRow(featureSetSpec); input.add(randomRow); expected.add(randomRow); } @@ -132,7 +113,7 @@ public void shouldWriteSuccessAndFailureTagsCorrectly() { ValidateFeatureRows.newBuilder() .setFailureTag(FAILURE_TAG) .setSuccessTag(SUCCESS_TAG) - .setFeatureSets(featureSets) + .setFeatureSetSpecs(featureSetSpecs) .build()); PAssert.that(output.get(SUCCESS_TAG)).containsInAnyOrder(expected); @@ -143,36 +124,28 @@ public void shouldWriteSuccessAndFailureTagsCorrectly() { @Test public void shouldExcludeUnregisteredFields() { - FeatureSet fs1 = - FeatureSet.newBuilder() - .setSpec( - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setVersion(1) - .setProject("myproject") - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_1") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_2") - .setValueType(Enum.INT64) - .build())) + FeatureSetSpec fs1 = + FeatureSetSpec.newBuilder() + .setName("feature_set") + .setVersion(1) + .setProject("myproject") + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_primary") + .setValueType(Enum.INT32) + .build()) + .addEntities( + EntitySpec.newBuilder() + .setName("entity_id_secondary") + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) .build(); - Map featureSets = new HashMap<>(); + Map featureSets = new HashMap<>(); featureSets.put("myproject/feature_set:1", fs1); List input = new ArrayList<>(); @@ -196,7 +169,7 @@ public void shouldExcludeUnregisteredFields() { ValidateFeatureRows.newBuilder() .setFailureTag(FAILURE_TAG) .setSuccessTag(SUCCESS_TAG) - .setFeatureSets(featureSets) + .setFeatureSetSpecs(featureSets) .build()); PAssert.that(output.get(SUCCESS_TAG)).containsInAnyOrder(expected); diff --git a/ingestion/src/test/java/feast/ingestion/utils/StoreUtilTest.java b/ingestion/src/test/java/feast/ingestion/utils/StoreUtilTest.java deleted file mode 100644 index 82988121bc8..00000000000 --- a/ingestion/src/test/java/feast/ingestion/utils/StoreUtilTest.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.ingestion.utils; - -import static feast.types.ValueProto.ValueType.Enum.*; - -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.Field; -import com.google.cloud.bigquery.Field.Mode; -import com.google.cloud.bigquery.Schema; -import com.google.cloud.bigquery.StandardSQLTypeName; -import feast.core.FeatureSetProto.EntitySpec; -import feast.core.FeatureSetProto.FeatureSet; -import feast.core.FeatureSetProto.FeatureSetSpec; -import feast.core.FeatureSetProto.FeatureSpec; -import java.util.Arrays; -import org.junit.Assert; -import org.junit.Test; -import org.mockito.Mockito; - -public class StoreUtilTest { - - @Test - public void setupBigQuery_shouldCreateTable_givenValidFeatureSetSpec() { - FeatureSet featureSet = - FeatureSet.newBuilder() - .setSpec( - FeatureSetSpec.newBuilder() - .setName("feature_set_1") - .setVersion(1) - .setProject("feast-project") - .addEntities(EntitySpec.newBuilder().setName("entity_1").setValueType(INT32)) - .addFeatures(FeatureSpec.newBuilder().setName("feature_1").setValueType(INT32)) - .addFeatures( - FeatureSpec.newBuilder().setName("feature_2").setValueType(STRING_LIST))) - .build(); - BigQuery mockedBigquery = Mockito.mock(BigQuery.class); - StoreUtil.setupBigQuery(featureSet, "project-1", "dataset_1", mockedBigquery); - } - - @Test - public void createBigQueryTableDefinition_shouldCreateCorrectSchema_givenValidFeatureSetSpec() { - FeatureSetSpec input = - FeatureSetSpec.newBuilder() - .addAllEntities( - Arrays.asList( - EntitySpec.newBuilder().setName("bytes_entity").setValueType(BYTES).build(), - EntitySpec.newBuilder().setName("string_entity").setValueType(STRING).build(), - EntitySpec.newBuilder().setName("int32_entity").setValueType(INT32).build(), - EntitySpec.newBuilder().setName("int64_entity").setValueType(INT64).build(), - EntitySpec.newBuilder().setName("double_entity").setValueType(DOUBLE).build(), - EntitySpec.newBuilder().setName("float_entity").setValueType(FLOAT).build(), - EntitySpec.newBuilder().setName("bool_entity").setValueType(BOOL).build(), - EntitySpec.newBuilder() - .setName("bytes_list_entity") - .setValueType(BYTES_LIST) - .build(), - EntitySpec.newBuilder() - .setName("string_list_entity") - .setValueType(STRING_LIST) - .build(), - EntitySpec.newBuilder() - .setName("int32_list_entity") - .setValueType(INT32_LIST) - .build(), - EntitySpec.newBuilder() - .setName("int64_list_entity") - .setValueType(INT64_LIST) - .build(), - EntitySpec.newBuilder() - .setName("double_list_entity") - .setValueType(DOUBLE_LIST) - .build(), - EntitySpec.newBuilder() - .setName("float_list_entity") - .setValueType(FLOAT_LIST) - .build(), - EntitySpec.newBuilder() - .setName("bool_list_entity") - .setValueType(BOOL_LIST) - .build())) - .addAllFeatures( - Arrays.asList( - FeatureSpec.newBuilder().setName("bytes_feature").setValueType(BYTES).build(), - FeatureSpec.newBuilder().setName("string_feature").setValueType(STRING).build(), - FeatureSpec.newBuilder().setName("int32_feature").setValueType(INT32).build(), - FeatureSpec.newBuilder().setName("int64_feature").setValueType(INT64).build(), - FeatureSpec.newBuilder().setName("double_feature").setValueType(DOUBLE).build(), - FeatureSpec.newBuilder().setName("float_feature").setValueType(FLOAT).build(), - FeatureSpec.newBuilder().setName("bool_feature").setValueType(BOOL).build(), - FeatureSpec.newBuilder() - .setName("bytes_list_feature") - .setValueType(BYTES_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("string_list_feature") - .setValueType(STRING_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("int32_list_feature") - .setValueType(INT32_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("int64_list_feature") - .setValueType(INT64_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("double_list_feature") - .setValueType(DOUBLE_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("float_list_feature") - .setValueType(FLOAT_LIST) - .build(), - FeatureSpec.newBuilder() - .setName("bool_list_feature") - .setValueType(BOOL_LIST) - .build())) - .build(); - - Schema actual = StoreUtil.createBigQueryTableDefinition(input).getSchema(); - - Schema expected = - Schema.of( - Arrays.asList( - // Fields from entity - Field.newBuilder("bytes_entity", StandardSQLTypeName.BYTES).build(), - Field.newBuilder("string_entity", StandardSQLTypeName.STRING).build(), - Field.newBuilder("int32_entity", StandardSQLTypeName.INT64).build(), - Field.newBuilder("int64_entity", StandardSQLTypeName.INT64).build(), - Field.newBuilder("double_entity", StandardSQLTypeName.FLOAT64).build(), - Field.newBuilder("float_entity", StandardSQLTypeName.FLOAT64).build(), - Field.newBuilder("bool_entity", StandardSQLTypeName.BOOL).build(), - Field.newBuilder("bytes_list_entity", StandardSQLTypeName.BYTES) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("string_list_entity", StandardSQLTypeName.STRING) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("int32_list_entity", StandardSQLTypeName.INT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("int64_list_entity", StandardSQLTypeName.INT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("double_list_entity", StandardSQLTypeName.FLOAT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("float_list_entity", StandardSQLTypeName.FLOAT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("bool_list_entity", StandardSQLTypeName.BOOL) - .setMode(Mode.REPEATED) - .build(), - // Fields from feature - Field.newBuilder("bytes_feature", StandardSQLTypeName.BYTES).build(), - Field.newBuilder("string_feature", StandardSQLTypeName.STRING).build(), - Field.newBuilder("int32_feature", StandardSQLTypeName.INT64).build(), - Field.newBuilder("int64_feature", StandardSQLTypeName.INT64).build(), - Field.newBuilder("double_feature", StandardSQLTypeName.FLOAT64).build(), - Field.newBuilder("float_feature", StandardSQLTypeName.FLOAT64).build(), - Field.newBuilder("bool_feature", StandardSQLTypeName.BOOL).build(), - Field.newBuilder("bytes_list_feature", StandardSQLTypeName.BYTES) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("string_list_feature", StandardSQLTypeName.STRING) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("int32_list_feature", StandardSQLTypeName.INT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("int64_list_feature", StandardSQLTypeName.INT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("double_list_feature", StandardSQLTypeName.FLOAT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("float_list_feature", StandardSQLTypeName.FLOAT64) - .setMode(Mode.REPEATED) - .build(), - Field.newBuilder("bool_list_feature", StandardSQLTypeName.BOOL) - .setMode(Mode.REPEATED) - .build(), - // Reserved fields - Field.newBuilder("event_timestamp", StandardSQLTypeName.TIMESTAMP) - .setDescription(StoreUtil.BIGQUERY_EVENT_TIMESTAMP_FIELD_DESCRIPTION) - .build(), - Field.newBuilder("created_timestamp", StandardSQLTypeName.TIMESTAMP) - .setDescription(StoreUtil.BIGQUERY_CREATED_TIMESTAMP_FIELD_DESCRIPTION) - .build(), - Field.newBuilder("job_id", StandardSQLTypeName.STRING) - .setDescription(StoreUtil.BIGQUERY_JOB_ID_FIELD_DESCRIPTION) - .build())); - - Assert.assertEquals(expected, actual); - } -} diff --git a/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java b/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java deleted file mode 100644 index 75663d24a6a..00000000000 --- a/ingestion/src/test/java/feast/store/serving/redis/RedisCustomIOTest.java +++ /dev/null @@ -1,238 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.store.serving.redis; - -import static feast.test.TestUtil.field; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; - -import feast.core.StoreProto; -import feast.storage.RedisProto.RedisKey; -import feast.store.serving.redis.RedisCustomIO.Method; -import feast.store.serving.redis.RedisCustomIO.RedisMutation; -import feast.types.FeatureRowProto.FeatureRow; -import feast.types.ValueProto.ValueType.Enum; -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.api.sync.RedisStringCommands; -import io.lettuce.core.codec.ByteArrayCodec; -import java.io.IOException; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.concurrent.ScheduledFuture; -import java.util.concurrent.ScheduledThreadPoolExecutor; -import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; -import org.apache.beam.sdk.testing.PAssert; -import org.apache.beam.sdk.testing.TestPipeline; -import org.apache.beam.sdk.transforms.Count; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.values.PCollection; -import org.junit.After; -import org.junit.Before; -import org.junit.Rule; -import org.junit.Test; -import redis.embedded.Redis; -import redis.embedded.RedisServer; - -public class RedisCustomIOTest { - @Rule public transient TestPipeline p = TestPipeline.create(); - - private static String REDIS_HOST = "localhost"; - private static int REDIS_PORT = 51234; - private Redis redis; - private RedisClient redisClient; - private RedisStringCommands sync; - - @Before - public void setUp() throws IOException { - redis = new RedisServer(REDIS_PORT); - redis.start(); - redisClient = - RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000))); - StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec()); - sync = connection.sync(); - } - - @After - public void teardown() { - redisClient.shutdown(); - redis.stop(); - } - - @Test - public void shouldWriteToRedis() { - StoreProto.Store.RedisConfig redisConfig = - StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("fs:1") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("fs:1") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("fs:1") - .addEntities(field("entity", 2, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("fs:1") - .addFields(field("entity", 2, Enum.INT64)) - .addFields(field("feature", "two", Enum.STRING)) - .build()); - - List featureRowWrites = - kvs.entrySet().stream() - .map( - kv -> - new RedisMutation( - Method.SET, - kv.getKey().toByteArray(), - kv.getValue().toByteArray(), - null, - null)) - .collect(Collectors.toList()); - - StoreProto.Store store = - StoreProto.Store.newBuilder() - .setRedisConfig(redisConfig) - .setType(StoreProto.Store.StoreType.REDIS) - .build(); - p.apply(Create.of(featureRowWrites)).apply(RedisCustomIO.write(store)); - p.run(); - - kvs.forEach( - (key, value) -> { - byte[] actual = sync.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test(timeout = 10000) - public void shouldRetryFailConnection() throws InterruptedException { - StoreProto.Store.RedisConfig redisConfig = - StoreProto.Store.RedisConfig.newBuilder() - .setHost(REDIS_HOST) - .setPort(REDIS_PORT) - .setMaxRetries(4) - .setInitialBackoffMs(2000) - .build(); - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("fs:1") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("fs:1") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - List featureRowWrites = - kvs.entrySet().stream() - .map( - kv -> - new RedisMutation( - Method.SET, - kv.getKey().toByteArray(), - kv.getValue().toByteArray(), - null, - null)) - .collect(Collectors.toList()); - - StoreProto.Store store = - StoreProto.Store.newBuilder() - .setRedisConfig(redisConfig) - .setType(StoreProto.Store.StoreType.REDIS) - .build(); - PCollection failedElementCount = - p.apply(Create.of(featureRowWrites)) - .apply(RedisCustomIO.write(store)) - .apply(Count.globally()); - - redis.stop(); - final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); - ScheduledFuture scheduledRedisRestart = - redisRestartExecutor.schedule( - () -> { - redis.start(); - }, - 3, - TimeUnit.SECONDS); - - PAssert.that(failedElementCount).containsInAnyOrder(0L); - p.run(); - scheduledRedisRestart.cancel(true); - - kvs.forEach( - (key, value) -> { - byte[] actual = sync.get(key.toByteArray()); - assertThat(actual, equalTo(value.toByteArray())); - }); - } - - @Test - public void shouldProduceFailedElementIfRetryExceeded() { - StoreProto.Store.RedisConfig redisConfig = - StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); - HashMap kvs = new LinkedHashMap<>(); - kvs.put( - RedisKey.newBuilder() - .setFeatureSet("fs:1") - .addEntities(field("entity", 1, Enum.INT64)) - .build(), - FeatureRow.newBuilder() - .setFeatureSet("fs:1") - .addFields(field("entity", 1, Enum.INT64)) - .addFields(field("feature", "one", Enum.STRING)) - .build()); - - List featureRowWrites = - kvs.entrySet().stream() - .map( - kv -> - new RedisMutation( - Method.SET, - kv.getKey().toByteArray(), - kv.getValue().toByteArray(), - null, - null)) - .collect(Collectors.toList()); - - StoreProto.Store store = - StoreProto.Store.newBuilder() - .setRedisConfig(redisConfig) - .setType(StoreProto.Store.StoreType.REDIS) - .build(); - PCollection failedElementCount = - p.apply(Create.of(featureRowWrites)) - .apply(RedisCustomIO.write(store)) - .apply(Count.globally()); - - redis.stop(); - PAssert.that(failedElementCount).containsInAnyOrder(1L); - p.run(); - } -} diff --git a/ingestion/src/test/java/feast/test/TestUtil.java b/ingestion/src/test/java/feast/test/TestUtil.java index 3cad39e3ec5..2cd3242fb00 100644 --- a/ingestion/src/test/java/feast/test/TestUtil.java +++ b/ingestion/src/test/java/feast/test/TestUtil.java @@ -21,20 +21,13 @@ import com.google.protobuf.ByteString; import com.google.protobuf.util.Timestamps; import feast.core.FeatureSetProto.FeatureSet; -import feast.ingestion.transform.WriteToStore; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.ingestion.transform.metrics.WriteSuccessMetricsTransform; import feast.storage.RedisProto.RedisKey; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FeatureRowProto.FeatureRow.Builder; import feast.types.FieldProto.Field; -import feast.types.ValueProto.BoolList; -import feast.types.ValueProto.BytesList; -import feast.types.ValueProto.DoubleList; -import feast.types.ValueProto.FloatList; -import feast.types.ValueProto.Int32List; -import feast.types.ValueProto.Int64List; -import feast.types.ValueProto.StringList; -import feast.types.ValueProto.Value; -import feast.types.ValueProto.ValueType; +import feast.types.ValueProto.*; import java.io.IOException; import java.net.DatagramPacket; import java.net.DatagramSocket; @@ -174,12 +167,15 @@ public static void publishFeatureRowsToKafka( /** * Create a Feature Row with random value according to the FeatureSetSpec * - *

See {@link #createRandomFeatureRow(FeatureSet, int)} + *

See {@link #createRandomFeatureRow(FeatureSetSpec, int)} + * + * @param featureSetSpec {@link FeatureSetSpec} + * @return {@link FeatureRow} */ - public static FeatureRow createRandomFeatureRow(FeatureSet featureSet) { + public static FeatureRow createRandomFeatureRow(FeatureSetSpec featureSetSpec) { ThreadLocalRandom random = ThreadLocalRandom.current(); int randomStringSizeMaxSize = 12; - return createRandomFeatureRow(featureSet, random.nextInt(0, randomStringSizeMaxSize) + 4); + return createRandomFeatureRow(featureSetSpec, random.nextInt(0, randomStringSizeMaxSize) + 4); } /** @@ -188,18 +184,18 @@ public static FeatureRow createRandomFeatureRow(FeatureSet featureSet) { *

The Feature Row created contains fields according to the entities and features defined in * FeatureSet, matching the value type of the field, with randomized value for testing. * - * @param featureSet {@link FeatureSet} + * @param featureSetSpec {@link FeatureSetSpec} * @param randomStringSize number of characters for the generated random string * @return {@link FeatureRow} */ - public static FeatureRow createRandomFeatureRow(FeatureSet featureSet, int randomStringSize) { + public static FeatureRow createRandomFeatureRow( + FeatureSetSpec featureSetSpec, int randomStringSize) { Builder builder = FeatureRow.newBuilder() - .setFeatureSet(getFeatureSetReference(featureSet)) + .setFeatureSet(getFeatureSetReference(featureSetSpec)) .setEventTimestamp(Timestamps.fromMillis(System.currentTimeMillis())); - featureSet - .getSpec() + featureSetSpec .getEntitiesList() .forEach( field -> { @@ -210,8 +206,7 @@ public static FeatureRow createRandomFeatureRow(FeatureSet featureSet, int rando .build()); }); - featureSet - .getSpec() + featureSetSpec .getFeaturesList() .forEach( field -> { @@ -301,15 +296,14 @@ public static Value createRandomValue(ValueType.Enum type, int randomStringSize) *

The entities in the created {@link RedisKey} will contain the value with matching field name * in the {@link FeatureRow} * - * @param featureSet {@link FeatureSet} + * @param featureSetSpec {@link FeatureSetSpec} * @param row {@link FeatureSet} * @return {@link RedisKey} */ - public static RedisKey createRedisKey(FeatureSet featureSet, FeatureRow row) { + public static RedisKey createRedisKey(FeatureSetSpec featureSetSpec, FeatureRow row) { RedisKey.Builder builder = - RedisKey.newBuilder().setFeatureSet(getFeatureSetReference(featureSet)); - featureSet - .getSpec() + RedisKey.newBuilder().setFeatureSet(getFeatureSetReference(featureSetSpec)); + featureSetSpec .getEntitiesList() .forEach( entityField -> @@ -452,7 +446,9 @@ public static void waitUntilAllElementsAreWrittenToStore( } String writeToStoreMetric = - WriteToStore.METRIC_NAMESPACE + ":" + WriteToStore.ELEMENTS_WRITTEN_METRIC; + WriteSuccessMetricsTransform.METRIC_NAMESPACE + + ":" + + WriteSuccessMetricsTransform.ELEMENTS_WRITTEN_METRIC; long committed = 0; long maxSystemTimeMillis = System.currentTimeMillis() + maxWaitDuration.getMillis(); diff --git a/pom.xml b/pom.xml index 3abb0eb9ace..649ef01865b 100644 --- a/pom.xml +++ b/pom.xml @@ -29,6 +29,8 @@ datatypes/java + storage/api + storage/connectors ingestion core serving diff --git a/serving/pom.xml b/serving/pom.xml index 4cc02dc4510..1390bfdc80c 100644 --- a/serving/pom.xml +++ b/serving/pom.xml @@ -76,6 +76,24 @@ ${project.version} + + dev.feast + feast-storage-api + ${project.version} + + + + dev.feast + feast-storage-connector-redis + ${project.version} + + + + dev.feast + feast-storage-connector-bigquery + ${project.version} + + org.slf4j @@ -114,6 +132,7 @@ io.github.lognet grpc-spring-boot-starter + org.springframework.boot @@ -136,17 +155,6 @@ protobuf-java-util - - io.pebbletemplates - pebble - 3.1.0 - - - - io.lettuce - lettuce-core - - com.google.guava @@ -180,12 +188,14 @@ simpleclient 0.8.0 + io.prometheus simpleclient_hotspot 0.8.0 + io.prometheus @@ -198,17 +208,6 @@ 0.8.0 - - - com.google.cloud - google-cloud-bigquery - - - - com.google.cloud - google-cloud-storage - - com.google.auto.value auto-value-annotations diff --git a/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java b/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java index d0ea058baf4..28df853e224 100644 --- a/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java +++ b/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java @@ -25,12 +25,12 @@ import feast.core.StoreProto.Store.RedisConfig; import feast.core.StoreProto.Store.Subscription; import feast.serving.FeastProperties; -import feast.serving.service.BigQueryServingService; -import feast.serving.service.JobService; -import feast.serving.service.NoopJobService; -import feast.serving.service.RedisServingService; -import feast.serving.service.ServingService; +import feast.serving.service.*; import feast.serving.specs.CachedSpecService; +import feast.storage.api.retriever.HistoricalRetriever; +import feast.storage.api.retriever.OnlineRetriever; +import feast.storage.connectors.bigquery.retriever.BigQueryHistoricalRetriever; +import feast.storage.connectors.redis.retriever.RedisOnlineRetriever; import io.opentracing.Tracer; import java.util.Map; import org.slf4j.Logger; @@ -79,9 +79,9 @@ public ServingService servingService( switch (store.getType()) { case REDIS: - servingService = - new RedisServingService( - storeConfiguration.getServingRedisConnection(), specService, tracer); + OnlineRetriever redisRetriever = + new RedisOnlineRetriever(storeConfiguration.getServingRedisConnection()); + servingService = new OnlineServingService(redisRetriever, specService, tracer); break; case BIGQUERY: BigQueryConfig bqConfig = store.getBigqueryConfig(); @@ -104,17 +104,20 @@ public ServingService servingService( throw new IllegalArgumentException( "Unable to instantiate jobService for BigQuery store."); } - servingService = - new BigQueryServingService( - bigquery, - bqConfig.getProjectId(), - bqConfig.getDatasetId(), - specService, - jobService, - jobStagingLocation, - feastProperties.getJobs().getBigqueryInitialRetryDelaySecs(), - feastProperties.getJobs().getBigqueryTotalTimeoutSecs(), - storage); + + HistoricalRetriever bqRetriever = + BigQueryHistoricalRetriever.builder() + .setBigquery(bigquery) + .setDatasetId(bqConfig.getDatasetId()) + .setProjectId(bqConfig.getProjectId()) + .setJobStagingLocation(jobStagingLocation) + .setInitialRetryDelaySecs( + feastProperties.getJobs().getBigqueryInitialRetryDelaySecs()) + .setTotalTimeoutSecs(feastProperties.getJobs().getBigqueryTotalTimeoutSecs()) + .setStorage(storage) + .build(); + + servingService = new HistoricalServingService(bqRetriever, specService, jobService); break; case CASSANDRA: case UNRECOGNIZED: diff --git a/serving/src/main/java/feast/serving/service/BigQueryServingService.java b/serving/src/main/java/feast/serving/service/BigQueryServingService.java deleted file mode 100644 index 8e3b7ae53e4..00000000000 --- a/serving/src/main/java/feast/serving/service/BigQueryServingService.java +++ /dev/null @@ -1,282 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.serving.service; - -import static feast.serving.store.bigquery.QueryTemplater.createEntityTableUUIDQuery; -import static feast.serving.store.bigquery.QueryTemplater.generateFullTableName; - -import com.google.cloud.RetryOption; -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.BigQueryException; -import com.google.cloud.bigquery.Field; -import com.google.cloud.bigquery.FormatOptions; -import com.google.cloud.bigquery.Job; -import com.google.cloud.bigquery.JobInfo; -import com.google.cloud.bigquery.LoadJobConfiguration; -import com.google.cloud.bigquery.QueryJobConfiguration; -import com.google.cloud.bigquery.Schema; -import com.google.cloud.bigquery.Table; -import com.google.cloud.bigquery.TableId; -import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.storage.Storage; -import feast.serving.ServingAPIProto; -import feast.serving.ServingAPIProto.DataFormat; -import feast.serving.ServingAPIProto.DatasetSource; -import feast.serving.ServingAPIProto.FeastServingType; -import feast.serving.ServingAPIProto.GetBatchFeaturesRequest; -import feast.serving.ServingAPIProto.GetBatchFeaturesResponse; -import feast.serving.ServingAPIProto.GetFeastServingInfoRequest; -import feast.serving.ServingAPIProto.GetFeastServingInfoResponse; -import feast.serving.ServingAPIProto.GetJobRequest; -import feast.serving.ServingAPIProto.GetJobResponse; -import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest; -import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; -import feast.serving.ServingAPIProto.JobStatus; -import feast.serving.ServingAPIProto.JobType; -import feast.serving.specs.CachedSpecService; -import feast.serving.specs.FeatureSetRequest; -import feast.serving.store.bigquery.BatchRetrievalQueryRunnable; -import feast.serving.store.bigquery.QueryTemplater; -import feast.serving.store.bigquery.model.FeatureSetInfo; -import io.grpc.Status; -import java.util.List; -import java.util.Optional; -import java.util.UUID; -import java.util.stream.Collectors; -import org.slf4j.Logger; -import org.threeten.bp.Duration; - -public class BigQueryServingService implements ServingService { - - public static final long TEMP_TABLE_EXPIRY_DURATION_MS = Duration.ofDays(1).toMillis(); - private static final Logger log = org.slf4j.LoggerFactory.getLogger(BigQueryServingService.class); - - private final BigQuery bigquery; - private final String projectId; - private final String datasetId; - private final CachedSpecService specService; - private final JobService jobService; - private final String jobStagingLocation; - private final int initialRetryDelaySecs; - private final int totalTimeoutSecs; - private final Storage storage; - - public BigQueryServingService( - BigQuery bigquery, - String projectId, - String datasetId, - CachedSpecService specService, - JobService jobService, - String jobStagingLocation, - int initialRetryDelaySecs, - int totalTimeoutSecs, - Storage storage) { - this.bigquery = bigquery; - this.projectId = projectId; - this.datasetId = datasetId; - this.specService = specService; - this.jobService = jobService; - this.jobStagingLocation = jobStagingLocation; - this.initialRetryDelaySecs = initialRetryDelaySecs; - this.totalTimeoutSecs = totalTimeoutSecs; - this.storage = storage; - } - - /** {@inheritDoc} */ - @Override - public GetFeastServingInfoResponse getFeastServingInfo( - GetFeastServingInfoRequest getFeastServingInfoRequest) { - return GetFeastServingInfoResponse.newBuilder() - .setType(FeastServingType.FEAST_SERVING_TYPE_BATCH) - .setJobStagingLocation(jobStagingLocation) - .build(); - } - - /** {@inheritDoc} */ - @Override - public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest getFeaturesRequest) { - throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); - } - - /** {@inheritDoc} */ - @Override - public GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeaturesRequest) { - List featureSetRequests = - specService.getFeatureSets(getFeaturesRequest.getFeaturesList()); - - Table entityTable; - String entityTableName; - try { - entityTable = loadEntities(getFeaturesRequest.getDatasetSource()); - - TableId entityTableWithUUIDs = generateUUIDs(entityTable); - entityTableName = generateFullTableName(entityTableWithUUIDs); - } catch (Exception e) { - throw Status.INTERNAL - .withDescription("Unable to load entity dataset to Bigquery") - .asRuntimeException(); - } - - Schema entityTableSchema = entityTable.getDefinition().getSchema(); - List entityNames = - entityTableSchema.getFields().stream() - .map(Field::getName) - .filter(name -> !name.equals("event_timestamp")) - .collect(Collectors.toList()); - - List featureSetInfos = QueryTemplater.getFeatureSetInfos(featureSetRequests); - - String feastJobId = UUID.randomUUID().toString(); - ServingAPIProto.Job feastJob = - ServingAPIProto.Job.newBuilder() - .setId(feastJobId) - .setType(JobType.JOB_TYPE_DOWNLOAD) - .setStatus(JobStatus.JOB_STATUS_PENDING) - .build(); - jobService.upsert(feastJob); - - new Thread( - BatchRetrievalQueryRunnable.builder() - .setEntityTableName(entityTableName) - .setBigquery(bigquery) - .setStorage(storage) - .setJobService(jobService) - .setProjectId(projectId) - .setDatasetId(datasetId) - .setFeastJobId(feastJobId) - .setEntityTableColumnNames(entityNames) - .setFeatureSetInfos(featureSetInfos) - .setJobStagingLocation(jobStagingLocation) - .setInitialRetryDelaySecs(initialRetryDelaySecs) - .setTotalTimeoutSecs(totalTimeoutSecs) - .build()) - .start(); - - return GetBatchFeaturesResponse.newBuilder().setJob(feastJob).build(); - } - - /** {@inheritDoc} */ - @Override - public GetJobResponse getJob(GetJobRequest getJobRequest) { - Optional job = jobService.get(getJobRequest.getJob().getId()); - if (!job.isPresent()) { - throw Status.NOT_FOUND - .withDescription(String.format("Job not found: %s", getJobRequest.getJob().getId())) - .asRuntimeException(); - } - return GetJobResponse.newBuilder().setJob(job.get()).build(); - } - - private Table loadEntities(DatasetSource datasetSource) { - Table loadedEntityTable; - switch (datasetSource.getDatasetSourceCase()) { - case FILE_SOURCE: - try { - // Currently only AVRO format is supported - - if (datasetSource.getFileSource().getDataFormat() != DataFormat.DATA_FORMAT_AVRO) { - throw Status.INVALID_ARGUMENT - .withDescription("Invalid file format, only AVRO is supported.") - .asRuntimeException(); - } - - TableId tableId = TableId.of(projectId, datasetId, createTempTableName()); - log.info("Loading entity rows to: {}.{}.{}", projectId, datasetId, tableId.getTable()); - - LoadJobConfiguration loadJobConfiguration = - LoadJobConfiguration.of( - tableId, datasetSource.getFileSource().getFileUrisList(), FormatOptions.avro()); - loadJobConfiguration = - loadJobConfiguration.toBuilder().setUseAvroLogicalTypes(true).build(); - Job job = bigquery.create(JobInfo.of(loadJobConfiguration)); - waitForJob(job); - - TableInfo expiry = - bigquery - .getTable(tableId) - .toBuilder() - .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) - .build(); - bigquery.update(expiry); - - loadedEntityTable = bigquery.getTable(tableId); - if (!loadedEntityTable.exists()) { - throw new RuntimeException( - "Unable to create entity dataset table, table already exists"); - } - return loadedEntityTable; - } catch (Exception e) { - log.error("Exception has occurred in loadEntities method: ", e); - throw Status.INTERNAL - .withDescription("Failed to load entity dataset into store: " + e.toString()) - .withCause(e) - .asRuntimeException(); - } - case DATASETSOURCE_NOT_SET: - default: - throw Status.INVALID_ARGUMENT - .withDescription("Data source must be set.") - .asRuntimeException(); - } - } - - private TableId generateUUIDs(Table loadedEntityTable) { - try { - String uuidQuery = - createEntityTableUUIDQuery(generateFullTableName(loadedEntityTable.getTableId())); - QueryJobConfiguration queryJobConfig = - QueryJobConfiguration.newBuilder(uuidQuery) - .setDestinationTable(TableId.of(projectId, datasetId, createTempTableName())) - .build(); - Job queryJob = bigquery.create(JobInfo.of(queryJobConfig)); - Job completedJob = waitForJob(queryJob); - TableInfo expiry = - bigquery - .getTable(queryJobConfig.getDestinationTable()) - .toBuilder() - .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) - .build(); - bigquery.update(expiry); - queryJobConfig = completedJob.getConfiguration(); - return queryJobConfig.getDestinationTable(); - } catch (InterruptedException | BigQueryException e) { - throw Status.INTERNAL - .withDescription("Failed to load entity dataset into store") - .withCause(e) - .asRuntimeException(); - } - } - - private Job waitForJob(Job queryJob) throws InterruptedException { - Job completedJob = - queryJob.waitFor( - RetryOption.initialRetryDelay(Duration.ofSeconds(initialRetryDelaySecs)), - RetryOption.totalTimeout(Duration.ofSeconds(totalTimeoutSecs))); - if (completedJob == null) { - throw Status.INTERNAL.withDescription("Job no longer exists").asRuntimeException(); - } else if (completedJob.getStatus().getError() != null) { - throw Status.INTERNAL - .withDescription("Job failed: " + completedJob.getStatus().getError()) - .asRuntimeException(); - } - return completedJob; - } - - public static String createTempTableName() { - return "_" + UUID.randomUUID().toString().replace("-", ""); - } -} diff --git a/serving/src/main/java/feast/serving/service/HistoricalServingService.java b/serving/src/main/java/feast/serving/service/HistoricalServingService.java new file mode 100644 index 00000000000..cc6df1b6b5a --- /dev/null +++ b/serving/src/main/java/feast/serving/service/HistoricalServingService.java @@ -0,0 +1,119 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.service; + +import feast.serving.ServingAPIProto; +import feast.serving.ServingAPIProto.*; +import feast.serving.ServingAPIProto.Job.Builder; +import feast.serving.specs.CachedSpecService; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.storage.api.retriever.HistoricalRetrievalResult; +import feast.storage.api.retriever.HistoricalRetriever; +import io.grpc.Status; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import org.slf4j.Logger; + +public class HistoricalServingService implements ServingService { + + private static final Logger log = + org.slf4j.LoggerFactory.getLogger(HistoricalServingService.class); + + private final HistoricalRetriever retriever; + private final CachedSpecService specService; + private final JobService jobService; + + public HistoricalServingService( + HistoricalRetriever retriever, CachedSpecService specService, JobService jobService) { + this.retriever = retriever; + this.specService = specService; + this.jobService = jobService; + } + + /** {@inheritDoc} */ + @Override + public GetFeastServingInfoResponse getFeastServingInfo( + GetFeastServingInfoRequest getFeastServingInfoRequest) { + return GetFeastServingInfoResponse.newBuilder() + .setType(FeastServingType.FEAST_SERVING_TYPE_BATCH) + .setJobStagingLocation(retriever.getStagingLocation()) + .build(); + } + + /** {@inheritDoc} */ + @Override + public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest getFeaturesRequest) { + throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); + } + + /** {@inheritDoc} */ + @Override + public GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeaturesRequest) { + List featureSetRequests = + specService.getFeatureSets(getFeaturesRequest.getFeaturesList()); + String retrievalId = UUID.randomUUID().toString(); + Job runningJob = + Job.newBuilder() + .setId(retrievalId) + .setType(JobType.JOB_TYPE_DOWNLOAD) + .setStatus(JobStatus.JOB_STATUS_RUNNING) + .build(); + jobService.upsert(runningJob); + Thread thread = + new Thread( + new Runnable() { + @Override + public void run() { + HistoricalRetrievalResult result = + retriever.getHistoricalFeatures( + retrievalId, getFeaturesRequest.getDatasetSource(), featureSetRequests); + jobService.upsert(resultToJob(result)); + } + }); + thread.start(); + + return GetBatchFeaturesResponse.newBuilder().setJob(runningJob).build(); + } + + /** {@inheritDoc} */ + @Override + public GetJobResponse getJob(GetJobRequest getJobRequest) { + Optional job = jobService.get(getJobRequest.getJob().getId()); + if (!job.isPresent()) { + throw Status.NOT_FOUND + .withDescription(String.format("Job not found: %s", getJobRequest.getJob().getId())) + .asRuntimeException(); + } + return GetJobResponse.newBuilder().setJob(job.get()).build(); + } + + private Job resultToJob(HistoricalRetrievalResult result) { + Builder builder = + Job.newBuilder() + .setId(result.getId()) + .setType(JobType.JOB_TYPE_DOWNLOAD) + .setStatus(result.getStatus()); + if (result.hasError()) { + return builder.setError(result.getError()).build(); + } + return builder + .addAllFileUris(result.getFileUris()) + .setDataFormat(result.getDataFormat()) + .build(); + } +} diff --git a/serving/src/main/java/feast/serving/service/OnlineServingService.java b/serving/src/main/java/feast/serving/service/OnlineServingService.java new file mode 100644 index 00000000000..30addd2b9f2 --- /dev/null +++ b/serving/src/main/java/feast/serving/service/OnlineServingService.java @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.service; + +import com.google.common.collect.Maps; +import com.google.protobuf.Duration; +import feast.serving.ServingAPIProto.*; +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues; +import feast.serving.specs.CachedSpecService; +import feast.serving.util.Metrics; +import feast.serving.util.RefUtil; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.storage.api.retriever.OnlineRetriever; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.ValueProto.Value; +import io.grpc.Status; +import io.opentracing.Scope; +import io.opentracing.Tracer; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +import org.slf4j.Logger; + +public class OnlineServingService implements ServingService { + + private static final Logger log = org.slf4j.LoggerFactory.getLogger(OnlineServingService.class); + private final CachedSpecService specService; + private final Tracer tracer; + private final OnlineRetriever retriever; + + public OnlineServingService( + OnlineRetriever retriever, CachedSpecService specService, Tracer tracer) { + this.retriever = retriever; + this.specService = specService; + this.tracer = tracer; + } + + /** {@inheritDoc} */ + @Override + public GetFeastServingInfoResponse getFeastServingInfo( + GetFeastServingInfoRequest getFeastServingInfoRequest) { + return GetFeastServingInfoResponse.newBuilder() + .setType(FeastServingType.FEAST_SERVING_TYPE_ONLINE) + .build(); + } + + /** {@inheritDoc} */ + @Override + public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest request) { + try (Scope scope = tracer.buildSpan("getOnlineFeatures").startActive(true)) { + GetOnlineFeaturesResponse.Builder getOnlineFeaturesResponseBuilder = + GetOnlineFeaturesResponse.newBuilder(); + List featureSetRequests = + specService.getFeatureSets(request.getFeaturesList()); + List entityRows = request.getEntityRowsList(); + Map> featureValuesMap = + entityRows.stream() + .collect(Collectors.toMap(row -> row, row -> Maps.newHashMap(row.getFieldsMap()))); + // Get all feature rows from the retriever. Each feature row list corresponds to a single + // feature set request. + List> featureRows = + retriever.getOnlineFeatures(entityRows, featureSetRequests); + + // For each feature set request, read the feature rows returned by the retriever, and + // populate the featureValuesMap with the feature values corresponding to that entity row. + for (var fsIdx = 0; fsIdx < featureRows.size(); fsIdx++) { + List featureRowsForFs = featureRows.get(fsIdx); + FeatureSetRequest featureSetRequest = featureSetRequests.get(fsIdx); + + String project = featureSetRequest.getSpec().getProject(); + + // In order to return values containing the same feature references provided by the user, + // we reuse the feature references in the request as the keys in the featureValuesMap + Map refsByName = featureSetRequest.getFeatureRefsByName(); + + // Each feature row returned (per feature set request) corresponds to a given entity row. + // For each feature row, update the featureValuesMap. + for (var entityRowIdx = 0; entityRowIdx < entityRows.size(); entityRowIdx++) { + FeatureRow featureRow = featureRowsForFs.get(entityRowIdx); + EntityRow entityRow = entityRows.get(entityRowIdx); + + // If the row is stale, put an empty value into the featureValuesMap. + if (isStale(featureSetRequest, entityRow, featureRow)) { + featureSetRequest + .getFeatureReferences() + .parallelStream() + .forEach( + ref -> { + populateStaleKeyCountMetrics(project, ref); + featureValuesMap + .get(entityRow) + .put(RefUtil.generateFeatureStringRef(ref), Value.newBuilder().build()); + }); + + } else { + populateRequestCountMetrics(featureSetRequest); + + // Else populate the featureValueMap at this entityRow with the values in the feature + // row. + featureRow.getFieldsList().stream() + .filter(field -> refsByName.containsKey(field.getName())) + .forEach( + field -> { + FeatureReference ref = refsByName.get(field.getName()); + String id = RefUtil.generateFeatureStringRef(ref); + featureValuesMap.get(entityRow).put(id, field.getValue()); + }); + } + } + } + + List fieldValues = + featureValuesMap.values().stream() + .map(valueMap -> FieldValues.newBuilder().putAllFields(valueMap).build()) + .collect(Collectors.toList()); + return getOnlineFeaturesResponseBuilder.addAllFieldValues(fieldValues).build(); + } + } + + private void populateStaleKeyCountMetrics(String project, FeatureReference ref) { + Metrics.staleKeyCount + .labels(project, RefUtil.generateFeatureStringRefWithoutProject(ref)) + .inc(); + } + + private void populateRequestCountMetrics(FeatureSetRequest featureSetRequest) { + String project = featureSetRequest.getSpec().getProject(); + featureSetRequest + .getFeatureReferences() + .parallelStream() + .forEach( + ref -> + Metrics.requestCount + .labels(project, RefUtil.generateFeatureStringRefWithoutProject(ref)) + .inc()); + } + + @Override + public GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeaturesRequest) { + throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); + } + + @Override + public GetJobResponse getJob(GetJobRequest getJobRequest) { + throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); + } + + private boolean isStale( + FeatureSetRequest featureSetRequest, EntityRow entityRow, FeatureRow featureRow) { + Duration maxAge = featureSetRequest.getSpec().getMaxAge(); + if (maxAge.equals(Duration.getDefaultInstance())) { + return false; + } + long givenTimestamp = entityRow.getEntityTimestamp().getSeconds(); + if (givenTimestamp == 0) { + givenTimestamp = System.currentTimeMillis() / 1000; + } + long timeDifference = givenTimestamp - featureRow.getEventTimestamp().getSeconds(); + return timeDifference > maxAge.getSeconds(); + } +} diff --git a/serving/src/main/java/feast/serving/service/RedisServingService.java b/serving/src/main/java/feast/serving/service/RedisServingService.java deleted file mode 100644 index 78d9d9cebe4..00000000000 --- a/serving/src/main/java/feast/serving/service/RedisServingService.java +++ /dev/null @@ -1,345 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.serving.service; - -import static feast.serving.util.Metrics.invalidEncodingCount; -import static feast.serving.util.Metrics.missingKeyCount; -import static feast.serving.util.Metrics.requestCount; -import static feast.serving.util.Metrics.requestLatency; -import static feast.serving.util.Metrics.staleKeyCount; -import static feast.serving.util.RefUtil.generateFeatureSetStringRef; -import static feast.serving.util.RefUtil.generateFeatureStringRef; - -import com.google.common.collect.Maps; -import com.google.protobuf.AbstractMessageLite; -import com.google.protobuf.Duration; -import com.google.protobuf.InvalidProtocolBufferException; -import feast.core.FeatureSetProto.EntitySpec; -import feast.core.FeatureSetProto.FeatureSetSpec; -import feast.serving.ServingAPIProto.FeastServingType; -import feast.serving.ServingAPIProto.FeatureReference; -import feast.serving.ServingAPIProto.GetBatchFeaturesRequest; -import feast.serving.ServingAPIProto.GetBatchFeaturesResponse; -import feast.serving.ServingAPIProto.GetFeastServingInfoRequest; -import feast.serving.ServingAPIProto.GetFeastServingInfoResponse; -import feast.serving.ServingAPIProto.GetJobRequest; -import feast.serving.ServingAPIProto.GetJobResponse; -import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest; -import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; -import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; -import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues; -import feast.serving.encoding.FeatureRowDecoder; -import feast.serving.specs.CachedSpecService; -import feast.serving.specs.FeatureSetRequest; -import feast.serving.util.RefUtil; -import feast.storage.RedisProto.RedisKey; -import feast.types.FeatureRowProto.FeatureRow; -import feast.types.FieldProto.Field; -import feast.types.ValueProto.Value; -import io.grpc.Status; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.api.sync.RedisCommands; -import io.opentracing.Scope; -import io.opentracing.Tracer; -import java.util.List; -import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.stream.Collectors; -import org.slf4j.Logger; - -public class RedisServingService implements ServingService { - - private static final Logger log = org.slf4j.LoggerFactory.getLogger(RedisServingService.class); - private final CachedSpecService specService; - private final Tracer tracer; - private final RedisCommands syncCommands; - - public RedisServingService( - StatefulRedisConnection connection, - CachedSpecService specService, - Tracer tracer) { - this.syncCommands = connection.sync(); - this.specService = specService; - this.tracer = tracer; - } - - /** {@inheritDoc} */ - @Override - public GetFeastServingInfoResponse getFeastServingInfo( - GetFeastServingInfoRequest getFeastServingInfoRequest) { - return GetFeastServingInfoResponse.newBuilder() - .setType(FeastServingType.FEAST_SERVING_TYPE_ONLINE) - .build(); - } - - /** {@inheritDoc} */ - @Override - public GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest request) { - try (Scope scope = tracer.buildSpan("Redis-getOnlineFeatures").startActive(true)) { - GetOnlineFeaturesResponse.Builder getOnlineFeaturesResponseBuilder = - GetOnlineFeaturesResponse.newBuilder(); - - List entityRows = request.getEntityRowsList(); - Map> featureValuesMap = - entityRows.stream() - .collect(Collectors.toMap(row -> row, row -> Maps.newHashMap(row.getFieldsMap()))); - List featureSetRequests = - specService.getFeatureSets(request.getFeaturesList()); - for (FeatureSetRequest featureSetRequest : featureSetRequests) { - - List featureSetEntityNames = - featureSetRequest.getSpec().getEntitiesList().stream() - .map(EntitySpec::getName) - .collect(Collectors.toList()); - - List redisKeys = - getRedisKeys(featureSetEntityNames, entityRows, featureSetRequest.getSpec()); - - try { - sendAndProcessMultiGet(redisKeys, entityRows, featureValuesMap, featureSetRequest); - } catch (InvalidProtocolBufferException | ExecutionException e) { - throw Status.INTERNAL - .withDescription("Unable to parse protobuf while retrieving feature") - .withCause(e) - .asRuntimeException(); - } - } - List fieldValues = - featureValuesMap.values().stream() - .map(valueMap -> FieldValues.newBuilder().putAllFields(valueMap).build()) - .collect(Collectors.toList()); - return getOnlineFeaturesResponseBuilder.addAllFieldValues(fieldValues).build(); - } - } - - @Override - public GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeaturesRequest) { - throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); - } - - @Override - public GetJobResponse getJob(GetJobRequest getJobRequest) { - throw Status.UNIMPLEMENTED.withDescription("Method not implemented").asRuntimeException(); - } - - /** - * Build the redis keys for retrieval from the store. - * - * @param featureSetEntityNames entity names that actually belong to the featureSet - * @param entityRows entity values to retrieve for - * @param featureSetSpec featureSetSpec of the features to retrieve - * @return list of RedisKeys - */ - private List getRedisKeys( - List featureSetEntityNames, - List entityRows, - FeatureSetSpec featureSetSpec) { - try (Scope scope = tracer.buildSpan("Redis-makeRedisKeys").startActive(true)) { - String featureSetRef = generateFeatureSetStringRef(featureSetSpec); - List redisKeys = - entityRows.stream() - .map(row -> makeRedisKey(featureSetRef, featureSetEntityNames, row)) - .collect(Collectors.toList()); - return redisKeys; - } - } - - /** - * Create {@link RedisKey} - * - * @param featureSet featureSet reference of the feature. E.g. feature_set_1:1 - * @param featureSetEntityNames entity names that belong to the featureSet - * @param entityRow entityRow to build the key from - * @return {@link RedisKey} - */ - private RedisKey makeRedisKey( - String featureSet, List featureSetEntityNames, EntityRow entityRow) { - RedisKey.Builder builder = RedisKey.newBuilder().setFeatureSet(featureSet); - Map fieldsMap = entityRow.getFieldsMap(); - featureSetEntityNames.sort(String::compareTo); - for (int i = 0; i < featureSetEntityNames.size(); i++) { - String entityName = featureSetEntityNames.get(i); - - if (!fieldsMap.containsKey(entityName)) { - throw Status.INVALID_ARGUMENT - .withDescription( - String.format( - "Entity row fields \"%s\" does not contain required entity field \"%s\"", - fieldsMap.keySet().toString(), entityName)) - .asRuntimeException(); - } - - builder.addEntities( - Field.newBuilder().setName(entityName).setValue(fieldsMap.get(entityName))); - } - return builder.build(); - } - - private void sendAndProcessMultiGet( - List redisKeys, - List entityRows, - Map> featureValuesMap, - FeatureSetRequest featureSetRequest) - throws InvalidProtocolBufferException, ExecutionException { - - List values = sendMultiGet(redisKeys); - long startTime = System.currentTimeMillis(); - try (Scope scope = tracer.buildSpan("Redis-processResponse").startActive(true)) { - FeatureSetSpec spec = featureSetRequest.getSpec(); - - Map nullValues = - featureSetRequest.getFeatureReferences().stream() - .collect( - Collectors.toMap( - RefUtil::generateFeatureStringRef, - featureReference -> Value.newBuilder().build())); - - for (int i = 0; i < values.size(); i++) { - EntityRow entityRow = entityRows.get(i); - Map featureValues = featureValuesMap.get(entityRow); - - byte[] value = values.get(i); - if (value == null) { - featureSetRequest - .getFeatureReferences() - .parallelStream() - .forEach( - request -> - missingKeyCount - .labels( - spec.getProject(), - String.format("%s:%d", request.getName(), request.getVersion())) - .inc()); - featureValues.putAll(nullValues); - continue; - } - - FeatureRow featureRow = FeatureRow.parseFrom(value); - String featureSetRef = redisKeys.get(i).getFeatureSet(); - FeatureRowDecoder decoder = - new FeatureRowDecoder(featureSetRef, specService.getFeatureSetSpec(featureSetRef)); - if (decoder.isEncoded(featureRow)) { - if (decoder.isEncodingValid(featureRow)) { - featureRow = decoder.decode(featureRow); - } else { - featureSetRequest - .getFeatureReferences() - .parallelStream() - .forEach( - request -> - invalidEncodingCount - .labels( - spec.getProject(), - String.format("%s:%d", request.getName(), request.getVersion())) - .inc()); - featureValues.putAll(nullValues); - continue; - } - } - - boolean stale = isStale(featureSetRequest, entityRow, featureRow); - if (stale) { - featureSetRequest - .getFeatureReferences() - .parallelStream() - .forEach( - request -> - staleKeyCount - .labels( - spec.getProject(), - String.format("%s:%d", request.getName(), request.getVersion())) - .inc()); - featureValues.putAll(nullValues); - continue; - } - - featureSetRequest - .getFeatureReferences() - .parallelStream() - .forEach( - request -> - requestCount - .labels( - spec.getProject(), - String.format("%s:%d", request.getName(), request.getVersion())) - .inc()); - - Map featureNames = - featureSetRequest.getFeatureReferences().stream() - .collect( - Collectors.toMap( - FeatureReference::getName, featureReference -> featureReference)); - featureRow.getFieldsList().stream() - .filter(field -> featureNames.keySet().contains(field.getName())) - .forEach( - field -> { - FeatureReference ref = featureNames.get(field.getName()); - String id = generateFeatureStringRef(ref); - featureValues.put(id, field.getValue()); - }); - } - } finally { - requestLatency - .labels("processResponse") - .observe((System.currentTimeMillis() - startTime) / 1000); - } - } - - private boolean isStale( - FeatureSetRequest featureSetRequest, EntityRow entityRow, FeatureRow featureRow) { - if (featureSetRequest.getSpec().getMaxAge().equals(Duration.getDefaultInstance())) { - return false; - } - long givenTimestamp = entityRow.getEntityTimestamp().getSeconds(); - if (givenTimestamp == 0) { - givenTimestamp = System.currentTimeMillis() / 1000; - } - long timeDifference = givenTimestamp - featureRow.getEventTimestamp().getSeconds(); - return timeDifference > featureSetRequest.getSpec().getMaxAge().getSeconds(); - } - - /** - * Send a list of get request as an mget - * - * @param keys list of {@link RedisKey} - * @return list of {@link FeatureRow} in primitive byte representation for each {@link RedisKey} - */ - private List sendMultiGet(List keys) { - try (Scope scope = tracer.buildSpan("Redis-sendMultiGet").startActive(true)) { - long startTime = System.currentTimeMillis(); - try { - byte[][] binaryKeys = - keys.stream() - .map(AbstractMessageLite::toByteArray) - .collect(Collectors.toList()) - .toArray(new byte[0][0]); - return syncCommands.mget(binaryKeys).stream() - .map(keyValue -> keyValue.getValueOrElse(null)) - .collect(Collectors.toList()); - } catch (Exception e) { - throw Status.NOT_FOUND - .withDescription("Unable to retrieve feature from Redis") - .withCause(e) - .asRuntimeException(); - } finally { - requestLatency - .labels("sendMultiGet") - .observe((System.currentTimeMillis() - startTime) / 1000d); - } - } - } -} diff --git a/serving/src/main/java/feast/serving/service/ServingService.java b/serving/src/main/java/feast/serving/service/ServingService.java index 83adcb73ba8..5e662229eeb 100644 --- a/serving/src/main/java/feast/serving/service/ServingService.java +++ b/serving/src/main/java/feast/serving/service/ServingService.java @@ -26,12 +26,75 @@ import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; public interface ServingService { + /** + * Get information about the Feast serving deployment. + * + *

For Bigquery deployments, this includes the default job staging location to load + * intermediate files to. Otherwise, this method only returns the current Feast Serving backing + * store type. + * + * @param getFeastServingInfoRequest {@link GetFeastServingInfoRequest} + * @return {@link GetFeastServingInfoResponse} + */ GetFeastServingInfoResponse getFeastServingInfo( GetFeastServingInfoRequest getFeastServingInfoRequest); + /** + * Get features from an online serving store, given a list of {@link + * feast.serving.ServingAPIProto.FeatureReference}s to retrieve, and list of {@link + * feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow}s to join the retrieved values + * to. + * + *

Features can be queried across feature sets, but each {@link + * feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow} must contain all entities for + * all feature sets included in the request. + * + *

This request is fulfilled synchronously. + * + * @param getFeaturesRequest {@link GetOnlineFeaturesRequest} containing list of {@link + * feast.serving.ServingAPIProto.FeatureReference}s to retrieve and list of {@link + * feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow}s to join the retrieved + * values to. + * @return {@link GetOnlineFeaturesResponse} with list of {@link + * feast.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues} for each {@link + * feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow} supplied. + */ GetOnlineFeaturesResponse getOnlineFeatures(GetOnlineFeaturesRequest getFeaturesRequest); + /** + * Get features from a batch serving store, given a list of {@link + * feast.serving.ServingAPIProto.FeatureReference}s to retrieve, and {@link + * feast.serving.ServingAPIProto.DatasetSource} pointing to remote location of dataset to join + * retrieved features to. All columns in the provided dataset will be preserved in the output + * dataset. + * + *

Due to the potential size of batch retrieval requests, this request is fulfilled + * asynchronously, and returns a retrieval job id, which when supplied to {@link + * #getJob(GetJobRequest)} will return the status of the retrieval job. + * + * @param getFeaturesRequest {@link GetBatchFeaturesRequest} containing a list of {@link + * feast.serving.ServingAPIProto.FeatureReference}s to retrieve, and {@link + * feast.serving.ServingAPIProto.DatasetSource} pointing to remote location of dataset to join + * retrieved features to. + * @return {@link GetBatchFeaturesResponse} containing reference to a retrieval {@link + * feast.serving.ServingAPIProto.Job}. + */ GetBatchFeaturesResponse getBatchFeatures(GetBatchFeaturesRequest getFeaturesRequest); + /** + * Get the status of a retrieval job from a batch serving store. + * + *

The client should check the status of the returned job periodically by calling ReloadJob to + * determine if the job has completed successfully or with an error. If the job completes + * successfully i.e. status = JOB_STATUS_DONE with no error, then the client can check the + * file_uris for the location to download feature values data. The client is assumed to have + * access to these file URIs. + * + *

If an error occurred during retrieval, the {@link GetJobResponse} will also contain the + * error that resulted in termination. + * + * @param getJobRequest {@link GetJobRequest} containing reference to a retrieval job + * @return {@link GetJobResponse} + */ GetJobResponse getJob(GetJobRequest getJobRequest); } diff --git a/serving/src/main/java/feast/serving/specs/CachedSpecService.java b/serving/src/main/java/feast/serving/specs/CachedSpecService.java index 12a8242da13..47f4934d52c 100644 --- a/serving/src/main/java/feast/serving/specs/CachedSpecService.java +++ b/serving/src/main/java/feast/serving/specs/CachedSpecService.java @@ -36,6 +36,7 @@ import feast.core.StoreProto.Store.Subscription; import feast.serving.ServingAPIProto.FeatureReference; import feast.serving.exception.SpecRetrievalException; +import feast.storage.api.retriever.FeatureSetRequest; import io.grpc.StatusRuntimeException; import io.prometheus.client.Gauge; import java.io.IOException; diff --git a/serving/src/main/java/feast/serving/util/RefUtil.java b/serving/src/main/java/feast/serving/util/RefUtil.java index 74de3e65620..c3bcb0827a2 100644 --- a/serving/src/main/java/feast/serving/util/RefUtil.java +++ b/serving/src/main/java/feast/serving/util/RefUtil.java @@ -28,6 +28,14 @@ public static String generateFeatureStringRef(FeatureReference featureReference) return ref; } + public static String generateFeatureStringRefWithoutProject(FeatureReference featureReference) { + String ref = String.format("%s", featureReference.getName()); + if (featureReference.getVersion() > 0) { + return ref + String.format(":%d", featureReference.getVersion()); + } + return ref; + } + public static String generateFeatureSetStringRef(FeatureSetSpec featureSetSpec) { String ref = String.format("%s/%s", featureSetSpec.getProject(), featureSetSpec.getName()); if (featureSetSpec.getVersion() > 0) { diff --git a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java index abeb44bd731..01c9304bda0 100644 --- a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java +++ b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java @@ -37,7 +37,7 @@ import feast.serving.ServingAPIProto.FeatureReference; import feast.serving.specs.CachedSpecService; import feast.serving.specs.CoreSpecService; -import feast.serving.specs.FeatureSetRequest; +import feast.storage.api.retriever.FeatureSetRequest; import java.io.BufferedWriter; import java.io.File; import java.io.FileWriter; diff --git a/serving/src/test/java/feast/serving/service/RedisServingServiceTest.java b/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java similarity index 72% rename from serving/src/test/java/feast/serving/service/RedisServingServiceTest.java rename to serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java index 05a24d3fe6a..b78fcb69170 100644 --- a/serving/src/test/java/feast/serving/service/RedisServingServiceTest.java +++ b/serving/src/test/java/feast/serving/service/OnlineServingServiceTest.java @@ -22,7 +22,6 @@ import static org.mockito.MockitoAnnotations.initMocks; import com.google.common.collect.Lists; -import com.google.protobuf.AbstractMessageLite; import com.google.protobuf.Duration; import com.google.protobuf.Timestamp; import feast.core.FeatureSetProto.EntitySpec; @@ -33,17 +32,16 @@ import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse.FieldValues; import feast.serving.specs.CachedSpecService; -import feast.serving.specs.FeatureSetRequest; -import feast.storage.RedisProto.RedisKey; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.storage.connectors.redis.retriever.RedisOnlineRetriever; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FieldProto.Field; import feast.types.ValueProto.Value; -import io.lettuce.core.KeyValue; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.api.sync.RedisCommands; import io.opentracing.Tracer; import io.opentracing.Tracer.SpanBuilder; -import java.util.*; +import java.util.Collections; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.junit.Before; import org.junit.Test; @@ -51,44 +49,20 @@ import org.mockito.Mock; import org.mockito.Mockito; -public class RedisServingServiceTest { +public class OnlineServingServiceTest { @Mock CachedSpecService specService; @Mock Tracer tracer; - @Mock StatefulRedisConnection connection; + @Mock RedisOnlineRetriever retriever; - @Mock RedisCommands syncCommands; - - private RedisServingService redisServingService; - private byte[][] redisKeyList; + private OnlineServingService onlineServingService; @Before public void setUp() { initMocks(this); - when(connection.sync()).thenReturn(syncCommands); - redisServingService = new RedisServingService(connection, specService, tracer); - redisKeyList = - Lists.newArrayList( - RedisKey.newBuilder() - .setFeatureSet("project/featureSet:1") - .addAllEntities( - Lists.newArrayList( - Field.newBuilder().setName("entity1").setValue(intValue(1)).build(), - Field.newBuilder().setName("entity2").setValue(strValue("a")).build())) - .build(), - RedisKey.newBuilder() - .setFeatureSet("project/featureSet:1") - .addAllEntities( - Lists.newArrayList( - Field.newBuilder().setName("entity1").setValue(intValue(2)).build(), - Field.newBuilder().setName("entity2").setValue(strValue("b")).build())) - .build()) - .stream() - .map(AbstractMessageLite::toByteArray) - .collect(Collectors.toList()) - .toArray(new byte[0][0]); + onlineServingService = new OnlineServingService(retriever, specService, tracer); } @Test @@ -148,14 +122,11 @@ public void shouldReturnResponseWithValuesIfKeysPresent() { .setSpec(getFeatureSetSpec()) .build(); - List> featureRowBytes = - featureRows.stream() - .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) - .collect(Collectors.toList()); when(specService.getFeatureSets(request.getFeaturesList())) .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + when(retriever.getOnlineFeatures( + request.getEntityRowsList(), Collections.singletonList(featureSetRequest))) + .thenReturn(Collections.singletonList(featureRows)); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); GetOnlineFeaturesResponse expected = @@ -173,100 +144,13 @@ public void shouldReturnResponseWithValuesIfKeysPresent() { .putFields("project/feature1:1", intValue(2)) .putFields("project/feature2:1", intValue(2))) .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); + GetOnlineFeaturesResponse actual = onlineServingService.getOnlineFeatures(request); assertThat( responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); } @Test - public void shouldReturnResponseWithValuesWhenFeatureSetSpecHasUnspecifiedMaxAge() { - GetOnlineFeaturesRequest request = - GetOnlineFeaturesRequest.newBuilder() - .addFeatures( - FeatureReference.newBuilder() - .setName("feature1") - .setVersion(1) - .setProject("project") - .build()) - .addFeatures( - FeatureReference.newBuilder() - .setName("feature2") - .setVersion(1) - .setProject("project") - .build()) - .addEntityRows( - EntityRow.newBuilder() - .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) - .putFields("entity1", intValue(1)) - .putFields("entity2", strValue("a"))) - .addEntityRows( - EntityRow.newBuilder() - .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) - .putFields("entity1", intValue(2)) - .putFields("entity2", strValue("b"))) - .build(); - - List featureRows = - Lists.newArrayList( - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(2)) // much older timestamp - .addAllFields( - Lists.newArrayList( - Field.newBuilder().setName("entity1").setValue(intValue(1)).build(), - Field.newBuilder().setName("entity2").setValue(strValue("a")).build(), - Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), - Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) - .setFeatureSet("featureSet:1") - .build(), - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(15)) // much older timestamp - .addAllFields( - Lists.newArrayList( - Field.newBuilder().setName("entity1").setValue(intValue(2)).build(), - Field.newBuilder().setName("entity2").setValue(strValue("b")).build(), - Field.newBuilder().setName("feature1").setValue(intValue(2)).build(), - Field.newBuilder().setName("feature2").setValue(intValue(2)).build())) - .setFeatureSet("featureSet:1") - .build()); - - FeatureSetRequest featureSetRequest = - FeatureSetRequest.newBuilder() - .addAllFeatureReferences(request.getFeaturesList()) - .setSpec(getFeatureSetSpecWithNoMaxAge()) - .build(); - - List> featureRowBytes = - featureRows.stream() - .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) - .collect(Collectors.toList()); - when(specService.getFeatureSets(request.getFeaturesList())) - .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); - when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); - - GetOnlineFeaturesResponse expected = - GetOnlineFeaturesResponse.newBuilder() - .addFieldValues( - FieldValues.newBuilder() - .putFields("entity1", intValue(1)) - .putFields("entity2", strValue("a")) - .putFields("project/feature1:1", intValue(1)) - .putFields("project/feature2:1", intValue(1))) - .addFieldValues( - FieldValues.newBuilder() - .putFields("entity1", intValue(2)) - .putFields("entity2", strValue("b")) - .putFields("project/feature1:1", intValue(2)) - .putFields("project/feature2:1", intValue(2))) - .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); - assertThat( - responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); - } - - @Test - public void shouldReturnKeysWithoutVersionifNotProvided() { + public void shouldReturnKeysWithoutVersionIfNotProvided() { GetOnlineFeaturesRequest request = GetOnlineFeaturesRequest.newBuilder() .addFeatures( @@ -318,14 +202,11 @@ public void shouldReturnKeysWithoutVersionifNotProvided() { .setSpec(getFeatureSetSpec()) .build(); - List> featureRowBytes = - featureRows.stream() - .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) - .collect(Collectors.toList()); when(specService.getFeatureSets(request.getFeaturesList())) .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + when(retriever.getOnlineFeatures( + request.getEntityRowsList(), Collections.singletonList(featureSetRequest))) + .thenReturn(Collections.singletonList(featureRows)); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); GetOnlineFeaturesResponse expected = @@ -343,7 +224,7 @@ public void shouldReturnKeysWithoutVersionifNotProvided() { .putFields("project/feature1:1", intValue(2)) .putFields("project/feature2", intValue(2))) .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); + GetOnlineFeaturesResponse actual = onlineServingService.getOnlineFeatures(request); assertThat( responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); } @@ -383,27 +264,29 @@ public void shouldReturnResponseWithUnsetValuesIfKeysNotPresent() { .setSpec(getFeatureSetSpec()) .build(); - FeatureRow featureRowPresent = - FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) - .addAllFields( - Lists.newArrayList( - Field.newBuilder().setName("entity1").setValue(intValue(1)).build(), - Field.newBuilder().setName("entity2").setValue(strValue("a")).build(), - Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), - Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) - .setFeatureSet("featureSet:1") - .build(); - - List> featureRowBytes = + List featureRows = Lists.newArrayList( - KeyValue.from(new byte[1], Optional.of(featureRowPresent.toByteArray())), - KeyValue.from(new byte[1], Optional.empty())); + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").build(), + Field.newBuilder().setName("feature2").build())) + .build()); when(specService.getFeatureSets(request.getFeaturesList())) .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + when(retriever.getOnlineFeatures( + request.getEntityRowsList(), Collections.singletonList(featureSetRequest))) + .thenReturn(Collections.singletonList(featureRows)); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); GetOnlineFeaturesResponse expected = @@ -421,7 +304,7 @@ public void shouldReturnResponseWithUnsetValuesIfKeysNotPresent() { .putFields("project/feature1:1", Value.newBuilder().build()) .putFields("project/feature2:1", Value.newBuilder().build())) .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); + GetOnlineFeaturesResponse actual = onlineServingService.getOnlineFeatures(request); assertThat( responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); } @@ -487,14 +370,11 @@ public void shouldReturnResponseWithUnsetValuesIfMaxAgeIsExceeded() { .setSpec(spec) .build(); - List> featureRowBytes = - featureRows.stream() - .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) - .collect(Collectors.toList()); when(specService.getFeatureSets(request.getFeaturesList())) .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + when(retriever.getOnlineFeatures( + request.getEntityRowsList(), Collections.singletonList(featureSetRequest))) + .thenReturn(Collections.singletonList(featureRows)); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); GetOnlineFeaturesResponse expected = @@ -512,7 +392,7 @@ public void shouldReturnResponseWithUnsetValuesIfMaxAgeIsExceeded() { .putFields("project/feature1:1", Value.newBuilder().build()) .putFields("project/feature2:1", Value.newBuilder().build())) .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); + GetOnlineFeaturesResponse actual = onlineServingService.getOnlineFeatures(request); assertThat( responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); } @@ -569,14 +449,11 @@ public void shouldFilterOutUndesiredRows() { .setSpec(getFeatureSetSpec()) .build(); - List> featureRowBytes = - featureRows.stream() - .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) - .collect(Collectors.toList()); when(specService.getFeatureSets(request.getFeaturesList())) .thenReturn(Collections.singletonList(featureSetRequest)); - when(connection.sync()).thenReturn(syncCommands); - when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + when(retriever.getOnlineFeatures( + request.getEntityRowsList(), Collections.singletonList(featureSetRequest))) + .thenReturn(Collections.singletonList(featureRows)); when(tracer.buildSpan(ArgumentMatchers.any())).thenReturn(Mockito.mock(SpanBuilder.class)); GetOnlineFeaturesResponse expected = @@ -592,7 +469,7 @@ public void shouldFilterOutUndesiredRows() { .putFields("entity2", strValue("b")) .putFields("project/feature1:1", intValue(2))) .build(); - GetOnlineFeaturesResponse actual = redisServingService.getOnlineFeatures(request); + GetOnlineFeaturesResponse actual = onlineServingService.getOnlineFeatures(request); assertThat( responseToMapList(actual), containsInAnyOrder(responseToMapList(expected).toArray())); } diff --git a/storage/api/pom.xml b/storage/api/pom.xml new file mode 100644 index 00000000000..c1648c7cfa1 --- /dev/null +++ b/storage/api/pom.xml @@ -0,0 +1,72 @@ + + + + dev.feast + feast-parent + ${revision} + ../.. + + + 4.0.0 + feast-storage-api + + Feast Storage API + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + javax.annotation + + + + + + + + + + dev.feast + datatypes-java + ${project.version} + + + + org.apache.beam + beam-sdks-java-core + ${org.apache.beam.version} + + + + com.google.auto.value + auto-value-annotations + 1.6.6 + + + + com.google.auto.value + auto-value + 1.6.6 + provided + + + + org.apache.commons + commons-lang3 + 3.9 + + + + junit + junit + 4.12 + test + + + + diff --git a/serving/src/main/java/feast/serving/specs/FeatureSetRequest.java b/storage/api/src/main/java/feast/storage/api/retriever/FeatureSetRequest.java similarity index 84% rename from serving/src/main/java/feast/serving/specs/FeatureSetRequest.java rename to storage/api/src/main/java/feast/storage/api/retriever/FeatureSetRequest.java index 904630659d7..d181abfbe63 100644 --- a/serving/src/main/java/feast/serving/specs/FeatureSetRequest.java +++ b/storage/api/src/main/java/feast/storage/api/retriever/FeatureSetRequest.java @@ -14,13 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.specs; +package feast.storage.api.retriever; import com.google.auto.value.AutoValue; import com.google.common.collect.ImmutableSet; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.serving.ServingAPIProto.FeatureReference; import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; @AutoValue public abstract class FeatureSetRequest { @@ -50,4 +52,9 @@ public Builder addFeatureReference(FeatureReference featureReference) { public abstract FeatureSetRequest build(); } + + public Map getFeatureRefsByName() { + return getFeatureReferences().stream() + .collect(Collectors.toMap(FeatureReference::getName, featureReference -> featureReference)); + } } diff --git a/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetrievalResult.java b/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetrievalResult.java new file mode 100644 index 00000000000..a81ce776254 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetrievalResult.java @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.retriever; + +import com.google.auto.value.AutoValue; +import feast.serving.ServingAPIProto.DataFormat; +import feast.serving.ServingAPIProto.JobStatus; +import java.io.Serializable; +import java.util.List; +import javax.annotation.Nullable; + +/** Result of a historical feature retrieval request. */ +@AutoValue +public abstract class HistoricalRetrievalResult implements Serializable { + + public abstract String getId(); + + public abstract JobStatus getStatus(); + + @Nullable + public abstract String getError(); + + @Nullable + public abstract List getFileUris(); + + @Nullable + public abstract DataFormat getDataFormat(); + + /** + * Instantiates a {@link HistoricalRetrievalResult} indicating that the retrieval was a failure, + * together with its associated error. + * + * @param id retrieval id identifying the retrieval request. + * @param error error that occurred + * @return {@link HistoricalRetrievalResult} + */ + public static HistoricalRetrievalResult error(String id, Exception error) { + return newBuilder() + .setId(id) + .setStatus(JobStatus.JOB_STATUS_DONE) + .setError(error.getMessage()) + .build(); + } + + /** + * Instantiates a {@link HistoricalRetrievalResult} indicating that the retrieval was a success, + * together with the location of the output. + * + * @param id retrieval id identifying the retrieval request + * @param fileUris list of output file URIs + * @param dataFormat data format of the output files + * @return + */ + public static HistoricalRetrievalResult success( + String id, List fileUris, DataFormat dataFormat) { + return newBuilder() + .setId(id) + .setStatus(JobStatus.JOB_STATUS_DONE) + .setFileUris(fileUris) + .setDataFormat(dataFormat) + .build(); + } + + static Builder newBuilder() { + return new AutoValue_HistoricalRetrievalResult.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setId(String id); + + abstract Builder setStatus(JobStatus jobStatus); + + abstract Builder setError(String error); + + abstract Builder setFileUris(List fileUris); + + abstract Builder setDataFormat(DataFormat dataFormat); + + abstract HistoricalRetrievalResult build(); + } + + public boolean hasError() { + return getError() != null; + } +} diff --git a/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetriever.java b/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetriever.java new file mode 100644 index 00000000000..95a89c1a3cb --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/retriever/HistoricalRetriever.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.retriever; + +import feast.serving.ServingAPIProto.DatasetSource; +import java.util.List; + +/** + * A historical retriever is a feature retriever that retrieves feature data corresponding to + * provided entities over a given period of time. + */ +public interface HistoricalRetriever { + + /** + * Get temporary staging location if applicable. If not applicable to this store, returns an empty + * string. + * + * @return staging location uri + */ + String getStagingLocation(); + + /** + * Get all features corresponding to the provided batch features request. + * + * @param retrievalId String that uniquely identifies this retrieval request. + * @param datasetSource {@link DatasetSource} containing source to load the dataset containing + * entity columns. + * @param featureSetRequests List of {@link FeatureSetRequest} to feature references in the + * request tied to that feature set. + * @return {@link HistoricalRetrievalResult} if successful, contains the location of the results, + * else contains the error to be returned to the user. + */ + HistoricalRetrievalResult getHistoricalFeatures( + String retrievalId, DatasetSource datasetSource, List featureSetRequests); +} diff --git a/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetriever.java b/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetriever.java new file mode 100644 index 00000000000..5eb27b995ea --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/retriever/OnlineRetriever.java @@ -0,0 +1,40 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.retriever; + +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.types.FeatureRowProto.FeatureRow; +import java.util.List; + +/** + * An online retriever is a feature retriever that retrieves the latest feature data corresponding + * to provided entities. + */ +public interface OnlineRetriever { + + /** + * Get all values corresponding to the request. + * + * @param entityRows list of entity rows in the feature request + * @param featureSetRequests List of {@link FeatureSetRequest} to feature references in the + * request tied to that feature set. + * @return list of lists of {@link FeatureRow}s corresponding to each feature set request and + * entity row. + */ + List> getOnlineFeatures( + List entityRows, List featureSetRequests); +} diff --git a/storage/api/src/main/java/feast/storage/api/writer/DeadletterSink.java b/storage/api/src/main/java/feast/storage/api/writer/DeadletterSink.java new file mode 100644 index 00000000000..a07254bddb0 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/writer/DeadletterSink.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.writer; + +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; + +/** Interface for for implementing user defined deadletter sinks to write failed elements to. */ +public interface DeadletterSink { + + /** + * Set up the deadletter sink for writes. This method will be called once during pipeline + * initialisation. + */ + void prepareWrite(); + + /** + * Get a {@link PTransform} that writes a collection of FailedElements to the deadletter sink. + * + * @return {@link PTransform} + */ + PTransform, PDone> write(); +} diff --git a/storage/api/src/main/java/feast/storage/api/writer/FailedElement.java b/storage/api/src/main/java/feast/storage/api/writer/FailedElement.java new file mode 100644 index 00000000000..d5823414772 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/writer/FailedElement.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.writer; + +import com.google.auto.value.AutoValue; +import javax.annotation.Nullable; +import org.apache.beam.sdk.schemas.AutoValueSchema; +import org.apache.beam.sdk.schemas.annotations.DefaultSchema; +import org.joda.time.Instant; + +@AutoValue +// Use DefaultSchema annotation so this AutoValue class can be serialized by Beam +// https://issues.apache.org/jira/browse/BEAM-1891 +// https://github.com/apache/beam/pull/7334 +@DefaultSchema(AutoValueSchema.class) +public abstract class FailedElement { + public abstract Instant getTimestamp(); + + @Nullable + public abstract String getJobName(); + + @Nullable + public abstract String getProjectName(); + + @Nullable + public abstract String getFeatureSetName(); + + @Nullable + public abstract String getFeatureSetVersion(); + + @Nullable + public abstract String getTransformName(); + + @Nullable + public abstract String getPayload(); + + @Nullable + public abstract String getErrorMessage(); + + @Nullable + public abstract String getStackTrace(); + + public static Builder newBuilder() { + return new AutoValue_FailedElement.Builder().setTimestamp(Instant.now()); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setTimestamp(Instant timestamp); + + public abstract Builder setProjectName(String projectName); + + public abstract Builder setFeatureSetName(String featureSetName); + + public abstract Builder setFeatureSetVersion(String featureSetVersion); + + public abstract Builder setJobName(String jobName); + + public abstract Builder setTransformName(String transformName); + + public abstract Builder setPayload(String payload); + + public abstract Builder setErrorMessage(String errorMessage); + + public abstract Builder setStackTrace(String stackTrace); + + public abstract FailedElement build(); + } +} diff --git a/storage/api/src/main/java/feast/storage/api/writer/FeatureSink.java b/storage/api/src/main/java/feast/storage/api/writer/FeatureSink.java new file mode 100644 index 00000000000..3dfe7e8f103 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/writer/FeatureSink.java @@ -0,0 +1,54 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.writer; + +import feast.core.FeatureSetProto; +import feast.types.FeatureRowProto.FeatureRow; +import java.io.Serializable; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; + +/** Interface for implementing user defined feature sink functionality. */ +public interface FeatureSink extends Serializable { + + /** + * Set up storage backend for write. This method will be called once during pipeline + * initialisation. + * + *

Examples when schemas need to be updated: + * + *

    + *
  • when a new entity is registered, a table usually needs to be created + *
  • when a new feature is registered, a column with appropriate data type usually needs to be + * created + *
+ * + *

If the storage backend is a key-value or a schema-less database, however, there may not be a + * need to manage any schemas. + * + * @param featureSet Feature set to be written + */ + void prepareWrite(FeatureSetProto.FeatureSet featureSet); + + /** + * Get a {@link PTransform} that writes feature rows to the store, and returns a {@link + * WriteResult} that splits successful and failed inserts to be separately logged. + * + * @return {@link PTransform} + */ + PTransform, WriteResult> writer(); +} diff --git a/storage/api/src/main/java/feast/storage/api/writer/WriteResult.java b/storage/api/src/main/java/feast/storage/api/writer/WriteResult.java new file mode 100644 index 00000000000..e378c2b46a4 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/api/writer/WriteResult.java @@ -0,0 +1,97 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.api.writer; + +import com.google.common.collect.ImmutableMap; +import feast.types.FeatureRowProto.FeatureRow; +import java.io.Serializable; +import java.util.Map; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.*; + +/** The result of a write transform. */ +public final class WriteResult implements Serializable, POutput { + + private final Pipeline pipeline; + private final PCollection successfulInserts; + private final PCollection failedInserts; + + private static TupleTag successfulInsertsTag = new TupleTag<>("successfulInserts"); + private static TupleTag failedInsertsTupleTag = new TupleTag<>("failedInserts"); + + /** + * Creates a {@link WriteResult} in the given {@link Pipeline}. + * + * @param pipeline {@link Pipeline} + * @param successfulInserts {@link PCollection} of {@link FeatureRow}s successfully inserted into + * the store + * @param failedInserts {@link PCollection} of {@link FailedElement}s + * @return {@link WriteResult} + */ + public static WriteResult in( + Pipeline pipeline, + PCollection successfulInserts, + PCollection failedInserts) { + return new WriteResult(pipeline, successfulInserts, failedInserts); + } + + private WriteResult( + Pipeline pipeline, + PCollection successfulInserts, + PCollection failedInserts) { + + this.pipeline = pipeline; + this.successfulInserts = successfulInserts; + this.failedInserts = failedInserts; + } + + /** + * Gets set of feature rows that were unsuccessfully written to the store. The failed feature rows + * are wrapped in FailedElement objects so implementations of WriteResult can be flexible in how + * errors are stored. + * + * @return FailedElements of unsuccessfully written feature rows + */ + public PCollection getFailedInserts() { + return failedInserts; + } + + /** + * Gets set of successfully written feature rows. + * + * @return PCollection of feature rows successfully written to the store + */ + public PCollection getSuccessfulInserts() { + return successfulInserts; + } + + @Override + public Pipeline getPipeline() { + return pipeline; + } + + @Override + public Map, PValue> expand() { + return ImmutableMap.of( + successfulInsertsTag, successfulInserts, failedInsertsTupleTag, failedInserts); + } + + @Override + public void finishSpecifyingOutput( + String transformName, PInput input, PTransform transform) {} +} diff --git a/ingestion/src/main/java/feast/retry/BackOffExecutor.java b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java similarity index 98% rename from ingestion/src/main/java/feast/retry/BackOffExecutor.java rename to storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java index 344c65ac424..296582f8b35 100644 --- a/ingestion/src/main/java/feast/retry/BackOffExecutor.java +++ b/storage/api/src/main/java/feast/storage/common/retry/BackOffExecutor.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.retry; +package feast.storage.common.retry; import java.io.Serializable; import org.apache.beam.sdk.util.BackOff; diff --git a/ingestion/src/main/java/feast/retry/Retriable.java b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java similarity index 95% rename from ingestion/src/main/java/feast/retry/Retriable.java rename to storage/api/src/main/java/feast/storage/common/retry/Retriable.java index 30676fe8208..2c92c851758 100644 --- a/ingestion/src/main/java/feast/retry/Retriable.java +++ b/storage/api/src/main/java/feast/storage/common/retry/Retriable.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.retry; +package feast.storage.common.retry; public interface Retriable { void execute() throws Exception; diff --git a/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java b/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java new file mode 100644 index 00000000000..6047a93dc17 --- /dev/null +++ b/storage/api/src/main/java/feast/storage/common/testing/TestUtil.java @@ -0,0 +1,188 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.common.testing; + +import com.google.protobuf.ByteString; +import com.google.protobuf.util.Timestamps; +import feast.core.FeatureSetProto.FeatureSet; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FeatureRowProto.FeatureRow.Builder; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.*; +import java.util.concurrent.ThreadLocalRandom; +import org.apache.commons.lang3.RandomStringUtils; + +@SuppressWarnings("WeakerAccess") +public class TestUtil { + + /** + * Create a Feature Row with random value according to the FeatureSetSpec + * + * @param featureSet {@link FeatureSet} + * @return {@link FeatureRow} + */ + public static FeatureRow createRandomFeatureRow(FeatureSet featureSet) { + ThreadLocalRandom random = ThreadLocalRandom.current(); + int randomStringSizeMaxSize = 12; + return createRandomFeatureRow(featureSet, random.nextInt(0, randomStringSizeMaxSize) + 4); + } + + /** + * Create a Feature Row with random value according to the FeatureSet. + * + *

The Feature Row created contains fields according to the entities and features defined in + * FeatureSet, matching the value type of the field, with randomized value for testing. + * + * @param featureSet {@link FeatureSet} + * @param randomStringSize number of characters for the generated random string + * @return {@link FeatureRow} + */ + public static FeatureRow createRandomFeatureRow(FeatureSet featureSet, int randomStringSize) { + Builder builder = + FeatureRow.newBuilder() + .setFeatureSet(getFeatureSetReference(featureSet)) + .setEventTimestamp(Timestamps.fromMillis(System.currentTimeMillis())); + + featureSet + .getSpec() + .getEntitiesList() + .forEach( + field -> { + builder.addFields( + Field.newBuilder() + .setName(field.getName()) + .setValue(createRandomValue(field.getValueType(), randomStringSize)) + .build()); + }); + + featureSet + .getSpec() + .getFeaturesList() + .forEach( + field -> { + builder.addFields( + Field.newBuilder() + .setName(field.getName()) + .setValue(createRandomValue(field.getValueType(), randomStringSize)) + .build()); + }); + + return builder.build(); + } + + private static String getFeatureSetReference(FeatureSet featureSet) { + FeatureSetSpec spec = featureSet.getSpec(); + return String.format("%s/%s:%d", spec.getProject(), spec.getName(), spec.getVersion()); + } + + /** + * Create a random Feast {@link Value} of {@link ValueType.Enum}. + * + * @param type {@link ValueType.Enum} + * @param randomStringSize number of characters for the generated random string + * @return {@link Value} + */ + public static Value createRandomValue(ValueType.Enum type, int randomStringSize) { + Value.Builder builder = Value.newBuilder(); + ThreadLocalRandom random = ThreadLocalRandom.current(); + + switch (type) { + case INVALID: + case UNRECOGNIZED: + throw new IllegalArgumentException("Invalid ValueType: " + type); + case BYTES: + builder.setBytesVal( + ByteString.copyFrom(RandomStringUtils.randomAlphanumeric(randomStringSize).getBytes())); + break; + case STRING: + builder.setStringVal(RandomStringUtils.randomAlphanumeric(randomStringSize)); + break; + case INT32: + builder.setInt32Val(random.nextInt()); + break; + case INT64: + builder.setInt64Val(random.nextLong()); + break; + case DOUBLE: + builder.setDoubleVal(random.nextDouble()); + break; + case FLOAT: + builder.setFloatVal(random.nextFloat()); + break; + case BOOL: + builder.setBoolVal(random.nextBoolean()); + break; + case BYTES_LIST: + builder.setBytesListVal( + BytesList.newBuilder() + .addVal( + ByteString.copyFrom( + RandomStringUtils.randomAlphanumeric(randomStringSize).getBytes())) + .build()); + break; + case STRING_LIST: + builder.setStringListVal( + StringList.newBuilder() + .addVal(RandomStringUtils.randomAlphanumeric(randomStringSize)) + .build()); + break; + case INT32_LIST: + builder.setInt32ListVal(Int32List.newBuilder().addVal(random.nextInt()).build()); + break; + case INT64_LIST: + builder.setInt64ListVal(Int64List.newBuilder().addVal(random.nextLong()).build()); + break; + case DOUBLE_LIST: + builder.setDoubleListVal(DoubleList.newBuilder().addVal(random.nextDouble()).build()); + break; + case FLOAT_LIST: + builder.setFloatListVal(FloatList.newBuilder().addVal(random.nextFloat()).build()); + break; + case BOOL_LIST: + builder.setBoolListVal(BoolList.newBuilder().addVal(random.nextBoolean()).build()); + break; + } + return builder.build(); + } + + /** + * Create a field object with given name and type. + * + * @param name of the field. + * @param value of the field. Should be compatible with the valuetype given. + * @param valueType type of the field. + * @return Field object + */ + public static Field field(String name, Object value, ValueType.Enum valueType) { + Field.Builder fieldBuilder = Field.newBuilder().setName(name); + switch (valueType) { + case INT32: + return fieldBuilder.setValue(Value.newBuilder().setInt32Val((int) value)).build(); + case INT64: + return fieldBuilder.setValue(Value.newBuilder().setInt64Val((int) value)).build(); + case FLOAT: + return fieldBuilder.setValue(Value.newBuilder().setFloatVal((float) value)).build(); + case DOUBLE: + return fieldBuilder.setValue(Value.newBuilder().setDoubleVal((double) value)).build(); + case STRING: + return fieldBuilder.setValue(Value.newBuilder().setStringVal((String) value)).build(); + default: + throw new IllegalStateException("Unexpected valueType: " + value.getClass()); + } + } +} diff --git a/storage/connectors/bigquery/pom.xml b/storage/connectors/bigquery/pom.xml new file mode 100644 index 00000000000..fab3739c43c --- /dev/null +++ b/storage/connectors/bigquery/pom.xml @@ -0,0 +1,94 @@ + + + + dev.feast + feast-storage-connectors + ${revision} + + + 4.0.0 + feast-storage-connector-bigquery + + Feast Storage Connector for BigQuery + + + + io.pebbletemplates + pebble + 3.1.0 + + + + + com.google.cloud + google-cloud-bigquery + + + + com.google.cloud + google-cloud-storage + + + + org.apache.beam + beam-sdks-java-io-google-cloud-platform + ${org.apache.beam.version} + + + com.google.cloud + google-cloud-spanner + + + com.google.cloud.bigtable + bigtable-client-core + + + + + + io.opencensus + opencensus-contrib-http-util + 0.21.0 + + + + com.google.auto.value + auto-value-annotations + 1.6.6 + + + + com.google.auto.value + auto-value + 1.6.6 + provided + + + + junit + junit + 4.12 + test + + + + org.apache.beam + beam-runners-direct-java + ${org.apache.beam.version} + test + + + + org.hamcrest + hamcrest-core + test + + + org.hamcrest + hamcrest-library + test + + + + diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/common/TypeUtil.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/common/TypeUtil.java new file mode 100644 index 00000000000..dcd13093177 --- /dev/null +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/common/TypeUtil.java @@ -0,0 +1,66 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.bigquery.common; + +import com.google.cloud.bigquery.StandardSQLTypeName; +import feast.types.ValueProto; +import java.util.HashMap; +import java.util.Map; + +public class TypeUtil { + + private static final Map + VALUE_TYPE_TO_STANDARD_SQL_TYPE = new HashMap<>(); + + static { + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put(ValueProto.ValueType.Enum.BYTES, StandardSQLTypeName.BYTES); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.STRING, StandardSQLTypeName.STRING); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put(ValueProto.ValueType.Enum.INT32, StandardSQLTypeName.INT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put(ValueProto.ValueType.Enum.INT64, StandardSQLTypeName.INT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.DOUBLE, StandardSQLTypeName.FLOAT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.FLOAT, StandardSQLTypeName.FLOAT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put(ValueProto.ValueType.Enum.BOOL, StandardSQLTypeName.BOOL); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.BYTES_LIST, StandardSQLTypeName.BYTES); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.STRING_LIST, StandardSQLTypeName.STRING); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.INT32_LIST, StandardSQLTypeName.INT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.INT64_LIST, StandardSQLTypeName.INT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.DOUBLE_LIST, StandardSQLTypeName.FLOAT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.FLOAT_LIST, StandardSQLTypeName.FLOAT64); + VALUE_TYPE_TO_STANDARD_SQL_TYPE.put( + ValueProto.ValueType.Enum.BOOL_LIST, StandardSQLTypeName.BOOL); + } + + /** + * Converts {@link feast.types.ValueProto.ValueType} to its corresponding {@link + * StandardSQLTypeName} + * + * @param valueType value type to convert + * @return {@link StandardSQLTypeName} + */ + public static StandardSQLTypeName toStandardSqlType(ValueProto.ValueType.Enum valueType) { + return VALUE_TYPE_TO_STANDARD_SQL_TYPE.get(valueType); + } +} diff --git a/serving/src/main/java/feast/serving/store/bigquery/BatchRetrievalQueryRunnable.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/BigQueryHistoricalRetriever.java similarity index 52% rename from serving/src/main/java/feast/serving/store/bigquery/BatchRetrievalQueryRunnable.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/BigQueryHistoricalRetriever.java index 61103af1092..27ba07e82ec 100644 --- a/serving/src/main/java/feast/serving/store/bigquery/BatchRetrievalQueryRunnable.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/BigQueryHistoricalRetriever.java @@ -1,6 +1,6 @@ /* * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors + * Copyright 2018-2020 The Feast Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,88 +14,46 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.store.bigquery; +package feast.storage.connectors.bigquery.retriever; -import static feast.serving.service.BigQueryServingService.TEMP_TABLE_EXPIRY_DURATION_MS; -import static feast.serving.service.BigQueryServingService.createTempTableName; -import static feast.serving.store.bigquery.QueryTemplater.createTimestampLimitQuery; +import static feast.storage.connectors.bigquery.retriever.QueryTemplater.createEntityTableUUIDQuery; +import static feast.storage.connectors.bigquery.retriever.QueryTemplater.createTimestampLimitQuery; import com.google.auto.value.AutoValue; import com.google.cloud.RetryOption; -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.BigQueryException; -import com.google.cloud.bigquery.DatasetId; -import com.google.cloud.bigquery.ExtractJobConfiguration; -import com.google.cloud.bigquery.FieldValueList; -import com.google.cloud.bigquery.Job; -import com.google.cloud.bigquery.JobInfo; -import com.google.cloud.bigquery.QueryJobConfiguration; -import com.google.cloud.bigquery.TableId; -import com.google.cloud.bigquery.TableInfo; -import com.google.cloud.bigquery.TableResult; +import com.google.cloud.bigquery.*; import com.google.cloud.storage.Blob; import com.google.cloud.storage.Storage; -import com.google.cloud.storage.Storage.BlobListOption; import feast.serving.ServingAPIProto; -import feast.serving.ServingAPIProto.DataFormat; -import feast.serving.ServingAPIProto.JobStatus; -import feast.serving.ServingAPIProto.JobType; -import feast.serving.service.JobService; -import feast.serving.store.bigquery.model.FeatureSetInfo; +import feast.serving.ServingAPIProto.DatasetSource; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.storage.api.retriever.HistoricalRetrievalResult; +import feast.storage.api.retriever.HistoricalRetriever; import io.grpc.Status; import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorCompletionService; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; +import java.util.UUID; +import java.util.concurrent.*; +import java.util.stream.Collectors; +import org.slf4j.Logger; import org.threeten.bp.Duration; -/** - * BatchRetrievalQueryRunnable is a Runnable for running a BigQuery Feast batch retrieval job async. - * - *

It does the following, in sequence: - * - *

1. Retrieve the temporal bounds of the entity dataset provided. This will be used to filter - * the feature set tables when performing the feature retrieval. - * - *

2. For each of the feature sets requested, generate the subquery for doing a point-in-time - * correctness join of the features in the feature set to the entity table. - * - *

3. Run each of the subqueries in parallel and wait for them to complete. If any of the jobs - * are unsuccessful, the thread running the BatchRetrievalQueryRunnable catches the error and - * updates the job database. - * - *

4. When all the subquery jobs are complete, join the outputs of all the subqueries into a - * single table. - * - *

5. Extract the output of the join to a remote file, and write the location of the remote file - * to the job database, and mark the retrieval job as successful. - */ @AutoValue -public abstract class BatchRetrievalQueryRunnable implements Runnable { +public abstract class BigQueryHistoricalRetriever implements HistoricalRetriever { - private static final long SUBQUERY_TIMEOUT_SECS = 900; // 15 minutes + private static final Logger log = + org.slf4j.LoggerFactory.getLogger(BigQueryHistoricalRetriever.class); - public abstract JobService jobService(); + public static final long TEMP_TABLE_EXPIRY_DURATION_MS = Duration.ofDays(1).toMillis(); + private static final long SUBQUERY_TIMEOUT_SECS = 900; // 15 minutes public abstract String projectId(); public abstract String datasetId(); - public abstract String feastJobId(); - public abstract BigQuery bigquery(); - public abstract List entityTableColumnNames(); - - public abstract List featureSetInfos(); - - public abstract String entityTableName(); - public abstract String jobStagingLocation(); public abstract int initialRetryDelaySecs(); @@ -105,55 +63,78 @@ public abstract class BatchRetrievalQueryRunnable implements Runnable { public abstract Storage storage(); public static Builder builder() { - return new AutoValue_BatchRetrievalQueryRunnable.Builder(); + return new AutoValue_BigQueryHistoricalRetriever.Builder(); } @AutoValue.Builder public abstract static class Builder { - public abstract Builder setJobService(JobService jobService); - public abstract Builder setProjectId(String projectId); public abstract Builder setDatasetId(String datasetId); - public abstract Builder setFeastJobId(String feastJobId); + public abstract Builder setJobStagingLocation(String jobStagingLocation); public abstract Builder setBigquery(BigQuery bigquery); - public abstract Builder setEntityTableColumnNames(List entityTableColumnNames); - - public abstract Builder setFeatureSetInfos(List featureSetInfos); - - public abstract Builder setEntityTableName(String entityTableName); - - public abstract Builder setJobStagingLocation(String jobStagingLocation); - public abstract Builder setInitialRetryDelaySecs(int initialRetryDelaySecs); public abstract Builder setTotalTimeoutSecs(int totalTimeoutSecs); public abstract Builder setStorage(Storage storage); - public abstract BatchRetrievalQueryRunnable build(); + public abstract BigQueryHistoricalRetriever build(); } @Override - public void run() { + public String getStagingLocation() { + return jobStagingLocation(); + } - // 1. Retrieve the temporal bounds of the entity dataset provided - FieldValueList timestampLimits = getTimestampLimits(entityTableName()); + @Override + public HistoricalRetrievalResult getHistoricalFeatures( + String retrievalId, DatasetSource datasetSource, List featureSetRequests) { + List featureSetQueryInfos = + QueryTemplater.getFeatureSetInfos(featureSetRequests); + + // 1. load entity table + Table entityTable; + String entityTableName; + try { + entityTable = loadEntities(datasetSource); + + TableId entityTableWithUUIDs = generateUUIDs(entityTable); + entityTableName = generateFullTableName(entityTableWithUUIDs); + } catch (Exception e) { + return HistoricalRetrievalResult.error( + retrievalId, + new RuntimeException( + String.format("Unable to load entity table to BigQuery: %s", e.toString()))); + } + + Schema entityTableSchema = entityTable.getDefinition().getSchema(); + List entityTableColumnNames = + entityTableSchema.getFields().stream() + .map(Field::getName) + .filter(name -> !name.equals("event_timestamp")) + .collect(Collectors.toList()); + + // 2. Retrieve the temporal bounds of the entity dataset provided + FieldValueList timestampLimits = getTimestampLimits(entityTableName); - // 2. Generate the subqueries - List featureSetQueries = generateQueries(timestampLimits); + // 3. Generate the subqueries + List featureSetQueries = + generateQueries(entityTableName, timestampLimits, featureSetQueryInfos); QueryJobConfiguration queryConfig; try { - // 3 & 4. Run the subqueries in parallel then collect the outputs - Job queryJob = runBatchQuery(featureSetQueries); + // 4. Run the subqueries in parallel then collect the outputs + Job queryJob = + runBatchQuery( + entityTableName, entityTableColumnNames, featureSetQueryInfos, featureSetQueries); queryConfig = queryJob.getConfiguration(); String exportTableDestinationUri = - String.format("%s/%s/*.avro", jobStagingLocation(), feastJobId()); + String.format("%s/%s/*.avro", jobStagingLocation(), retrievalId); // 5. Export the table // Hardcode the format to Avro for now @@ -162,60 +143,166 @@ public void run() { queryConfig.getDestinationTable(), exportTableDestinationUri, "Avro"); Job extractJob = bigquery().create(JobInfo.of(extractConfig)); waitForJob(extractJob); + } catch (BigQueryException | InterruptedException | IOException e) { - jobService() - .upsert( - ServingAPIProto.Job.newBuilder() - .setId(feastJobId()) - .setType(JobType.JOB_TYPE_DOWNLOAD) - .setStatus(JobStatus.JOB_STATUS_DONE) - .setError(e.getMessage()) - .build()); - return; + return HistoricalRetrievalResult.error(retrievalId, e); } - List fileUris = parseOutputFileURIs(); - - // 5. Update the job database - jobService() - .upsert( - ServingAPIProto.Job.newBuilder() - .setId(feastJobId()) - .setType(JobType.JOB_TYPE_DOWNLOAD) - .setStatus(JobStatus.JOB_STATUS_DONE) - .addAllFileUris(fileUris) - .setDataFormat(DataFormat.DATA_FORMAT_AVRO) - .build()); + List fileUris = parseOutputFileURIs(retrievalId); + + return HistoricalRetrievalResult.success( + retrievalId, fileUris, ServingAPIProto.DataFormat.DATA_FORMAT_AVRO); } - private List parseOutputFileURIs() { - String scheme = jobStagingLocation().substring(0, jobStagingLocation().indexOf("://")); - String stagingLocationNoScheme = - jobStagingLocation().substring(jobStagingLocation().indexOf("://") + 3); - String bucket = stagingLocationNoScheme.split("/")[0]; - List prefixParts = new ArrayList<>(); - prefixParts.add( - stagingLocationNoScheme.contains("/") && !stagingLocationNoScheme.endsWith("/") - ? stagingLocationNoScheme.substring(stagingLocationNoScheme.indexOf("/") + 1) - : ""); - prefixParts.add(feastJobId()); - String prefix = String.join("/", prefixParts) + "/"; + private TableId generateUUIDs(Table loadedEntityTable) { + try { + String uuidQuery = + createEntityTableUUIDQuery(generateFullTableName(loadedEntityTable.getTableId())); + QueryJobConfiguration queryJobConfig = + QueryJobConfiguration.newBuilder(uuidQuery) + .setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName())) + .build(); + Job queryJob = bigquery().create(JobInfo.of(queryJobConfig)); + Job completedJob = waitForJob(queryJob); + TableInfo expiry = + bigquery() + .getTable(queryJobConfig.getDestinationTable()) + .toBuilder() + .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) + .build(); + bigquery().update(expiry); + queryJobConfig = completedJob.getConfiguration(); + return queryJobConfig.getDestinationTable(); + } catch (InterruptedException | BigQueryException e) { + throw Status.INTERNAL + .withDescription("Failed to load entity dataset into store") + .withCause(e) + .asRuntimeException(); + } + } - List fileUris = new ArrayList<>(); - for (Blob blob : storage().list(bucket, BlobListOption.prefix(prefix)).iterateAll()) { - fileUris.add(String.format("%s://%s/%s", scheme, blob.getBucket(), blob.getName())); + private FieldValueList getTimestampLimits(String entityTableName) { + QueryJobConfiguration getTimestampLimitsQuery = + QueryJobConfiguration.newBuilder(createTimestampLimitQuery(entityTableName)) + .setDefaultDataset(DatasetId.of(projectId(), datasetId())) + .setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName())) + .build(); + try { + Job job = bigquery().create(JobInfo.of(getTimestampLimitsQuery)); + TableResult getTimestampLimitsQueryResult = waitForJob(job).getQueryResults(); + TableInfo expiry = + bigquery() + .getTable(getTimestampLimitsQuery.getDestinationTable()) + .toBuilder() + .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) + .build(); + bigquery().update(expiry); + FieldValueList result = null; + for (FieldValueList fields : getTimestampLimitsQueryResult.getValues()) { + result = fields; + } + if (result == null || result.get("min").isNull() || result.get("max").isNull()) { + throw new RuntimeException("query returned insufficient values"); + } + return result; + } catch (InterruptedException e) { + throw Status.INTERNAL + .withDescription("Unable to extract min and max timestamps from query") + .withCause(e) + .asRuntimeException(); + } + } + + private Table loadEntities(ServingAPIProto.DatasetSource datasetSource) { + Table loadedEntityTable; + switch (datasetSource.getDatasetSourceCase()) { + case FILE_SOURCE: + try { + // Currently only AVRO format is supported + if (datasetSource.getFileSource().getDataFormat() + != ServingAPIProto.DataFormat.DATA_FORMAT_AVRO) { + throw Status.INVALID_ARGUMENT + .withDescription("Invalid file format, only AVRO is supported.") + .asRuntimeException(); + } + + TableId tableId = TableId.of(projectId(), datasetId(), createTempTableName()); + log.info( + "Loading entity rows to: {}.{}.{}", projectId(), datasetId(), tableId.getTable()); + + LoadJobConfiguration loadJobConfiguration = + LoadJobConfiguration.of( + tableId, datasetSource.getFileSource().getFileUrisList(), FormatOptions.avro()); + loadJobConfiguration = + loadJobConfiguration.toBuilder().setUseAvroLogicalTypes(true).build(); + Job job = bigquery().create(JobInfo.of(loadJobConfiguration)); + waitForJob(job); + + TableInfo expiry = + bigquery() + .getTable(tableId) + .toBuilder() + .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) + .build(); + bigquery().update(expiry); + + loadedEntityTable = bigquery().getTable(tableId); + if (!loadedEntityTable.exists()) { + throw new RuntimeException( + "Unable to create entity dataset table, table already exists"); + } + return loadedEntityTable; + } catch (Exception e) { + log.error("Exception has occurred in loadEntities method: ", e); + throw Status.INTERNAL + .withDescription("Failed to load entity dataset into store: " + e.toString()) + .withCause(e) + .asRuntimeException(); + } + case DATASETSOURCE_NOT_SET: + default: + throw Status.INVALID_ARGUMENT + .withDescription("Data source must be set.") + .asRuntimeException(); } - return fileUris; } - Job runBatchQuery(List featureSetQueries) + private List generateQueries( + String entityTableName, + FieldValueList timestampLimits, + List featureSetQueryInfos) { + List featureSetQueries = new ArrayList<>(); + try { + for (FeatureSetQueryInfo featureSetInfo : featureSetQueryInfos) { + String query = + QueryTemplater.createFeatureSetPointInTimeQuery( + featureSetInfo, + projectId(), + datasetId(), + entityTableName, + timestampLimits.get("min").getStringValue(), + timestampLimits.get("max").getStringValue()); + featureSetQueries.add(query); + } + } catch (IOException e) { + throw Status.INTERNAL + .withDescription("Unable to generate query for batch retrieval") + .withCause(e) + .asRuntimeException(); + } + return featureSetQueries; + } + + Job runBatchQuery( + String entityTableName, + List entityTableColumnNames, + List featureSetQueryInfos, + List featureSetQueries) throws BigQueryException, InterruptedException, IOException { ExecutorService executorService = Executors.newFixedThreadPool(featureSetQueries.size()); - ExecutorCompletionService executorCompletionService = + ExecutorCompletionService executorCompletionService = new ExecutorCompletionService<>(executorService); - List featureSetInfos = new ArrayList<>(); - // For each of the feature sets requested, start an async job joining the features in that // feature set to the provided entity table for (int i = 0; i < featureSetQueries.size(); i++) { @@ -227,28 +314,21 @@ Job runBatchQuery(List featureSetQueries) executorCompletionService.submit( SubqueryCallable.builder() .setBigquery(bigquery()) - .setFeatureSetInfo(featureSetInfos().get(i)) + .setFeatureSetInfo(featureSetQueryInfos.get(i)) .setSubqueryJob(subqueryJob) .build()); } + List completedFeatureSetQueryInfos = new ArrayList<>(); + for (int i = 0; i < featureSetQueries.size(); i++) { try { // Try to retrieve the outputs of all the jobs. The timeout here is a formality; // a stricter timeout is implemented in the actual SubqueryCallable. - FeatureSetInfo featureSetInfo = + FeatureSetQueryInfo featureSetInfo = executorCompletionService.take().get(SUBQUERY_TIMEOUT_SECS, TimeUnit.SECONDS); - featureSetInfos.add(featureSetInfo); + completedFeatureSetQueryInfos.add(featureSetInfo); } catch (InterruptedException | ExecutionException | TimeoutException e) { - jobService() - .upsert( - ServingAPIProto.Job.newBuilder() - .setId(feastJobId()) - .setType(JobType.JOB_TYPE_DOWNLOAD) - .setStatus(JobStatus.JOB_STATUS_DONE) - .setError(e.getMessage()) - .build()); - executorService.shutdownNow(); throw Status.INTERNAL .withDescription("Error running batch query") @@ -261,7 +341,7 @@ Job runBatchQuery(List featureSetQueries) // subqueries into a single table. String joinQuery = QueryTemplater.createJoinQuery( - featureSetInfos, entityTableColumnNames(), entityTableName()); + completedFeatureSetQueryInfos, entityTableColumnNames, entityTableName); QueryJobConfiguration queryJobConfig = QueryJobConfiguration.newBuilder(joinQuery) .setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName())) @@ -280,59 +360,24 @@ Job runBatchQuery(List featureSetQueries) return completedQueryJob; } - private List generateQueries(FieldValueList timestampLimits) { - List featureSetQueries = new ArrayList<>(); - try { - for (FeatureSetInfo featureSetInfo : featureSetInfos()) { - String query = - QueryTemplater.createFeatureSetPointInTimeQuery( - featureSetInfo, - projectId(), - datasetId(), - entityTableName(), - timestampLimits.get("min").getStringValue(), - timestampLimits.get("max").getStringValue()); - featureSetQueries.add(query); - } - } catch (IOException e) { - throw Status.INTERNAL - .withDescription("Unable to generate query for batch retrieval") - .withCause(e) - .asRuntimeException(); - } - return featureSetQueries; - } + private List parseOutputFileURIs(String feastJobId) { + String scheme = jobStagingLocation().substring(0, jobStagingLocation().indexOf("://")); + String stagingLocationNoScheme = + jobStagingLocation().substring(jobStagingLocation().indexOf("://") + 3); + String bucket = stagingLocationNoScheme.split("/")[0]; + List prefixParts = new ArrayList<>(); + prefixParts.add( + stagingLocationNoScheme.contains("/") && !stagingLocationNoScheme.endsWith("/") + ? stagingLocationNoScheme.substring(stagingLocationNoScheme.indexOf("/") + 1) + : ""); + prefixParts.add(feastJobId); + String prefix = String.join("/", prefixParts) + "/"; - private FieldValueList getTimestampLimits(String entityTableName) { - QueryJobConfiguration getTimestampLimitsQuery = - QueryJobConfiguration.newBuilder(createTimestampLimitQuery(entityTableName)) - .setDefaultDataset(DatasetId.of(projectId(), datasetId())) - .setDestinationTable(TableId.of(projectId(), datasetId(), createTempTableName())) - .build(); - try { - Job job = bigquery().create(JobInfo.of(getTimestampLimitsQuery)); - TableResult getTimestampLimitsQueryResult = waitForJob(job).getQueryResults(); - TableInfo expiry = - bigquery() - .getTable(getTimestampLimitsQuery.getDestinationTable()) - .toBuilder() - .setExpirationTime(System.currentTimeMillis() + TEMP_TABLE_EXPIRY_DURATION_MS) - .build(); - bigquery().update(expiry); - FieldValueList result = null; - for (FieldValueList fields : getTimestampLimitsQueryResult.getValues()) { - result = fields; - } - if (result == null || result.get("min").isNull() || result.get("max").isNull()) { - throw new RuntimeException("query returned insufficient values"); - } - return result; - } catch (InterruptedException e) { - throw Status.INTERNAL - .withDescription("Unable to extract min and max timestamps from query") - .withCause(e) - .asRuntimeException(); + List fileUris = new ArrayList<>(); + for (Blob blob : storage().list(bucket, Storage.BlobListOption.prefix(prefix)).iterateAll()) { + fileUris.add(String.format("%s://%s/%s", scheme, blob.getBucket(), blob.getName())); } + return fileUris; } private Job waitForJob(Job queryJob) throws InterruptedException { @@ -349,4 +394,13 @@ private Job waitForJob(Job queryJob) throws InterruptedException { } return completedJob; } + + public String generateFullTableName(TableId tableId) { + return String.format( + "%s.%s.%s", tableId.getProject(), tableId.getDataset(), tableId.getTable()); + } + + public String createTempTableName() { + return "_" + UUID.randomUUID().toString().replace("-", ""); + } } diff --git a/serving/src/main/java/feast/serving/store/bigquery/model/FeatureSetInfo.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/FeatureSetQueryInfo.java similarity index 90% rename from serving/src/main/java/feast/serving/store/bigquery/model/FeatureSetInfo.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/FeatureSetQueryInfo.java index 77c80ead0ea..5a7d56e9844 100644 --- a/serving/src/main/java/feast/serving/store/bigquery/model/FeatureSetInfo.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/FeatureSetQueryInfo.java @@ -14,11 +14,11 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.store.bigquery.model; +package feast.storage.connectors.bigquery.retriever; import java.util.List; -public class FeatureSetInfo { +public class FeatureSetQueryInfo { private final String project; private final String name; @@ -28,7 +28,7 @@ public class FeatureSetInfo { private final List features; private final String table; - public FeatureSetInfo( + public FeatureSetQueryInfo( String project, String name, int version, @@ -45,7 +45,7 @@ public FeatureSetInfo( this.table = table; } - public FeatureSetInfo(FeatureSetInfo featureSetInfo, String table) { + public FeatureSetQueryInfo(FeatureSetQueryInfo featureSetInfo, String table) { this.project = featureSetInfo.getProject(); this.name = featureSetInfo.getName(); diff --git a/serving/src/main/java/feast/serving/store/bigquery/QueryTemplater.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/QueryTemplater.java similarity index 90% rename from serving/src/main/java/feast/serving/store/bigquery/QueryTemplater.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/QueryTemplater.java index e3f1138db89..cba997b6ab0 100644 --- a/serving/src/main/java/feast/serving/store/bigquery/QueryTemplater.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/QueryTemplater.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.store.bigquery; +package feast.storage.connectors.bigquery.retriever; import com.google.cloud.bigquery.TableId; import com.google.protobuf.Duration; @@ -23,8 +23,7 @@ import feast.core.FeatureSetProto.EntitySpec; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.serving.ServingAPIProto.FeatureReference; -import feast.serving.specs.FeatureSetRequest; -import feast.serving.store.bigquery.model.FeatureSetInfo; +import feast.storage.api.retriever.FeatureSetRequest; import java.io.IOException; import java.io.StringWriter; import java.io.Writer; @@ -67,13 +66,14 @@ public static String createEntityTableUUIDQuery(String leftTableName) { * Generate the information necessary for the sql templating for point in time correctness join to * the entity dataset for each feature set requested. * - * @param featureSetRequests List of feature sets requested + * @param featureSetRequests List of {@link FeatureSetRequest} containing a {@link FeatureSetSpec} + * and its corresponding {@link FeatureReference}s provided by the user. * @return List of FeatureSetInfos */ - public static List getFeatureSetInfos(List featureSetRequests) - throws IllegalArgumentException { + public static List getFeatureSetInfos( + List featureSetRequests) throws IllegalArgumentException { - List featureSetInfos = new ArrayList<>(); + List featureSetInfos = new ArrayList<>(); for (FeatureSetRequest featureSetRequest : featureSetRequests) { FeatureSetSpec spec = featureSetRequest.getSpec(); Duration maxAge = spec.getMaxAge(); @@ -84,7 +84,7 @@ public static List getFeatureSetInfos(List fe .map(FeatureReference::getName) .collect(Collectors.toList()); featureSetInfos.add( - new FeatureSetInfo( + new FeatureSetQueryInfo( spec.getProject(), spec.getName(), spec.getVersion(), @@ -109,7 +109,7 @@ public static List getFeatureSetInfos(List fe * @return point in time correctness join BQ SQL query */ public static String createFeatureSetPointInTimeQuery( - FeatureSetInfo featureSetInfo, + FeatureSetQueryInfo featureSetInfo, String projectId, String datasetId, String leftTableName, @@ -139,7 +139,7 @@ public static String createFeatureSetPointInTimeQuery( * @return query to join temporary feature set tables to the entity table */ public static String createJoinQuery( - List featureSetInfos, + List featureSetInfos, List entityTableColumnNames, String leftTableName) throws IOException { diff --git a/serving/src/main/java/feast/serving/store/bigquery/SubqueryCallable.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/SubqueryCallable.java similarity index 70% rename from serving/src/main/java/feast/serving/store/bigquery/SubqueryCallable.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/SubqueryCallable.java index 14026030b42..43a32cef504 100644 --- a/serving/src/main/java/feast/serving/store/bigquery/SubqueryCallable.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/retriever/SubqueryCallable.java @@ -14,19 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.store.bigquery; +package feast.storage.connectors.bigquery.retriever; -import static feast.serving.service.BigQueryServingService.TEMP_TABLE_EXPIRY_DURATION_MS; -import static feast.serving.store.bigquery.QueryTemplater.generateFullTableName; +import static feast.storage.connectors.bigquery.retriever.BigQueryHistoricalRetriever.TEMP_TABLE_EXPIRY_DURATION_MS; +import static feast.storage.connectors.bigquery.retriever.QueryTemplater.generateFullTableName; import com.google.auto.value.AutoValue; -import com.google.cloud.bigquery.BigQuery; -import com.google.cloud.bigquery.BigQueryException; -import com.google.cloud.bigquery.Job; -import com.google.cloud.bigquery.QueryJobConfiguration; -import com.google.cloud.bigquery.TableId; -import com.google.cloud.bigquery.TableInfo; -import feast.serving.store.bigquery.model.FeatureSetInfo; +import com.google.cloud.bigquery.*; import java.util.concurrent.Callable; /** @@ -34,11 +28,11 @@ * updated with the reference to the table containing the results of the query. */ @AutoValue -public abstract class SubqueryCallable implements Callable { +public abstract class SubqueryCallable implements Callable { public abstract BigQuery bigquery(); - public abstract FeatureSetInfo featureSetInfo(); + public abstract FeatureSetQueryInfo featureSetInfo(); public abstract Job subqueryJob(); @@ -51,7 +45,7 @@ public abstract static class Builder { public abstract Builder setBigquery(BigQuery bigquery); - public abstract Builder setFeatureSetInfo(FeatureSetInfo featureSetInfo); + public abstract Builder setFeatureSetInfo(FeatureSetQueryInfo featureSetInfo); public abstract Builder setSubqueryJob(Job subqueryJob); @@ -59,7 +53,7 @@ public abstract static class Builder { } @Override - public FeatureSetInfo call() throws BigQueryException, InterruptedException { + public FeatureSetQueryInfo call() throws BigQueryException, InterruptedException { QueryJobConfiguration subqueryConfig; subqueryJob().waitFor(); subqueryConfig = subqueryJob().getConfiguration(); @@ -75,6 +69,6 @@ public FeatureSetInfo call() throws BigQueryException, InterruptedException { String fullTablePath = generateFullTableName(destinationTable); - return new FeatureSetInfo(featureSetInfo(), fullTablePath); + return new FeatureSetQueryInfo(featureSetInfo(), fullTablePath); } } diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryDeadletterSink.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryDeadletterSink.java new file mode 100644 index 00000000000..96364c96c78 --- /dev/null +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryDeadletterSink.java @@ -0,0 +1,133 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.bigquery.writer; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.auto.value.AutoValue; +import com.google.common.io.Resources; +import feast.storage.api.writer.DeadletterSink; +import feast.storage.api.writer.FailedElement; +import java.nio.charset.StandardCharsets; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.CreateDisposition; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.slf4j.Logger; + +public class BigQueryDeadletterSink implements DeadletterSink { + + private static final String DEADLETTER_SCHEMA_FILE_PATH = "schemas/deadletter_table_schema.json"; + private static final Logger log = org.slf4j.LoggerFactory.getLogger(BigQueryDeadletterSink.class); + private static final String TIMESTAMP_COLUMN = "timestamp"; + + private final String tableSpec; + private String jsonSchema; + + public BigQueryDeadletterSink(String tableSpec) { + + this.tableSpec = tableSpec; + try { + jsonSchema = + Resources.toString( + Resources.getResource(DEADLETTER_SCHEMA_FILE_PATH), StandardCharsets.UTF_8); + } catch (Exception e) { + log.error( + "Unable to read {} file from the resources folder!", DEADLETTER_SCHEMA_FILE_PATH, e); + } + } + + @Override + public void prepareWrite() {} + + @Override + public PTransform, PDone> write() { + return WriteFailedElement.newBuilder() + .setJsonSchema(jsonSchema) + .setTableSpec(tableSpec) + .build(); + } + + @AutoValue + public abstract static class WriteFailedElement + extends PTransform, PDone> { + + public abstract String getTableSpec(); + + public abstract String getJsonSchema(); + + public static Builder newBuilder() { + return new AutoValue_BigQueryDeadletterSink_WriteFailedElement.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + + /** + * @param tableSpec Table spec should follow the format "PROJECT_ID:DATASET_ID.TABLE_ID". + * Table will be created if not exists. + */ + public abstract Builder setTableSpec(String tableSpec); + + /** + * @param jsonSchema JSON string describing the schema + * of the table. + */ + public abstract Builder setJsonSchema(String jsonSchema); + + public abstract WriteFailedElement build(); + } + + @Override + public PDone expand(PCollection input) { + TimePartitioning partition = new TimePartitioning().setType("DAY"); + partition.setField(TIMESTAMP_COLUMN); + input + .apply("FailedElementToTableRow", ParDo.of(new FailedElementToTableRowFn())) + .apply( + "WriteFailedElementsToBigQuery", + BigQueryIO.writeTableRows() + .to(getTableSpec()) + .withJsonSchema(getJsonSchema()) + .withTimePartitioning(partition) + .withCreateDisposition(CreateDisposition.CREATE_IF_NEEDED) + .withWriteDisposition(WriteDisposition.WRITE_APPEND)); + return PDone.in(input.getPipeline()); + } + } + + public static class FailedElementToTableRowFn extends DoFn { + @ProcessElement + public void processElement(ProcessContext context) { + final FailedElement element = context.element(); + final TableRow tableRow = + new TableRow() + .set(TIMESTAMP_COLUMN, element.getTimestamp().toString()) + .set("job_name", element.getJobName()) + .set("transform_name", element.getTransformName()) + .set("payload", element.getPayload()) + .set("error_message", element.getErrorMessage()) + .set("stack_trace", element.getStackTrace()); + context.output(tableRow); + } + } +} diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java new file mode 100644 index 00000000000..8860db2622a --- /dev/null +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryFeatureSink.java @@ -0,0 +1,188 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.bigquery.writer; + +import com.google.auto.value.AutoValue; +import com.google.cloud.bigquery.*; +import com.google.common.collect.ImmutableMap; +import feast.core.FeatureSetProto; +import feast.core.StoreProto.Store.BigQueryConfig; +import feast.storage.api.writer.FeatureSink; +import feast.storage.api.writer.WriteResult; +import feast.storage.connectors.bigquery.common.TypeUtil; +import feast.types.FeatureRowProto; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.apache.beam.repackaged.core.org.apache.commons.lang3.tuple.Pair; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; +import org.slf4j.Logger; + +@AutoValue +public abstract class BigQueryFeatureSink implements FeatureSink { + private static final Logger log = org.slf4j.LoggerFactory.getLogger(BigQueryFeatureSink.class); + + // Column description for reserved fields + public static final String BIGQUERY_EVENT_TIMESTAMP_FIELD_DESCRIPTION = + "Event time for the FeatureRow"; + public static final String BIGQUERY_CREATED_TIMESTAMP_FIELD_DESCRIPTION = + "Processing time of the FeatureRow ingestion in Feast\""; + public static final String BIGQUERY_JOB_ID_FIELD_DESCRIPTION = + "Feast import job ID for the FeatureRow"; + + public abstract String getProjectId(); + + public abstract String getDatasetId(); + + public abstract BigQuery getBigQuery(); + + /** + * Initialize a {@link BigQueryFeatureSink.Builder} from a {@link BigQueryConfig}. This method + * initializes a {@link BigQuery} client with default options. Use the builder method to inject + * your own client. + * + * @param config {@link BigQueryConfig} + * @return {@link BigQueryFeatureSink.Builder} + */ + public static BigQueryFeatureSink fromConfig(BigQueryConfig config) { + return builder() + .setDatasetId(config.getDatasetId()) + .setProjectId(config.getProjectId()) + .setBigQuery(BigQueryOptions.getDefaultInstance().getService()) + .build(); + } + + public static Builder builder() { + return new AutoValue_BigQueryFeatureSink.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + + public abstract Builder setProjectId(String projectId); + + public abstract Builder setDatasetId(String datasetId); + + public abstract Builder setBigQuery(BigQuery bigQuery); + + public abstract BigQueryFeatureSink build(); + } + + /** @param featureSet Feature set to be written */ + @Override + public void prepareWrite(FeatureSetProto.FeatureSet featureSet) { + BigQuery bigquery = getBigQuery(); + FeatureSetProto.FeatureSetSpec featureSetSpec = featureSet.getSpec(); + + DatasetId datasetId = DatasetId.of(getProjectId(), getDatasetId()); + if (bigquery.getDataset(datasetId) == null) { + log.info( + "Creating dataset '{}' in project '{}'", datasetId.getDataset(), datasetId.getProject()); + bigquery.create(DatasetInfo.of(datasetId)); + } + String tableName = + String.format( + "%s_%s_v%d", + featureSetSpec.getProject(), featureSetSpec.getName(), featureSetSpec.getVersion()) + .replaceAll("-", "_"); + TableId tableId = TableId.of(datasetId.getProject(), datasetId.getDataset(), tableName); + + // Return if there is an existing table + Table table = bigquery.getTable(tableId); + if (table != null) { + log.info( + "Writing to existing BigQuery table '{}:{}.{}'", + getProjectId(), + datasetId.getDataset(), + tableName); + return; + } + + log.info( + "Creating table '{}' in dataset '{}' in project '{}'", + tableId.getTable(), + datasetId.getDataset(), + datasetId.getProject()); + TableDefinition tableDefinition = createBigQueryTableDefinition(featureSet.getSpec()); + TableInfo tableInfo = TableInfo.of(tableId, tableDefinition); + bigquery.create(tableInfo); + } + + @Override + public PTransform, WriteResult> writer() { + return new BigQueryWrite(DatasetId.of(getProjectId(), getDatasetId())); + } + + private TableDefinition createBigQueryTableDefinition(FeatureSetProto.FeatureSetSpec spec) { + List fields = new ArrayList<>(); + log.info("Table will have the following fields:"); + + for (FeatureSetProto.EntitySpec entitySpec : spec.getEntitiesList()) { + Field.Builder builder = + Field.newBuilder( + entitySpec.getName(), TypeUtil.toStandardSqlType(entitySpec.getValueType())); + if (entitySpec.getValueType().name().toLowerCase().endsWith("_list")) { + builder.setMode(Field.Mode.REPEATED); + } + Field field = builder.build(); + log.info("- {}", field.toString()); + fields.add(field); + } + for (FeatureSetProto.FeatureSpec featureSpec : spec.getFeaturesList()) { + Field.Builder builder = + Field.newBuilder( + featureSpec.getName(), TypeUtil.toStandardSqlType(featureSpec.getValueType())); + if (featureSpec.getValueType().name().toLowerCase().endsWith("_list")) { + builder.setMode(Field.Mode.REPEATED); + } + Field field = builder.build(); + log.info("- {}", field.toString()); + fields.add(field); + } + + // Refer to protos/feast/core/Store.proto for reserved fields in BigQuery. + Map> + reservedFieldNameToPairOfStandardSQLTypeAndDescription = + ImmutableMap.of( + "event_timestamp", + Pair.of(StandardSQLTypeName.TIMESTAMP, BIGQUERY_EVENT_TIMESTAMP_FIELD_DESCRIPTION), + "created_timestamp", + Pair.of( + StandardSQLTypeName.TIMESTAMP, BIGQUERY_CREATED_TIMESTAMP_FIELD_DESCRIPTION), + "job_id", + Pair.of(StandardSQLTypeName.STRING, BIGQUERY_JOB_ID_FIELD_DESCRIPTION)); + for (Map.Entry> entry : + reservedFieldNameToPairOfStandardSQLTypeAndDescription.entrySet()) { + Field field = + Field.newBuilder(entry.getKey(), entry.getValue().getLeft()) + .setDescription(entry.getValue().getRight()) + .build(); + log.info("- {}", field.toString()); + fields.add(field); + } + + TimePartitioning timePartitioning = + TimePartitioning.newBuilder(TimePartitioning.Type.DAY).setField("event_timestamp").build(); + log.info("Table partitioning: " + timePartitioning.toString()); + + return StandardTableDefinition.newBuilder() + .setTimePartitioning(timePartitioning) + .setSchema(Schema.of(fields)) + .build(); + } +} diff --git a/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryWrite.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryWrite.java new file mode 100644 index 00000000000..e3f5e5ae713 --- /dev/null +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/BigQueryWrite.java @@ -0,0 +1,107 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.bigquery.writer; + +import com.google.api.services.bigquery.model.TableDataInsertAllResponse; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.bigquery.DatasetId; +import feast.storage.api.writer.FailedElement; +import feast.storage.api.writer.WriteResult; +import feast.types.FeatureRowProto; +import java.io.IOException; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryInsertError; +import org.apache.beam.sdk.io.gcp.bigquery.InsertRetryPolicy; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.slf4j.Logger; + +/** + * A {@link PTransform} that writes {@link FeatureRowProto FeatureRows} to the specified BigQuery + * dataset, and returns a {@link WriteResult} containing the unsuccessful writes. Since Bigquery + * does not output successful writes, we cannot emit those, and so no success metrics will be + * captured if this sink is used. + */ +public class BigQueryWrite + extends PTransform, WriteResult> { + private static final Logger log = org.slf4j.LoggerFactory.getLogger(BigQueryWrite.class); + + // Destination dataset + private DatasetId destination; + + public BigQueryWrite(DatasetId destination) { + this.destination = destination; + } + + @Override + public WriteResult expand(PCollection input) { + String jobName = input.getPipeline().getOptions().getJobName(); + org.apache.beam.sdk.io.gcp.bigquery.WriteResult bigqueryWriteResult = + input.apply( + "WriteTableRowToBigQuery", + BigQueryIO.write() + .to(new GetTableDestination(destination.getProject(), destination.getDataset())) + .withFormatFunction(new FeatureRowToTableRow(jobName)) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_NEVER) + .withWriteDisposition(BigQueryIO.Write.WriteDisposition.WRITE_APPEND) + .withExtendedErrorInfo() + .withMethod(BigQueryIO.Write.Method.STREAMING_INSERTS) + .withFailedInsertRetryPolicy(InsertRetryPolicy.retryTransientErrors())); + + PCollection failedElements = + bigqueryWriteResult + .getFailedInsertsWithErr() + .apply( + "WrapBigQueryInsertionError", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext context) { + TableDataInsertAllResponse.InsertErrors error = + context.element().getError(); + TableRow row = context.element().getRow(); + try { + context.output( + FailedElement.newBuilder() + .setErrorMessage(error.toPrettyString()) + .setPayload(row.toPrettyString()) + .setJobName(context.getPipelineOptions().getJobName()) + .setTransformName("WriteTableRowToBigQuery") + .build()); + } catch (IOException e) { + log.error(e.getMessage()); + } + } + })); + + // Since BigQueryIO does not support emitting successful writes, we set successfulInserts to + // an empty stream, + // and no metrics will be collected. + PCollection successfulInserts = + input.apply( + "dummy", + ParDo.of( + new DoFn() { + @ProcessElement + public void processElement(ProcessContext context) {} + })); + + return WriteResult.in(input.getPipeline(), successfulInserts, failedElements); + } +} diff --git a/ingestion/src/main/java/feast/store/serving/bigquery/FeatureRowToTableRow.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/FeatureRowToTableRow.java similarity index 98% rename from ingestion/src/main/java/feast/store/serving/bigquery/FeatureRowToTableRow.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/FeatureRowToTableRow.java index b89cf832910..12833b31b85 100644 --- a/ingestion/src/main/java/feast/store/serving/bigquery/FeatureRowToTableRow.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/FeatureRowToTableRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.store.serving.bigquery; +package feast.storage.connectors.bigquery.writer; import com.google.api.services.bigquery.model.TableRow; import com.google.protobuf.util.Timestamps; diff --git a/ingestion/src/main/java/feast/store/serving/bigquery/GetTableDestination.java b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/GetTableDestination.java similarity index 97% rename from ingestion/src/main/java/feast/store/serving/bigquery/GetTableDestination.java rename to storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/GetTableDestination.java index eb37db94498..5903d36b858 100644 --- a/ingestion/src/main/java/feast/store/serving/bigquery/GetTableDestination.java +++ b/storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/writer/GetTableDestination.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.store.serving.bigquery; +package feast.storage.connectors.bigquery.writer; import com.google.api.services.bigquery.model.TimePartitioning; import feast.types.FeatureRowProto.FeatureRow; diff --git a/storage/connectors/bigquery/src/main/resources/schemas/deadletter_table_schema.json b/storage/connectors/bigquery/src/main/resources/schemas/deadletter_table_schema.json new file mode 100644 index 00000000000..92381189073 --- /dev/null +++ b/storage/connectors/bigquery/src/main/resources/schemas/deadletter_table_schema.json @@ -0,0 +1,34 @@ +{ + "fields": [ + { + "name": "timestamp", + "type": "TIMESTAMP", + "mode": "REQUIRED" + }, + { + "name": "job_name", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "transform_name", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "payload", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "error_message", + "type": "STRING", + "mode": "NULLABLE" + }, + { + "name": "stack_trace", + "type": "STRING", + "mode": "NULLABLE" + } + ] +} \ No newline at end of file diff --git a/storage/connectors/bigquery/src/main/resources/templates/join_featuresets.sql b/storage/connectors/bigquery/src/main/resources/templates/join_featuresets.sql new file mode 100644 index 00000000000..60b7c7d7a12 --- /dev/null +++ b/storage/connectors/bigquery/src/main/resources/templates/join_featuresets.sql @@ -0,0 +1,24 @@ +/* + Joins the outputs of multiple point-in-time-correctness joins to a single table. + */ +WITH joined as ( +SELECT * FROM `{{ leftTableName }}` +{% for featureSet in featureSets %} +LEFT JOIN ( + SELECT + uuid, + {% for featureName in featureSet.features %} + {{ featureSet.project }}_{{ featureName }}_v{{ featureSet.version }}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} + FROM `{{ featureSet.table }}` +) USING (uuid) +{% endfor %} +) SELECT + event_timestamp, + {{ entities | join(', ') }} + {% for featureSet in featureSets %} + {% for featureName in featureSet.features %} + ,{{ featureSet.project }}_{{ featureName }}_v{{ featureSet.version }} as {{ featureName }} + {% endfor %} + {% endfor %} +FROM joined \ No newline at end of file diff --git a/storage/connectors/bigquery/src/main/resources/templates/single_featureset_pit_join.sql b/storage/connectors/bigquery/src/main/resources/templates/single_featureset_pit_join.sql new file mode 100644 index 00000000000..fb4c555b529 --- /dev/null +++ b/storage/connectors/bigquery/src/main/resources/templates/single_featureset_pit_join.sql @@ -0,0 +1,90 @@ +/* + This query template performs the point-in-time correctness join for a single feature set table + to the provided entity table. + + 1. Concatenate the timestamp and entities from the feature set table with the entity dataset. + Feature values are joined to this table later for improved efficiency. + featureset_timestamp is equal to null in rows from the entity dataset. + */ +WITH union_features AS ( +SELECT + -- uuid is a unique identifier for each row in the entity dataset. Generated by `QueryTemplater.createEntityTableUUIDQuery` + uuid, + -- event_timestamp contains the timestamps to join onto + event_timestamp, + -- the feature_timestamp, i.e. the latest occurrence of the requested feature relative to the entity_dataset timestamp + NULL as {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp, + -- created timestamp of the feature at the corresponding feature_timestamp + NULL as created_timestamp, + -- select only entities belonging to this feature set + {{ featureSet.entities | join(', ')}}, + -- boolean for filtering the dataset later + true AS is_entity_table +FROM `{{leftTableName}}` +UNION ALL +SELECT + NULL as uuid, + event_timestamp, + event_timestamp as {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp, + created_timestamp, + {{ featureSet.entities | join(', ')}}, + false AS is_entity_table +FROM `{{projectId}}.{{datasetId}}.{{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}` WHERE event_timestamp <= '{{maxTimestamp}}' +{% if featureSet.maxAge == 0 %}{% else %}AND event_timestamp >= Timestamp_sub(TIMESTAMP '{{ minTimestamp }}', interval {{ featureSet.maxAge }} second){% endif %} +), +/* + 2. Window the data in the unioned dataset, partitioning by entity and ordering by event_timestamp, as + well as is_entity_table. + Within each window, back-fill the feature_timestamp - as a result of this, the null feature_timestamps + in the rows from the entity table should now contain the latest timestamps relative to the row's + event_timestamp. + + For rows where event_timestamp(provided datetime) - feature_timestamp > max age, set the + feature_timestamp to null. + */ +joined AS ( +SELECT + uuid, + event_timestamp, + {{ featureSet.entities | join(', ')}}, + {% for featureName in featureSet.features %} + IF(event_timestamp >= {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp {% if featureSet.maxAge == 0 %}{% else %}AND Timestamp_sub(event_timestamp, interval {{ featureSet.maxAge }} second) < {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp{% endif %}, {{ featureSet.project }}_{{ featureName }}_v{{ featureSet.version }}, NULL) as {{ featureSet.project }}_{{ featureName }}_v{{ featureSet.version }}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} +FROM ( +SELECT + uuid, + event_timestamp, + {{ featureSet.entities | join(', ')}}, + FIRST_VALUE(created_timestamp IGNORE NULLS) over w AS created_timestamp, + FIRST_VALUE({{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp IGNORE NULLS) over w AS {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp, + is_entity_table +FROM union_features +WINDOW w AS (PARTITION BY {{ featureSet.entities | join(', ') }} ORDER BY event_timestamp DESC, is_entity_table DESC, created_timestamp DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) +) +/* + 3. Select only the rows from the entity table, and join the features from the original feature set table + to the dataset using the entity values, feature_timestamp, and created_timestamps. + */ +LEFT JOIN ( +SELECT + event_timestamp as {{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp, + created_timestamp, + {{ featureSet.entities | join(', ')}}, + {% for featureName in featureSet.features %} + {{ featureName }} as {{ featureSet.project }}_{{ featureName }}_v{{ featureSet.version }}{% if loop.last %}{% else %}, {% endif %} + {% endfor %} +FROM `{{ projectId }}.{{ datasetId }}.{{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}` WHERE event_timestamp <= '{{maxTimestamp}}' +{% if featureSet.maxAge == 0 %}{% else %}AND event_timestamp >= Timestamp_sub(TIMESTAMP '{{ minTimestamp }}', interval {{ featureSet.maxAge }} second){% endif %} +) USING ({{ featureSet.project }}_{{ featureSet.name }}_v{{ featureSet.version }}_feature_timestamp, created_timestamp, {{ featureSet.entities | join(', ')}}) +WHERE is_entity_table +) +/* + 4. Finally, deduplicate the rows by selecting the first occurrence of each entity table row UUID. + */ +SELECT + k.* +FROM ( + SELECT ARRAY_AGG(row LIMIT 1)[OFFSET(0)] k + FROM joined row + GROUP BY uuid +) \ No newline at end of file diff --git a/storage/connectors/pom.xml b/storage/connectors/pom.xml new file mode 100644 index 00000000000..b52668a31a4 --- /dev/null +++ b/storage/connectors/pom.xml @@ -0,0 +1,51 @@ + + + + dev.feast + feast-parent + ${revision} + ../.. + + + 4.0.0 + feast-storage-connectors + pom + + Feast Storage Connectors + + + redis + bigquery + + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + + javax.annotation + + + + + + + + + dev.feast + datatypes-java + ${project.version} + + + + dev.feast + feast-storage-api + ${project.version} + + + + diff --git a/storage/connectors/redis/pom.xml b/storage/connectors/redis/pom.xml new file mode 100644 index 00000000000..6c50895bd20 --- /dev/null +++ b/storage/connectors/redis/pom.xml @@ -0,0 +1,82 @@ + + + + dev.feast + feast-storage-connectors + ${revision} + + + 4.0.0 + feast-storage-connector-redis + + Feast Storage Connector for Redis + + + + io.lettuce + lettuce-core + + + + org.apache.commons + commons-lang3 + 3.9 + + + + com.google.auto.value + auto-value-annotations + 1.6.6 + + + + com.google.auto.value + auto-value + 1.6.6 + provided + + + + org.mockito + mockito-core + 2.23.0 + test + + + + + com.github.kstyrc + embedded-redis + test + + + + org.apache.beam + beam-runners-direct-java + ${org.apache.beam.version} + test + + + + org.hamcrest + hamcrest-core + test + + + + org.hamcrest + hamcrest-library + test + + + + + junit + junit + 4.12 + test + + + + diff --git a/serving/src/main/java/feast/serving/encoding/FeatureRowDecoder.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java similarity index 98% rename from serving/src/main/java/feast/serving/encoding/FeatureRowDecoder.java rename to storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java index e70695d8c64..a5506028cbf 100644 --- a/serving/src/main/java/feast/serving/encoding/FeatureRowDecoder.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/FeatureRowDecoder.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.encoding; +package feast.storage.connectors.redis.retriever; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSpec; diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java new file mode 100644 index 00000000000..c8bb33de5fd --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java @@ -0,0 +1,204 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.retriever; + +import com.google.protobuf.AbstractMessageLite; +import com.google.protobuf.InvalidProtocolBufferException; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.serving.ServingAPIProto.FeatureReference; +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.storage.RedisProto.RedisKey; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.storage.api.retriever.OnlineRetriever; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import io.grpc.Status; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; + +public class RedisOnlineRetriever implements OnlineRetriever { + + private final RedisCommands syncCommands; + + public RedisOnlineRetriever(StatefulRedisConnection connection) { + this.syncCommands = connection.sync(); + } + + /** + * Gets online features from redis. This method returns a list of {@link FeatureRow}s + * corresponding to each feature set spec. Each feature row in the list then corresponds to an + * {@link EntityRow} provided by the user. + * + * @param entityRows list of entity rows in the feature request + * @param featureSetRequests Map of {@link feast.core.FeatureSetProto.FeatureSetSpec} to feature + * references in the request tied to that feature set. + * @return List of List of {@link FeatureRow} + */ + @Override + public List> getOnlineFeatures( + List entityRows, List featureSetRequests) { + + List> featureRows = new ArrayList<>(); + for (FeatureSetRequest featureSetRequest : featureSetRequests) { + List redisKeys = buildRedisKeys(entityRows, featureSetRequest.getSpec()); + try { + List featureRowsForFeatureSet = + sendAndProcessMultiGet( + redisKeys, + featureSetRequest.getSpec(), + featureSetRequest.getFeatureReferences().asList()); + featureRows.add(featureRowsForFeatureSet); + } catch (InvalidProtocolBufferException | ExecutionException e) { + throw Status.INTERNAL + .withDescription("Unable to parse protobuf while retrieving feature") + .withCause(e) + .asRuntimeException(); + } + } + return featureRows; + } + + private List buildRedisKeys(List entityRows, FeatureSetSpec featureSetSpec) { + String featureSetRef = generateFeatureSetStringRef(featureSetSpec); + List featureSetEntityNames = + featureSetSpec.getEntitiesList().stream() + .map(EntitySpec::getName) + .collect(Collectors.toList()); + List redisKeys = + entityRows.stream() + .map(row -> makeRedisKey(featureSetRef, featureSetEntityNames, row)) + .collect(Collectors.toList()); + return redisKeys; + } + + /** + * Create {@link RedisKey} + * + * @param featureSet featureSet reference of the feature. E.g. feature_set_1:1 + * @param featureSetEntityNames entity names that belong to the featureSet + * @param entityRow entityRow to build the key from + * @return {@link RedisKey} + */ + private RedisKey makeRedisKey( + String featureSet, List featureSetEntityNames, EntityRow entityRow) { + RedisKey.Builder builder = RedisKey.newBuilder().setFeatureSet(featureSet); + Map fieldsMap = entityRow.getFieldsMap(); + featureSetEntityNames.sort(String::compareTo); + for (int i = 0; i < featureSetEntityNames.size(); i++) { + String entityName = featureSetEntityNames.get(i); + + if (!fieldsMap.containsKey(entityName)) { + throw Status.INVALID_ARGUMENT + .withDescription( + String.format( + "Entity row fields \"%s\" does not contain required entity field \"%s\"", + fieldsMap.keySet().toString(), entityName)) + .asRuntimeException(); + } + + builder.addEntities( + Field.newBuilder().setName(entityName).setValue(fieldsMap.get(entityName))); + } + return builder.build(); + } + + private List sendAndProcessMultiGet( + List redisKeys, + FeatureSetSpec featureSetSpec, + List featureReferences) + throws InvalidProtocolBufferException, ExecutionException { + + List values = sendMultiGet(redisKeys); + List featureRows = new ArrayList<>(); + + FeatureRow.Builder nullFeatureRowBuilder = + FeatureRow.newBuilder().setFeatureSet(generateFeatureSetStringRef(featureSetSpec)); + for (FeatureReference featureReference : featureReferences) { + nullFeatureRowBuilder.addFields(Field.newBuilder().setName(featureReference.getName())); + } + + for (int i = 0; i < values.size(); i++) { + + byte[] value = values.get(i); + if (value == null) { + featureRows.add(nullFeatureRowBuilder.build()); + continue; + } + + FeatureRow featureRow = FeatureRow.parseFrom(value); + String featureSetRef = redisKeys.get(i).getFeatureSet(); + FeatureRowDecoder decoder = new FeatureRowDecoder(featureSetRef, featureSetSpec); + if (decoder.isEncoded(featureRow)) { + if (decoder.isEncodingValid(featureRow)) { + featureRow = decoder.decode(featureRow); + } else { + featureRows.add(nullFeatureRowBuilder.build()); + continue; + } + } + + featureRows.add(featureRow); + } + return featureRows; + } + + /** + * Send a list of get request as an mget + * + * @param keys list of {@link RedisKey} + * @return list of {@link FeatureRow} in primitive byte representation for each {@link RedisKey} + */ + private List sendMultiGet(List keys) { + try { + byte[][] binaryKeys = + keys.stream() + .map(AbstractMessageLite::toByteArray) + .collect(Collectors.toList()) + .toArray(new byte[0][0]); + return syncCommands.mget(binaryKeys).stream() + .map( + keyValue -> { + if (keyValue == null) { + return null; + } + return keyValue.getValueOrElse(null); + }) + .collect(Collectors.toList()); + } catch (Exception e) { + throw Status.NOT_FOUND + .withDescription("Unable to retrieve feature from Redis") + .withCause(e) + .asRuntimeException(); + } + } + + // TODO: Refactor this out to common package? + private static String generateFeatureSetStringRef(FeatureSetSpec featureSetSpec) { + String ref = String.format("%s/%s", featureSetSpec.getProject(), featureSetSpec.getName()); + if (featureSetSpec.getVersion() > 0) { + return ref + String.format(":%d", featureSetSpec.getVersion()); + } + return ref; + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java new file mode 100644 index 00000000000..cfe7771b324 --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisCustomIO.java @@ -0,0 +1,292 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.writer; + +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.core.StoreProto.Store.RedisConfig; +import feast.storage.RedisProto.RedisKey; +import feast.storage.RedisProto.RedisKey.Builder; +import feast.storage.api.writer.FailedElement; +import feast.storage.api.writer.WriteResult; +import feast.storage.common.retry.Retriable; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto; +import io.lettuce.core.RedisConnectionException; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.stream.Collectors; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.windowing.GlobalWindow; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.TupleTagList; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.joda.time.Instant; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class RedisCustomIO { + + private static final int DEFAULT_BATCH_SIZE = 1000; + private static final int DEFAULT_TIMEOUT = 2000; + + private static TupleTag successfulInsertsTag = new TupleTag<>("successfulInserts") {}; + private static TupleTag failedInsertsTupleTag = new TupleTag<>("failedInserts") {}; + + private static final Logger log = LoggerFactory.getLogger(RedisCustomIO.class); + + private RedisCustomIO() {} + + public static Write write(RedisConfig redisConfig, Map featureSetSpecs) { + return new Write(redisConfig, featureSetSpecs); + } + + /** ServingStoreWrite data to a Redis server. */ + public static class Write extends PTransform, WriteResult> { + + private Map featureSetSpecs; + private RedisConfig redisConfig; + private int batchSize; + private int timeout; + + public Write(RedisConfig redisConfig, Map featureSetSpecs) { + + this.redisConfig = redisConfig; + this.featureSetSpecs = featureSetSpecs; + } + + public Write withBatchSize(int batchSize) { + this.batchSize = batchSize; + return this; + } + + public Write withTimeout(int timeout) { + this.timeout = timeout; + return this; + } + + @Override + public WriteResult expand(PCollection input) { + PCollectionTuple redisWrite = + input.apply( + ParDo.of(new WriteDoFn(redisConfig, featureSetSpecs)) + .withOutputTags(successfulInsertsTag, TupleTagList.of(failedInsertsTupleTag))); + return WriteResult.in( + input.getPipeline(), + redisWrite.get(successfulInsertsTag), + redisWrite.get(failedInsertsTupleTag)); + } + + public static class WriteDoFn extends DoFn { + + private final List featureRows = new ArrayList<>(); + private Map featureSetSpecs; + private int batchSize = DEFAULT_BATCH_SIZE; + private int timeout = DEFAULT_TIMEOUT; + private RedisIngestionClient redisIngestionClient; + + WriteDoFn(RedisConfig config, Map featureSetSpecs) { + + this.redisIngestionClient = new RedisStandaloneIngestionClient(config); + this.featureSetSpecs = featureSetSpecs; + } + + public WriteDoFn withBatchSize(int batchSize) { + if (batchSize > 0) { + this.batchSize = batchSize; + } + return this; + } + + public WriteDoFn withTimeout(int timeout) { + if (timeout > 0) { + this.timeout = timeout; + } + return this; + } + + @Setup + public void setup() { + this.redisIngestionClient.setup(); + } + + @StartBundle + public void startBundle() { + try { + redisIngestionClient.connect(); + } catch (RedisConnectionException e) { + log.error("Connection to redis cannot be established ", e); + } + featureRows.clear(); + } + + private void executeBatch() throws Exception { + this.redisIngestionClient + .getBackOffExecutor() + .execute( + new Retriable() { + @Override + public void execute() throws ExecutionException, InterruptedException { + if (!redisIngestionClient.isConnected()) { + redisIngestionClient.connect(); + } + featureRows.forEach( + row -> { + redisIngestionClient.set(getKey(row), getValue(row)); + }); + redisIngestionClient.sync(); + } + + @Override + public Boolean isExceptionRetriable(Exception e) { + return e instanceof RedisConnectionException; + } + + @Override + public void cleanUpAfterFailure() {} + }); + } + + private FailedElement toFailedElement( + FeatureRow featureRow, Exception exception, String jobName) { + return FailedElement.newBuilder() + .setJobName(jobName) + .setTransformName("RedisCustomIO") + .setPayload(featureRow.toString()) + .setErrorMessage(exception.getMessage()) + .setStackTrace(ExceptionUtils.getStackTrace(exception)) + .build(); + } + + private byte[] getKey(FeatureRow featureRow) { + FeatureSetSpec featureSetSpec = featureSetSpecs.get(featureRow.getFeatureSet()); + List entityNames = + featureSetSpec.getEntitiesList().stream() + .map(EntitySpec::getName) + .sorted() + .collect(Collectors.toList()); + + Map entityFields = new HashMap<>(); + Builder redisKeyBuilder = RedisKey.newBuilder().setFeatureSet(featureRow.getFeatureSet()); + for (Field field : featureRow.getFieldsList()) { + if (entityNames.contains(field.getName())) { + entityFields.putIfAbsent( + field.getName(), + Field.newBuilder().setName(field.getName()).setValue(field.getValue()).build()); + } + } + for (String entityName : entityNames) { + redisKeyBuilder.addEntities(entityFields.get(entityName)); + } + return redisKeyBuilder.build().toByteArray(); + } + + private byte[] getValue(FeatureRow featureRow) { + FeatureSetSpec spec = featureSetSpecs.get(featureRow.getFeatureSet()); + + List featureNames = + spec.getFeaturesList().stream().map(FeatureSpec::getName).collect(Collectors.toList()); + Map fieldValueOnlyMap = + featureRow.getFieldsList().stream() + .filter(field -> featureNames.contains(field.getName())) + .distinct() + .collect( + Collectors.toMap( + Field::getName, + field -> Field.newBuilder().setValue(field.getValue()).build())); + + List values = + featureNames.stream() + .sorted() + .map( + featureName -> + fieldValueOnlyMap.getOrDefault( + featureName, + Field.newBuilder() + .setValue(ValueProto.Value.getDefaultInstance()) + .build())) + .collect(Collectors.toList()); + + return FeatureRow.newBuilder() + .setEventTimestamp(featureRow.getEventTimestamp()) + .addAllFields(values) + .build() + .toByteArray(); + } + + @ProcessElement + public void processElement(ProcessContext context) { + FeatureRow featureRow = context.element(); + featureRows.add(featureRow); + if (featureRows.size() >= batchSize) { + try { + executeBatch(); + featureRows.forEach(row -> context.output(successfulInsertsTag, row)); + featureRows.clear(); + } catch (Exception e) { + featureRows.forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); + context.output(failedInsertsTupleTag, failedElement); + }); + featureRows.clear(); + } + } + } + + @FinishBundle + public void finishBundle(FinishBundleContext context) + throws IOException, InterruptedException { + if (featureRows.size() > 0) { + try { + executeBatch(); + featureRows.forEach( + row -> + context.output( + successfulInsertsTag, row, Instant.now(), GlobalWindow.INSTANCE)); + featureRows.clear(); + } catch (Exception e) { + featureRows.forEach( + failedMutation -> { + FailedElement failedElement = + toFailedElement(failedMutation, e, context.getPipelineOptions().getJobName()); + context.output( + failedInsertsTupleTag, failedElement, Instant.now(), GlobalWindow.INSTANCE); + }); + featureRows.clear(); + } + } + } + + @Teardown + public void teardown() { + redisIngestionClient.shutdown(); + } + } + } +} diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisFeatureSink.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisFeatureSink.java new file mode 100644 index 00000000000..63c8c68d9bb --- /dev/null +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisFeatureSink.java @@ -0,0 +1,74 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.writer; + +import com.google.auto.value.AutoValue; +import feast.core.FeatureSetProto.FeatureSet; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.StoreProto.Store.RedisConfig; +import feast.storage.api.writer.FeatureSink; +import feast.storage.api.writer.WriteResult; +import feast.types.FeatureRowProto.FeatureRow; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisConnectionException; +import io.lettuce.core.RedisURI; +import java.util.Map; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PCollection; + +@AutoValue +public abstract class RedisFeatureSink implements FeatureSink { + + public abstract RedisConfig getRedisConfig(); + + public abstract Map getFeatureSetSpecs(); + + public abstract Builder toBuilder(); + + public static Builder builder() { + return new AutoValue_RedisFeatureSink.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setRedisConfig(RedisConfig redisConfig); + + public abstract Builder setFeatureSetSpecs(Map featureSetSpecs); + + public abstract RedisFeatureSink build(); + } + + @Override + public void prepareWrite(FeatureSet featureSet) { + RedisClient redisClient = + RedisClient.create(RedisURI.create(getRedisConfig().getHost(), getRedisConfig().getPort())); + try { + redisClient.connect(); + } catch (RedisConnectionException e) { + throw new RuntimeException( + String.format( + "Failed to connect to Redis at host: '%s' port: '%d'. Please check that your Redis is running and accessible from Feast.", + getRedisConfig().getHost(), getRedisConfig().getPort())); + } + redisClient.shutdown(); + } + + @Override + public PTransform, WriteResult> writer() { + return new RedisCustomIO.Write(getRedisConfig(), getFeatureSetSpecs()); + } +} diff --git a/ingestion/src/main/java/feast/store/serving/redis/RedisIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java similarity index 92% rename from ingestion/src/main/java/feast/store/serving/redis/RedisIngestionClient.java rename to storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java index d51eead53fb..6616a79aaca 100644 --- a/ingestion/src/main/java/feast/store/serving/redis/RedisIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisIngestionClient.java @@ -14,9 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.store.serving.redis; +package feast.storage.connectors.redis.writer; -import feast.retry.BackOffExecutor; +import feast.storage.common.retry.BackOffExecutor; import java.io.Serializable; public interface RedisIngestionClient extends Serializable { diff --git a/ingestion/src/main/java/feast/store/serving/redis/RedisStandaloneIngestionClient.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java similarity index 93% rename from ingestion/src/main/java/feast/store/serving/redis/RedisStandaloneIngestionClient.java rename to storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java index d95ebbbf64a..95bd7ad1516 100644 --- a/ingestion/src/main/java/feast/store/serving/redis/RedisStandaloneIngestionClient.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/writer/RedisStandaloneIngestionClient.java @@ -14,12 +14,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.store.serving.redis; +package feast.storage.connectors.redis.writer; import com.google.common.collect.Lists; import feast.core.StoreProto; -import feast.retry.BackOffExecutor; -import io.lettuce.core.*; +import feast.storage.common.retry.BackOffExecutor; +import io.lettuce.core.LettuceFutures; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisFuture; +import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.async.RedisAsyncCommands; import io.lettuce.core.codec.ByteArrayCodec; diff --git a/serving/src/test/java/feast/serving/encoding/FeatureRowDecoderTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java similarity index 98% rename from serving/src/test/java/feast/serving/encoding/FeatureRowDecoderTest.java rename to storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java index 8f6c79ad66c..0f37e68941c 100644 --- a/serving/src/test/java/feast/serving/encoding/FeatureRowDecoderTest.java +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/FeatureRowDecoderTest.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.encoding; +package feast.storage.connectors.redis.retriever; import static org.junit.Assert.*; diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/RedisOnlineRetrieverTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/RedisOnlineRetrieverTest.java new file mode 100644 index 00000000000..11c216c5a0c --- /dev/null +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/retriever/RedisOnlineRetrieverTest.java @@ -0,0 +1,262 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.retriever; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import com.google.protobuf.AbstractMessageLite; +import com.google.protobuf.Duration; +import com.google.protobuf.Timestamp; +import feast.core.FeatureSetProto.EntitySpec; +import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.FeatureSetProto.FeatureSpec; +import feast.serving.ServingAPIProto.FeatureReference; +import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; +import feast.storage.RedisProto.RedisKey; +import feast.storage.api.retriever.FeatureSetRequest; +import feast.types.FeatureRowProto.FeatureRow; +import feast.types.FieldProto.Field; +import feast.types.ValueProto.Value; +import io.lettuce.core.KeyValue; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisCommands; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; + +public class RedisOnlineRetrieverTest { + + @Mock StatefulRedisConnection connection; + + @Mock RedisCommands syncCommands; + + private RedisOnlineRetriever redisOnlineRetriever; + private byte[][] redisKeyList; + + @Before + public void setUp() { + initMocks(this); + when(connection.sync()).thenReturn(syncCommands); + redisOnlineRetriever = new RedisOnlineRetriever(connection); + redisKeyList = + Lists.newArrayList( + RedisKey.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllEntities( + Lists.newArrayList( + Field.newBuilder().setName("entity1").setValue(intValue(1)).build(), + Field.newBuilder().setName("entity2").setValue(strValue("a")).build())) + .build(), + RedisKey.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllEntities( + Lists.newArrayList( + Field.newBuilder().setName("entity1").setValue(intValue(2)).build(), + Field.newBuilder().setName("entity2").setValue(strValue("b")).build())) + .build()) + .stream() + .map(AbstractMessageLite::toByteArray) + .collect(Collectors.toList()) + .toArray(new byte[0][0]); + } + + @Test + public void shouldReturnResponseWithValuesIfKeysPresent() { + FeatureSetRequest featureSetRequest = + FeatureSetRequest.newBuilder() + .setSpec(getFeatureSetSpec()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature1") + .setVersion(1) + .setProject("project") + .build()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature2") + .setVersion(1) + .setProject("project") + .build()) + .build(); + List entityRows = + ImmutableList.of( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(1)) + .putFields("entity2", strValue("a")) + .build(), + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(2)) + .putFields("entity2", strValue("b")) + .build()); + + List featureRows = + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(1)).build(), + Field.newBuilder().setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(2)).build(), + Field.newBuilder().setValue(intValue(2)).build())) + .build()); + + List> featureRowBytes = + featureRows.stream() + .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) + .collect(Collectors.toList()); + + redisOnlineRetriever = new RedisOnlineRetriever(connection); + when(connection.sync()).thenReturn(syncCommands); + when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + + List> expected = + List.of( + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(2)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(2)).build())) + .build())); + + List> actual = + redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest)); + assertThat(actual, equalTo(expected)); + } + + @Test + public void shouldReturnResponseWithUnsetValuesIfKeysNotPresent() { + FeatureSetRequest featureSetRequest = + FeatureSetRequest.newBuilder() + .setSpec(getFeatureSetSpec()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature1") + .setVersion(1) + .setProject("project") + .build()) + .addFeatureReference( + FeatureReference.newBuilder() + .setName("feature2") + .setVersion(1) + .setProject("project") + .build()) + .build(); + List entityRows = + ImmutableList.of( + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(1)) + .putFields("entity2", strValue("a")) + .build(), + EntityRow.newBuilder() + .setEntityTimestamp(Timestamp.newBuilder().setSeconds(100)) + .putFields("entity1", intValue(2)) + .putFields("entity2", strValue("b")) + .build()); + + List featureRows = + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setValue(intValue(1)).build(), + Field.newBuilder().setValue(intValue(1)).build())) + .build()); + + List> featureRowBytes = + featureRows.stream() + .map(x -> KeyValue.from(new byte[1], Optional.of(x.toByteArray()))) + .collect(Collectors.toList()); + featureRowBytes.add(null); + + redisOnlineRetriever = new RedisOnlineRetriever(connection); + when(connection.sync()).thenReturn(syncCommands); + when(syncCommands.mget(redisKeyList)).thenReturn(featureRowBytes); + + List> expected = + List.of( + Lists.newArrayList( + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.newBuilder().setSeconds(100)) + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").setValue(intValue(1)).build(), + Field.newBuilder().setName("feature2").setValue(intValue(1)).build())) + .build(), + FeatureRow.newBuilder() + .setFeatureSet("project/featureSet:1") + .addAllFields( + Lists.newArrayList( + Field.newBuilder().setName("feature1").build(), + Field.newBuilder().setName("feature2").build())) + .build())); + + List> actual = + redisOnlineRetriever.getOnlineFeatures(entityRows, List.of(featureSetRequest)); + assertThat(actual, equalTo(expected)); + } + + private Value intValue(int val) { + return Value.newBuilder().setInt64Val(val).build(); + } + + private Value strValue(String val) { + return Value.newBuilder().setStringVal(val).build(); + } + + private FeatureSetSpec getFeatureSetSpec() { + return FeatureSetSpec.newBuilder() + .setProject("project") + .setName("featureSet") + .setVersion(1) + .addEntities(EntitySpec.newBuilder().setName("entity1")) + .addEntities(EntitySpec.newBuilder().setName("entity2")) + .addFeatures(FeatureSpec.newBuilder().setName("feature1")) + .addFeatures(FeatureSpec.newBuilder().setName("feature2")) + .setMaxAge(Duration.newBuilder().setSeconds(30)) // default + .build(); + } +} diff --git a/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java new file mode 100644 index 00000000000..66aba44bc20 --- /dev/null +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/test/TestUtil.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2020 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.storage.connectors.redis.test; + +import java.io.IOException; +import redis.embedded.RedisServer; + +public class TestUtil { + public static class LocalRedis { + + private static RedisServer server; + + /** + * Start local Redis for used in testing at "localhost" + * + * @param port port number + * @throws IOException if Redis failed to start + */ + public static void start(int port) throws IOException { + server = new RedisServer(port); + server.start(); + } + + public static void stop() { + if (server != null) { + server.stop(); + } + } + } +} diff --git a/ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java similarity index 50% rename from ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java rename to storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java index 7db5e28ecb8..beeabc2c884 100644 --- a/ingestion/src/test/java/feast/store/serving/redis/FeatureRowToRedisMutationDoFnTest.java +++ b/storage/connectors/redis/src/test/java/feast/storage/connectors/redis/writer/RedisFeatureSinkTest.java @@ -1,6 +1,6 @@ /* * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors + * Copyright 2018-2019 The Feast Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,142 +14,259 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.store.serving.redis; +package feast.storage.connectors.redis.writer; -import static org.junit.Assert.*; +import static feast.storage.common.testing.TestUtil.field; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Timestamp; -import feast.core.FeatureSetProto; import feast.core.FeatureSetProto.EntitySpec; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSpec; +import feast.core.StoreProto; +import feast.core.StoreProto.Store.RedisConfig; import feast.storage.RedisProto.RedisKey; -import feast.store.serving.redis.RedisCustomIO.RedisMutation; import feast.types.FeatureRowProto.FeatureRow; import feast.types.FieldProto.Field; import feast.types.ValueProto.Value; import feast.types.ValueProto.ValueType.Enum; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; +import io.lettuce.core.api.StatefulRedisConnection; +import io.lettuce.core.api.sync.RedisStringCommands; +import io.lettuce.core.codec.ByteArrayCodec; +import java.io.IOException; import java.util.*; -import org.apache.beam.sdk.extensions.protobuf.ProtoCoder; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Count; import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.PCollection; +import org.junit.After; +import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import redis.embedded.Redis; +import redis.embedded.RedisServer; -public class FeatureRowToRedisMutationDoFnTest { - +public class RedisFeatureSinkTest { @Rule public transient TestPipeline p = TestPipeline.create(); - private FeatureSetProto.FeatureSet fs = - FeatureSetProto.FeatureSet.newBuilder() - .setSpec( - FeatureSetSpec.newBuilder() - .setName("feature_set") - .setVersion(1) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_primary") - .setValueType(Enum.INT32) - .build()) - .addEntities( - EntitySpec.newBuilder() - .setName("entity_id_secondary") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_1") - .setValueType(Enum.STRING) - .build()) - .addFeatures( - FeatureSpec.newBuilder() - .setName("feature_2") - .setValueType(Enum.INT64) - .build())) - .build(); - - @Test - public void shouldConvertRowWithDuplicateEntitiesToValidKey() { - Map featureSets = new HashMap<>(); - featureSets.put("feature_set", fs); - - FeatureRow offendingRow = - FeatureRow.newBuilder() - .setFeatureSet("feature_set") - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) - .addFields( - Field.newBuilder() - .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(2))) - .addFields( - Field.newBuilder() - .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) - .addFields( - Field.newBuilder() - .setName("feature_1") - .setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields( - Field.newBuilder() - .setName("feature_2") - .setValue(Value.newBuilder().setInt64Val(1001))) + private static String REDIS_HOST = "localhost"; + private static int REDIS_PORT = 51234; + private Redis redis; + private RedisClient redisClient; + private RedisStringCommands sync; + + private RedisFeatureSink redisFeatureSink; + + @Before + public void setUp() throws IOException { + redis = new RedisServer(REDIS_PORT); + redis.start(); + redisClient = + RedisClient.create(new RedisURI(REDIS_HOST, REDIS_PORT, java.time.Duration.ofMillis(2000))); + StatefulRedisConnection connection = redisClient.connect(new ByteArrayCodec()); + sync = connection.sync(); + + FeatureSetSpec spec1 = + FeatureSetSpec.newBuilder() + .setName("fs") + .setVersion(1) + .setProject("myproject") + .addEntities(EntitySpec.newBuilder().setName("entity").setValueType(Enum.INT64).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature").setValueType(Enum.STRING).build()) .build(); - PCollection output = - p.apply(Create.of(Collections.singletonList(offendingRow))) - .setCoder(ProtoCoder.of(FeatureRow.class)) - .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets))); - - RedisKey expectedKey = - RedisKey.newBuilder() - .setFeatureSet("feature_set") + FeatureSetSpec spec2 = + FeatureSetSpec.newBuilder() + .setName("feature_set") + .setProject("myproject") + .setVersion(1) .addEntities( - Field.newBuilder() + EntitySpec.newBuilder() .setName("entity_id_primary") - .setValue(Value.newBuilder().setInt32Val(1))) + .setValueType(Enum.INT32) + .build()) .addEntities( - Field.newBuilder() + EntitySpec.newBuilder() .setName("entity_id_secondary") - .setValue(Value.newBuilder().setStringVal("a"))) + .setValueType(Enum.STRING) + .build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_1").setValueType(Enum.STRING).build()) + .addFeatures( + FeatureSpec.newBuilder().setName("feature_2").setValueType(Enum.INT64).build()) .build(); - FeatureRow expectedValue = + Map specMap = + ImmutableMap.of("myproject/fs:1", spec1, "myproject/feature_set:1", spec2); + StoreProto.Store.RedisConfig redisConfig = + StoreProto.Store.RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT).build(); + + redisFeatureSink = + RedisFeatureSink.builder().setFeatureSetSpecs(specMap).setRedisConfig(redisConfig).build(); + } + + @After + public void teardown() { + redisClient.shutdown(); + redis.stop(); + } + + @Test + public void shouldWriteToRedis() { + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), FeatureRow.newBuilder() - .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("strValue1"))) - .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 2, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("two"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build(), + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 2, Enum.INT64)) + .addFields(field("feature", "two", Enum.STRING)) + .build()); + + p.apply(Create.of(featureRows)).apply(redisFeatureSink.writer()); + p.run(); + + kvs.forEach( + (key, value) -> { + byte[] actual = sync.get(key.toByteArray()); + assertThat(actual, equalTo(value.toByteArray())); + }); + } + + @Test(timeout = 10000) + public void shouldRetryFailConnection() throws InterruptedException { + RedisConfig redisConfig = + RedisConfig.newBuilder() + .setHost(REDIS_HOST) + .setPort(REDIS_PORT) + .setMaxRetries(4) + .setInitialBackoffMs(2000) .build(); + redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build(); + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build()); + + PCollection failedElementCount = + p.apply(Create.of(featureRows)) + .apply(redisFeatureSink.writer()) + .getFailedInserts() + .apply(Count.globally()); + + redis.stop(); + final ScheduledThreadPoolExecutor redisRestartExecutor = new ScheduledThreadPoolExecutor(1); + ScheduledFuture scheduledRedisRestart = + redisRestartExecutor.schedule( + () -> { + redis.start(); + }, + 3, + TimeUnit.SECONDS); + + PAssert.that(failedElementCount).containsInAnyOrder(0L); + p.run(); + scheduledRedisRestart.cancel(true); + + kvs.forEach( + (key, value) -> { + byte[] actual = sync.get(key.toByteArray()); + assertThat(actual, equalTo(value.toByteArray())); + }); + } + + @Test + public void shouldProduceFailedElementIfRetryExceeded() { - PAssert.that(output) - .satisfies( - (SerializableFunction, Void>) - input -> { - input.forEach( - rm -> { - assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray())); - assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray())); - }); - return null; - }); + RedisConfig redisConfig = + RedisConfig.newBuilder().setHost(REDIS_HOST).setPort(REDIS_PORT + 1).build(); + redisFeatureSink = redisFeatureSink.toBuilder().setRedisConfig(redisConfig).build(); + + HashMap kvs = new LinkedHashMap<>(); + kvs.put( + RedisKey.newBuilder() + .setFeatureSet("myproject/fs:1") + .addEntities(field("entity", 1, Enum.INT64)) + .build(), + FeatureRow.newBuilder() + .setEventTimestamp(Timestamp.getDefaultInstance()) + .addFields(Field.newBuilder().setValue(Value.newBuilder().setStringVal("one"))) + .build()); + + List featureRows = + ImmutableList.of( + FeatureRow.newBuilder() + .setFeatureSet("myproject/fs:1") + .addFields(field("entity", 1, Enum.INT64)) + .addFields(field("feature", "one", Enum.STRING)) + .build()); + + PCollection failedElementCount = + p.apply(Create.of(featureRows)) + .apply(redisFeatureSink.writer()) + .getFailedInserts() + .apply(Count.globally()); + + redis.stop(); + PAssert.that(failedElementCount).containsInAnyOrder(1L); p.run(); } @Test - public void shouldConvertRowWithExtraEntitiesToValidKey() { - Map featureSets = new HashMap<>(); - featureSets.put("feature_set", fs); + public void shouldConvertRowWithDuplicateEntitiesToValidKey() { FeatureRow offendingRow = FeatureRow.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) .addFields( Field.newBuilder() @@ -157,7 +274,7 @@ public void shouldConvertRowWithExtraEntitiesToValidKey() { .setValue(Value.newBuilder().setInt32Val(1))) .addFields( Field.newBuilder() - .setName("entity_id_invalid") + .setName("entity_id_primary") .setValue(Value.newBuilder().setInt32Val(2))) .addFields( Field.newBuilder() @@ -173,14 +290,9 @@ public void shouldConvertRowWithExtraEntitiesToValidKey() { .setValue(Value.newBuilder().setInt64Val(1001))) .build(); - PCollection output = - p.apply(Create.of(Collections.singletonList(offendingRow))) - .setCoder(ProtoCoder.of(FeatureRow.class)) - .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets))); - RedisKey expectedKey = RedisKey.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .addEntities( Field.newBuilder() .setName("entity_id_primary") @@ -198,28 +310,19 @@ public void shouldConvertRowWithExtraEntitiesToValidKey() { .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) .build(); - PAssert.that(output) - .satisfies( - (SerializableFunction, Void>) - input -> { - input.forEach( - rm -> { - assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray())); - assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray())); - }); - return null; - }); + p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); } @Test public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { - Map featureSets = new HashMap<>(); - featureSets.put("feature_set", fs); - FeatureRow offendingRow = FeatureRow.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) .addFields( Field.newBuilder() @@ -239,14 +342,9 @@ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { .setValue(Value.newBuilder().setStringVal("strValue1"))) .build(); - PCollection output = - p.apply(Create.of(Collections.singletonList(offendingRow))) - .setCoder(ProtoCoder.of(FeatureRow.class)) - .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets))); - RedisKey expectedKey = RedisKey.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .addEntities( Field.newBuilder() .setName("entity_id_primary") @@ -267,28 +365,19 @@ public void shouldConvertRowWithOutOfOrderFieldsToValidKey() { .addAllFields(expectedFields) .build(); - PAssert.that(output) - .satisfies( - (SerializableFunction, Void>) - input -> { - input.forEach( - rm -> { - assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray())); - assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray())); - }); - return null; - }); + p.apply(Create.of(offendingRow)).apply(redisFeatureSink.writer()); + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); } @Test public void shouldMergeDuplicateFeatureFields() { - Map featureSets = new HashMap<>(); - featureSets.put("feature_set", fs); - FeatureRow featureRowWithDuplicatedFeatureFields = FeatureRow.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) .addFields( Field.newBuilder() @@ -312,14 +401,9 @@ public void shouldMergeDuplicateFeatureFields() { .setValue(Value.newBuilder().setInt64Val(1001))) .build(); - PCollection output = - p.apply(Create.of(Collections.singletonList(featureRowWithDuplicatedFeatureFields))) - .setCoder(ProtoCoder.of(FeatureRow.class)) - .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets))); - RedisKey expectedKey = RedisKey.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .addEntities( Field.newBuilder() .setName("entity_id_primary") @@ -337,28 +421,19 @@ public void shouldMergeDuplicateFeatureFields() { .addFields(Field.newBuilder().setValue(Value.newBuilder().setInt64Val(1001))) .build(); - PAssert.that(output) - .satisfies( - (SerializableFunction, Void>) - input -> { - input.forEach( - rm -> { - assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray())); - assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray())); - }); - return null; - }); + p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); } @Test public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { - Map featureSets = new HashMap<>(); - featureSets.put("feature_set", fs); - FeatureRow featureRowWithDuplicatedFeatureFields = FeatureRow.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .setEventTimestamp(Timestamp.newBuilder().setSeconds(10)) .addFields( Field.newBuilder() @@ -374,14 +449,9 @@ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { .setValue(Value.newBuilder().setStringVal("strValue1"))) .build(); - PCollection output = - p.apply(Create.of(Collections.singletonList(featureRowWithDuplicatedFeatureFields))) - .setCoder(ProtoCoder.of(FeatureRow.class)) - .apply(ParDo.of(new FeatureRowToRedisMutationDoFn(featureSets))); - RedisKey expectedKey = RedisKey.newBuilder() - .setFeatureSet("feature_set") + .setFeatureSet("myproject/feature_set:1") .addEntities( Field.newBuilder() .setName("entity_id_primary") @@ -399,17 +469,11 @@ public void shouldPopulateMissingFeatureValuesWithDefaultInstance() { .addFields(Field.newBuilder().setValue(Value.getDefaultInstance())) .build(); - PAssert.that(output) - .satisfies( - (SerializableFunction, Void>) - input -> { - input.forEach( - rm -> { - assert (Arrays.equals(rm.getKey(), expectedKey.toByteArray())); - assert (Arrays.equals(rm.getValue(), expectedValue.toByteArray())); - }); - return null; - }); + p.apply(Create.of(featureRowWithDuplicatedFeatureFields)).apply(redisFeatureSink.writer()); + p.run(); + + byte[] actual = sync.get(expectedKey.toByteArray()); + assertThat(actual, equalTo(expectedValue.toByteArray())); } } From e7482afcae03f8c4d0c3672d2ceb1a642d748d0b Mon Sep 17 00:00:00 2001 From: David Heryanto Date: Fri, 10 Apr 2020 17:31:46 +0800 Subject: [PATCH 2/4] Update Python SDK so FeatureSet can import Schema from Tensorflow metadata (#450) * Add skeleton for update/get schema in FeatureSet * Add update_schema method to FeatureSet - Update Field, Feature and Entity class with fields from presence_constraints, shape_type and domain_info * Update error message when domain ref is missing from top level schema * Add more assertion in test_update_schema before updating schema * Fix conflicting versions in package requirements * Add export_schema method to export schema from FeatureSet * Add exporting of Tensorflow metadata schema from FeatureSet. - Update documentation for properties in Field - Deduplication refactoring in FeatureSet * Remove changes to mypy generated codes * Revert changes to packages version in requirements-ci and setup.py They are not necessary for now and to avoid unexpected breaking changes. * Remove 'schema' param in 'from_proto' method in Entity and Feature. In import_tfx_schema method, the domain info is first made inline so there is no need to have schema level domain info when updating Feast Entity and Feature. Also added documentation to setter property methods in Field.py * Fix rebase errors, apply black * Remove unnecessary imports Co-authored-by: zhilingc --- sdk/python/feast/entity.py | 27 +- sdk/python/feast/feature.py | 38 +- sdk/python/feast/feature_set.py | 129 +- sdk/python/feast/field.py | 389 +++ sdk/python/feast/loaders/yaml.py | 3 +- sdk/python/feast/value_type.py | 23 + .../tensorflow_metadata/proto/v0/path_pb2.py | 69 - .../tensorflow_metadata/proto/v0/path_pb2.pyi | 52 - .../proto/v0/schema_pb2.py | 2256 ----------------- .../proto/v0/schema_pb2.pyi | 1063 -------- .../bikeshare_feature_set.yaml | 81 + .../tensorflow_metadata/bikeshare_schema.json | 136 + sdk/python/tests/test_feature_set.py | 101 +- 13 files changed, 917 insertions(+), 3450 deletions(-) delete mode 100644 sdk/python/tensorflow_metadata/proto/v0/path_pb2.py delete mode 100644 sdk/python/tensorflow_metadata/proto/v0/path_pb2.pyi delete mode 100644 sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py delete mode 100644 sdk/python/tensorflow_metadata/proto/v0/schema_pb2.pyi create mode 100644 sdk/python/tests/data/tensorflow_metadata/bikeshare_feature_set.yaml create mode 100644 sdk/python/tests/data/tensorflow_metadata/bikeshare_schema.json diff --git a/sdk/python/feast/entity.py b/sdk/python/feast/entity.py index 5f823a754a0..9c5a027b974 100644 --- a/sdk/python/feast/entity.py +++ b/sdk/python/feast/entity.py @@ -29,7 +29,26 @@ def to_proto(self) -> EntityProto: Returns EntitySpec object """ value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name) - return EntityProto(name=self.name, value_type=value_type) + return EntityProto( + name=self.name, + value_type=value_type, + presence=self.presence, + group_presence=self.group_presence, + shape=self.shape, + value_count=self.value_count, + domain=self.domain, + int_domain=self.int_domain, + float_domain=self.float_domain, + string_domain=self.string_domain, + bool_domain=self.bool_domain, + struct_domain=self.struct_domain, + natural_language_domain=self.natural_language_domain, + image_domain=self.image_domain, + mid_domain=self.mid_domain, + url_domain=self.url_domain, + time_domain=self.time_domain, + time_of_day_domain=self.time_of_day_domain, + ) @classmethod def from_proto(cls, entity_proto: EntityProto): @@ -42,4 +61,8 @@ def from_proto(cls, entity_proto: EntityProto): Returns: Entity object """ - return cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type)) + entity = cls(name=entity_proto.name, dtype=ValueType(entity_proto.value_type)) + entity.update_presence_constraints(entity_proto) + entity.update_shape_type(entity_proto) + entity.update_domain_info(entity_proto) + return entity diff --git a/sdk/python/feast/feature.py b/sdk/python/feast/feature.py index c9fc1cbff40..9c7ff20f9e2 100644 --- a/sdk/python/feast/feature.py +++ b/sdk/python/feast/feature.py @@ -24,9 +24,41 @@ class Feature(Field): def to_proto(self) -> FeatureProto: """Converts Feature object to its Protocol Buffer representation""" value_type = ValueTypeProto.ValueType.Enum.Value(self.dtype.name) - return FeatureProto(name=self.name, value_type=value_type) + return FeatureProto( + name=self.name, + value_type=value_type, + presence=self.presence, + group_presence=self.group_presence, + shape=self.shape, + value_count=self.value_count, + domain=self.domain, + int_domain=self.int_domain, + float_domain=self.float_domain, + string_domain=self.string_domain, + bool_domain=self.bool_domain, + struct_domain=self.struct_domain, + natural_language_domain=self.natural_language_domain, + image_domain=self.image_domain, + mid_domain=self.mid_domain, + url_domain=self.url_domain, + time_domain=self.time_domain, + time_of_day_domain=self.time_of_day_domain, + ) @classmethod def from_proto(cls, feature_proto: FeatureProto): - """Converts Protobuf Feature to its SDK equivalent""" - return cls(name=feature_proto.name, dtype=ValueType(feature_proto.value_type)) + """ + + Args: + feature_proto: FeatureSpec protobuf object + + Returns: + Feature object + """ + feature = cls( + name=feature_proto.name, dtype=ValueType(feature_proto.value_type) + ) + feature.update_presence_constraints(feature_proto) + feature.update_shape_type(feature_proto) + feature.update_domain_info(feature_proto) + return feature diff --git a/sdk/python/feast/feature_set.py b/sdk/python/feast/feature_set.py index c4cedaf6b2a..c6104f47a08 100644 --- a/sdk/python/feast/feature_set.py +++ b/sdk/python/feast/feature_set.py @@ -11,18 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - +import warnings from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Dict +from typing import List, Optional import pandas as pd import pyarrow as pa from google.protobuf import json_format from google.protobuf.duration_pb2 import Duration from google.protobuf.json_format import MessageToJson +from google.protobuf.message import Message from pandas.api.types import is_datetime64_ns_dtype from pyarrow.lib import TimestampType +from tensorflow_metadata.proto.v0 import schema_pb2 from feast.core.FeatureSet_pb2 import FeatureSet as FeatureSetProto from feast.core.FeatureSet_pb2 import FeatureSetMeta as FeatureSetMetaProto @@ -657,6 +659,93 @@ def is_valid(self): if len(self.entities) == 0: raise ValueError(f"No entities found in feature set {self.name}") + def import_tfx_schema(self, schema: schema_pb2.Schema): + """ + Updates presence_constraints, shape_type and domain_info for all fields + (features and entities) in the FeatureSet from schema in the Tensorflow metadata. + + Args: + schema: Schema from Tensorflow metadata + + Returns: + None + + """ + _make_tfx_schema_domain_info_inline(schema) + for feature_from_tfx_schema in schema.feature: + if feature_from_tfx_schema.name in self._fields.keys(): + field = self._fields[feature_from_tfx_schema.name] + field.update_presence_constraints(feature_from_tfx_schema) + field.update_shape_type(feature_from_tfx_schema) + field.update_domain_info(feature_from_tfx_schema) + else: + warnings.warn( + f"The provided schema contains feature name '{feature_from_tfx_schema.name}' " + f"that does not exist in the FeatureSet '{self.name}' in Feast" + ) + + def export_tfx_schema(self) -> schema_pb2.Schema: + """ + Create a Tensorflow metadata schema from a FeatureSet. + + Returns: + Tensorflow metadata schema. + + """ + schema = schema_pb2.Schema() + + # List of attributes to copy from fields in the FeatureSet to feature in + # Tensorflow metadata schema where the attribute name is the same. + attributes_to_copy_from_field_to_feature = [ + "name", + "presence", + "group_presence", + "shape", + "value_count", + "domain", + "int_domain", + "float_domain", + "string_domain", + "bool_domain", + "struct_domain", + "_natural_language_domain", + "image_domain", + "mid_domain", + "url_domain", + "time_domain", + "time_of_day_domain", + ] + + for _, field in self._fields.items(): + feature = schema_pb2.Feature() + for attr in attributes_to_copy_from_field_to_feature: + if getattr(field, attr) is None: + # This corresponds to an unset member in the proto Oneof field. + continue + if issubclass(type(getattr(feature, attr)), Message): + # Proto message field to copy is an "embedded" field, so MergeFrom() + # method must be used. + getattr(feature, attr).MergeFrom(getattr(field, attr)) + elif issubclass(type(getattr(feature, attr)), (int, str, bool)): + # Proto message field is a simple Python type, so setattr() + # can be used. + setattr(feature, attr, getattr(field, attr)) + else: + warnings.warn( + f"Attribute '{attr}' cannot be copied from Field " + f"'{field.name}' in FeatureSet '{self.name}' to a " + f"Feature in the Tensorflow metadata schema, because" + f"the type is neither a Protobuf message or Python " + f"int, str and bool" + ) + # "type" attr is handled separately because the attribute name is different + # ("dtype" in field and "type" in Feature) and "type" in Feature is only + # a subset of "dtype". + feature.type = field.dtype.to_tfx_schema_feature_type() + schema.feature.append(feature) + + return schema + @classmethod def from_yaml(cls, yml: str): """ @@ -855,6 +944,40 @@ def __hash__(self): return hash(repr(self)) +def _make_tfx_schema_domain_info_inline(schema: schema_pb2.Schema) -> None: + """ + Copy top level domain info defined at schema level into inline definition. + One use case is when importing domain info from Tensorflow metadata schema + into Feast features. Feast features do not have access to schema level information + so the domain info needs to be inline. + + Args: + schema: Tensorflow metadata schema + + Returns: None + """ + # Reference to domains defined at schema level + domain_ref_to_string_domain = {d.name: d for d in schema.string_domain} + domain_ref_to_float_domain = {d.name: d for d in schema.float_domain} + domain_ref_to_int_domain = {d.name: d for d in schema.int_domain} + + # With the reference, it is safe to remove the domains defined at schema level + del schema.string_domain[:] + del schema.float_domain[:] + del schema.int_domain[:] + + for feature in schema.feature: + domain_info_case = feature.WhichOneof("domain_info") + if domain_info_case == "domain": + domain_ref = feature.domain + if domain_ref in domain_ref_to_string_domain: + feature.string_domain.MergeFrom(domain_ref_to_string_domain[domain_ref]) + elif domain_ref in domain_ref_to_float_domain: + feature.float_domain.MergeFrom(domain_ref_to_float_domain[domain_ref]) + elif domain_ref in domain_ref_to_int_domain: + feature.int_domain.MergeFrom(domain_ref_to_int_domain[domain_ref]) + + def _infer_pd_column_type(column, series, rows_to_sample): dtype = None sample_count = 0 diff --git a/sdk/python/feast/field.py b/sdk/python/feast/field.py index 2efd4587ff0..be56823489b 100644 --- a/sdk/python/feast/field.py +++ b/sdk/python/feast/field.py @@ -11,8 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union +from feast.core.FeatureSet_pb2 import EntitySpec, FeatureSpec from feast.value_type import ValueType +from tensorflow_metadata.proto.v0 import schema_pb2 class Field: @@ -26,6 +29,22 @@ def __init__(self, name: str, dtype: ValueType): if not isinstance(dtype, ValueType): raise ValueError("dtype is not a valid ValueType") self._dtype = dtype + self._presence = None + self._group_presence = None + self._shape = None + self._value_count = None + self._domain = None + self._int_domain = None + self._float_domain = None + self._string_domain = None + self._bool_domain = None + self._struct_domain = None + self._natural_language_domain = None + self._image_domain = None + self._mid_domain = None + self._url_domain = None + self._time_domain = None + self._time_of_day_domain = None def __eq__(self, other): if self.name != other.name or self.dtype != other.dtype: @@ -46,6 +65,354 @@ def dtype(self) -> ValueType: """ return self._dtype + @property + def presence(self) -> schema_pb2.FeaturePresence: + """ + Getter for presence of this field + """ + return self._presence + + @presence.setter + def presence(self, presence: schema_pb2.FeaturePresence): + """ + Setter for presence of this field + """ + if not isinstance(presence, schema_pb2.FeaturePresence): + raise TypeError("presence must be of FeaturePresence type") + self._clear_presence_constraints() + self._presence = presence + + @property + def group_presence(self) -> schema_pb2.FeaturePresenceWithinGroup: + """ + Getter for group_presence of this field + """ + return self._group_presence + + @group_presence.setter + def group_presence(self, group_presence: schema_pb2.FeaturePresenceWithinGroup): + """ + Setter for group_presence of this field + """ + if not isinstance(group_presence, schema_pb2.FeaturePresenceWithinGroup): + raise TypeError("group_presence must be of FeaturePresenceWithinGroup type") + self._clear_presence_constraints() + self._group_presence = group_presence + + @property + def shape(self) -> schema_pb2.FixedShape: + """ + Getter for shape of this field + """ + return self._shape + + @shape.setter + def shape(self, shape: schema_pb2.FixedShape): + """ + Setter for shape of this field + """ + if not isinstance(shape, schema_pb2.FixedShape): + raise TypeError("shape must be of FixedShape type") + self._clear_shape_type() + self._shape = shape + + @property + def value_count(self) -> schema_pb2.ValueCount: + """ + Getter for value_count of this field + """ + return self._value_count + + @value_count.setter + def value_count(self, value_count: schema_pb2.ValueCount): + """ + Setter for value_count of this field + """ + if not isinstance(value_count, schema_pb2.ValueCount): + raise TypeError("value_count must be of ValueCount type") + self._clear_shape_type() + self._value_count = value_count + + @property + def domain(self) -> str: + """ + Getter for domain of this field + """ + return self._domain + + @domain.setter + def domain(self, domain: str): + """ + Setter for domain of this field + """ + if not isinstance(domain, str): + raise TypeError("domain must be of str type") + self._clear_domain_info() + self._domain = domain + + @property + def int_domain(self) -> schema_pb2.IntDomain: + """ + Getter for int_domain of this field + """ + return self._int_domain + + @int_domain.setter + def int_domain(self, int_domain: schema_pb2.IntDomain): + """ + Setter for int_domain of this field + """ + if not isinstance(int_domain, schema_pb2.IntDomain): + raise TypeError("int_domain must be of IntDomain type") + self._clear_domain_info() + self._int_domain = int_domain + + @property + def float_domain(self) -> schema_pb2.FloatDomain: + """ + Getter for float_domain of this field + """ + return self._float_domain + + @float_domain.setter + def float_domain(self, float_domain: schema_pb2.FloatDomain): + """ + Setter for float_domain of this field + """ + if not isinstance(float_domain, schema_pb2.FloatDomain): + raise TypeError("float_domain must be of FloatDomain type") + self._clear_domain_info() + self._float_domain = float_domain + + @property + def string_domain(self) -> schema_pb2.StringDomain: + """ + Getter for string_domain of this field + """ + return self._string_domain + + @string_domain.setter + def string_domain(self, string_domain: schema_pb2.StringDomain): + """ + Setter for string_domain of this field + """ + if not isinstance(string_domain, schema_pb2.StringDomain): + raise TypeError("string_domain must be of StringDomain type") + self._clear_domain_info() + self._string_domain = string_domain + + @property + def bool_domain(self) -> schema_pb2.BoolDomain: + """ + Getter for bool_domain of this field + """ + return self._bool_domain + + @bool_domain.setter + def bool_domain(self, bool_domain: schema_pb2.BoolDomain): + """ + Setter for bool_domain of this field + """ + if not isinstance(bool_domain, schema_pb2.BoolDomain): + raise TypeError("bool_domain must be of BoolDomain type") + self._clear_domain_info() + self._bool_domain = bool_domain + + @property + def struct_domain(self) -> schema_pb2.StructDomain: + """ + Getter for struct_domain of this field + """ + return self._struct_domain + + @struct_domain.setter + def struct_domain(self, struct_domain: schema_pb2.StructDomain): + """ + Setter for struct_domain of this field + """ + if not isinstance(struct_domain, schema_pb2.StructDomain): + raise TypeError("struct_domain must be of StructDomain type") + self._clear_domain_info() + self._struct_domain = struct_domain + + @property + def natural_language_domain(self) -> schema_pb2.NaturalLanguageDomain: + """ + Getter for natural_language_domain of this field + """ + return self._natural_language_domain + + @natural_language_domain.setter + def natural_language_domain( + self, natural_language_domain: schema_pb2.NaturalLanguageDomain + ): + """ + Setter for natural_language_domin of this field + """ + if not isinstance(natural_language_domain, schema_pb2.NaturalLanguageDomain): + raise TypeError( + "natural_language_domain must be of NaturalLanguageDomain type" + ) + self._clear_domain_info() + self._natural_language_domain = natural_language_domain + + @property + def image_domain(self) -> schema_pb2.ImageDomain: + """ + Getter for image_domain of this field + """ + return self._image_domain + + @image_domain.setter + def image_domain(self, image_domain: schema_pb2.ImageDomain): + """ + Setter for image_domain of this field + """ + if not isinstance(image_domain, schema_pb2.ImageDomain): + raise TypeError("image_domain must be of ImageDomain type") + self._clear_domain_info() + self._image_domain = image_domain + + @property + def mid_domain(self) -> schema_pb2.MIDDomain: + """ + Getter for mid_domain of this field + """ + return self._mid_domain + + @mid_domain.setter + def mid_domain(self, mid_domain: schema_pb2.MIDDomain): + """ + Setter for mid_domain of this field + """ + if not isinstance(mid_domain, schema_pb2.MIDDomain): + raise TypeError("mid_domain must be of MIDDomain type") + self._clear_domain_info() + self._mid_domain = mid_domain + + @property + def url_domain(self) -> schema_pb2.URLDomain: + """ + Getter for url_domain of this field + """ + return self._url_domain + + @url_domain.setter + def url_domain(self, url_domain: schema_pb2.URLDomain): + """ + Setter for url_domain of this field + """ + if not isinstance(url_domain, schema_pb2.URLDomain): + raise TypeError("url_domain must be of URLDomain type") + self._clear_domain_info() + self.url_domain = url_domain + + @property + def time_domain(self) -> schema_pb2.TimeDomain: + """ + Getter for time_domain of this field + """ + return self._time_domain + + @time_domain.setter + def time_domain(self, time_domain: schema_pb2.TimeDomain): + """ + Setter for time_domain of this field + """ + if not isinstance(time_domain, schema_pb2.TimeDomain): + raise TypeError("time_domain must be of TimeDomain type") + self._clear_domain_info() + self._time_domain = time_domain + + @property + def time_of_day_domain(self) -> schema_pb2.TimeOfDayDomain: + """ + Getter for time_of_day_domain of this field + """ + return self._time_of_day_domain + + @time_of_day_domain.setter + def time_of_day_domain(self, time_of_day_domain): + """ + Setter for time_of_day_domain of this field + """ + if not isinstance(time_of_day_domain, schema_pb2.TimeOfDayDomain): + raise TypeError("time_of_day_domain must be of TimeOfDayDomain type") + self._clear_domain_info() + self._time_of_day_domain = time_of_day_domain + + def update_presence_constraints( + self, feature: Union[schema_pb2.Feature, EntitySpec, FeatureSpec] + ) -> None: + """ + Update the presence constraints in this field from Tensorflow Feature, + Feast EntitySpec or FeatureSpec + + Args: + feature: Tensorflow Feature, Feast EntitySpec or FeatureSpec + + Returns: None + """ + presence_constraints_case = feature.WhichOneof("presence_constraints") + if presence_constraints_case == "presence": + self.presence = feature.presence + elif presence_constraints_case == "group_presence": + self.group_presence = feature.group_presence + + def update_shape_type( + self, feature: Union[schema_pb2.Feature, EntitySpec, FeatureSpec] + ) -> None: + """ + Update the shape type in this field from Tensorflow Feature, + Feast EntitySpec or FeatureSpec + + Args: + feature: Tensorflow Feature, Feast EntitySpec or FeatureSpec + + Returns: None + """ + shape_type_case = feature.WhichOneof("shape_type") + if shape_type_case == "shape": + self.shape = feature.shape + elif shape_type_case == "value_count": + self.value_count = feature.value_count + + def update_domain_info( + self, feature: Union[schema_pb2.Feature, EntitySpec, FeatureSpec] + ) -> None: + """ + Update the domain info in this field from Tensorflow Feature, Feast EntitySpec + or FeatureSpec + + Args: + feature: Tensorflow Feature, Feast EntitySpec or FeatureSpec + + Returns: None + """ + domain_info_case = feature.WhichOneof("domain_info") + if domain_info_case == "int_domain": + self.int_domain = feature.int_domain + elif domain_info_case == "float_domain": + self.float_domain = feature.float_domain + elif domain_info_case == "string_domain": + self.string_domain = feature.string_domain + elif domain_info_case == "bool_domain": + self.bool_domain = feature.bool_domain + elif domain_info_case == "struct_domain": + self.struct_domain = feature.struct_domain + elif domain_info_case == "natural_language_domain": + self.natural_language_domain = feature.natural_language_domain + elif domain_info_case == "image_domain": + self.image_domain = feature.image_domain + elif domain_info_case == "mid_domain": + self.mid_domain = feature.mid_domain + elif domain_info_case == "url_domain": + self.url_domain = feature.url_domain + elif domain_info_case == "time_domain": + self.time_domain = feature.time_domain + elif domain_info_case == "time_of_day_domain": + self.time_of_day_domain = feature.time_of_day_domain + def to_proto(self): """ Unimplemented to_proto method for a field. This should be extended. @@ -57,3 +424,25 @@ def from_proto(self, proto): Unimplemented from_proto method for a field. This should be extended. """ pass + + def _clear_presence_constraints(self): + self._presence = None + self._group_presence = None + + def _clear_shape_type(self): + self._shape = None + self._value_count = None + + def _clear_domain_info(self): + self._domain = None + self._int_domain = None + self._float_domain = None + self._string_domain = None + self._bool_domain = None + self._struct_domain = None + self._natural_language_domain = None + self._image_domain = None + self._mid_domain = None + self._url_domain = None + self._time_domain = None + self._time_of_day_domain = None diff --git a/sdk/python/feast/loaders/yaml.py b/sdk/python/feast/loaders/yaml.py index 130a71a3d02..624bc47d49c 100644 --- a/sdk/python/feast/loaders/yaml.py +++ b/sdk/python/feast/loaders/yaml.py @@ -57,7 +57,8 @@ def _get_yaml_contents(yml: str) -> str: yml_content = yml else: raise Exception( - f"Invalid YAML provided. Please provide either a file path or YAML string: ${yml}" + f"Invalid YAML provided. Please provide either a file path or YAML string.\n" + f"Provided YAML: {yml}" ) return yml_content diff --git a/sdk/python/feast/value_type.py b/sdk/python/feast/value_type.py index df315480ce7..687dccc7b7f 100644 --- a/sdk/python/feast/value_type.py +++ b/sdk/python/feast/value_type.py @@ -14,6 +14,8 @@ import enum +from tensorflow_metadata.proto.v0 import schema_pb2 + class ValueType(enum.Enum): """ @@ -35,3 +37,24 @@ class ValueType(enum.Enum): DOUBLE_LIST = 15 FLOAT_LIST = 16 BOOL_LIST = 17 + + def to_tfx_schema_feature_type(self) -> schema_pb2.FeatureType: + if self.value in [ + ValueType.BYTES.value, + ValueType.STRING.value, + ValueType.BOOL.value, + ValueType.BYTES_LIST.value, + ValueType.STRING_LIST.value, + ValueType.INT32_LIST.value, + ValueType.INT64_LIST.value, + ValueType.DOUBLE_LIST.value, + ValueType.FLOAT_LIST.value, + ValueType.BOOL_LIST.value, + ]: + return schema_pb2.FeatureType.BYTES + elif self.value in [ValueType.INT32.value, ValueType.INT64.value]: + return schema_pb2.FeatureType.INT + elif self.value in [ValueType.DOUBLE.value, ValueType.FLOAT.value]: + return schema_pb2.FeatureType.FLOAT + else: + return schema_pb2.FeatureType.TYPE_UNKNOWN diff --git a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py deleted file mode 100644 index 24850688592..00000000000 --- a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.py +++ /dev/null @@ -1,69 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: tensorflow_metadata/proto/v0/path.proto - -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_metadata/proto/v0/path.proto', - package='tensorflow.metadata.v0', - syntax='proto2', - serialized_options=b'\n\032org.tensorflow.metadata.v0P\001\370\001\001', - serialized_pb=b'\n\'tensorflow_metadata/proto/v0/path.proto\x12\x16tensorflow.metadata.v0\"\x14\n\x04Path\x12\x0c\n\x04step\x18\x01 \x03(\tB!\n\x1aorg.tensorflow.metadata.v0P\x01\xf8\x01\x01' -) - - - - -_PATH = _descriptor.Descriptor( - name='Path', - full_name='tensorflow.metadata.v0.Path', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='step', full_name='tensorflow.metadata.v0.Path.step', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=67, - serialized_end=87, -) - -DESCRIPTOR.message_types_by_name['Path'] = _PATH -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -Path = _reflection.GeneratedProtocolMessageType('Path', (_message.Message,), { - 'DESCRIPTOR' : _PATH, - '__module__' : 'tensorflow_metadata.proto.v0.path_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.Path) - }) -_sym_db.RegisterMessage(Path) - - -DESCRIPTOR._options = None -# @@protoc_insertion_point(module_scope) diff --git a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.pyi b/sdk/python/tensorflow_metadata/proto/v0/path_pb2.pyi deleted file mode 100644 index caf370bd372..00000000000 --- a/sdk/python/tensorflow_metadata/proto/v0/path_pb2.pyi +++ /dev/null @@ -1,52 +0,0 @@ -# @generated by generate_proto_mypy_stubs.py. Do not edit! -import sys -from google.protobuf.descriptor import ( - Descriptor as google___protobuf___descriptor___Descriptor, -) - -from google.protobuf.internal.containers import ( - RepeatedScalarFieldContainer as google___protobuf___internal___containers___RepeatedScalarFieldContainer, -) - -from google.protobuf.message import ( - Message as google___protobuf___message___Message, -) - -from typing import ( - Iterable as typing___Iterable, - Optional as typing___Optional, - Text as typing___Text, - Union as typing___Union, -) - -from typing_extensions import ( - Literal as typing_extensions___Literal, -) - - -builtin___bool = bool -builtin___bytes = bytes -builtin___float = float -builtin___int = int -if sys.version_info < (3,): - builtin___buffer = buffer - builtin___unicode = unicode - - -class Path(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - step = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - - def __init__(self, - *, - step : typing___Optional[typing___Iterable[typing___Text]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> Path: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> Path: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def ClearField(self, field_name: typing_extensions___Literal[u"step",b"step"]) -> None: ... diff --git a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py deleted file mode 100644 index c27579f0e28..00000000000 --- a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.py +++ /dev/null @@ -1,2256 +0,0 @@ -# -*- coding: utf-8 -*- -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: tensorflow_metadata/proto/v0/schema.proto - -from google.protobuf.internal import enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from google.protobuf import reflection as _reflection -from google.protobuf import symbol_database as _symbol_database -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2 -from tensorflow_metadata.proto.v0 import path_pb2 as tensorflow__metadata_dot_proto_dot_v0_dot_path__pb2 - - -DESCRIPTOR = _descriptor.FileDescriptor( - name='tensorflow_metadata/proto/v0/schema.proto', - package='tensorflow.metadata.v0', - syntax='proto2', - serialized_options=b'\n\032org.tensorflow.metadata.v0P\001\370\001\001', - serialized_pb=b'\n)tensorflow_metadata/proto/v0/schema.proto\x12\x16tensorflow.metadata.v0\x1a\x19google/protobuf/any.proto\x1a\'tensorflow_metadata/proto/v0/path.proto\"\xe2\x05\n\x06Schema\x12\x30\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x1f.tensorflow.metadata.v0.Feature\x12=\n\x0esparse_feature\x18\x06 \x03(\x0b\x32%.tensorflow.metadata.v0.SparseFeature\x12\x41\n\x10weighted_feature\x18\x0c \x03(\x0b\x32\'.tensorflow.metadata.v0.WeightedFeature\x12;\n\rstring_domain\x18\x04 \x03(\x0b\x32$.tensorflow.metadata.v0.StringDomain\x12\x39\n\x0c\x66loat_domain\x18\t \x03(\x0b\x32#.tensorflow.metadata.v0.FloatDomain\x12\x35\n\nint_domain\x18\n \x03(\x0b\x32!.tensorflow.metadata.v0.IntDomain\x12\x1b\n\x13\x64\x65\x66\x61ult_environment\x18\x05 \x03(\t\x12\x36\n\nannotation\x18\x08 \x01(\x0b\x32\".tensorflow.metadata.v0.Annotation\x12G\n\x13\x64\x61taset_constraints\x18\x0b \x01(\x0b\x32*.tensorflow.metadata.v0.DatasetConstraints\x12\x62\n\x1btensor_representation_group\x18\r \x03(\x0b\x32=.tensorflow.metadata.v0.Schema.TensorRepresentationGroupEntry\x1as\n\x1eTensorRepresentationGroupEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12@\n\x05value\x18\x02 \x01(\x0b\x32\x31.tensorflow.metadata.v0.TensorRepresentationGroup:\x02\x38\x01\"\xdf\x0b\n\x07\x46\x65\x61ture\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x16\n\ndeprecated\x18\x02 \x01(\x08\x42\x02\x18\x01\x12;\n\x08presence\x18\x0e \x01(\x0b\x32\'.tensorflow.metadata.v0.FeaturePresenceH\x00\x12L\n\x0egroup_presence\x18\x11 \x01(\x0b\x32\x32.tensorflow.metadata.v0.FeaturePresenceWithinGroupH\x00\x12\x33\n\x05shape\x18\x17 \x01(\x0b\x32\".tensorflow.metadata.v0.FixedShapeH\x01\x12\x39\n\x0bvalue_count\x18\x05 \x01(\x0b\x32\".tensorflow.metadata.v0.ValueCountH\x01\x12\x31\n\x04type\x18\x06 \x01(\x0e\x32#.tensorflow.metadata.v0.FeatureType\x12\x10\n\x06\x64omain\x18\x07 \x01(\tH\x02\x12\x37\n\nint_domain\x18\t \x01(\x0b\x32!.tensorflow.metadata.v0.IntDomainH\x02\x12;\n\x0c\x66loat_domain\x18\n \x01(\x0b\x32#.tensorflow.metadata.v0.FloatDomainH\x02\x12=\n\rstring_domain\x18\x0b \x01(\x0b\x32$.tensorflow.metadata.v0.StringDomainH\x02\x12\x39\n\x0b\x62ool_domain\x18\r \x01(\x0b\x32\".tensorflow.metadata.v0.BoolDomainH\x02\x12=\n\rstruct_domain\x18\x1d \x01(\x0b\x32$.tensorflow.metadata.v0.StructDomainH\x02\x12P\n\x17natural_language_domain\x18\x18 \x01(\x0b\x32-.tensorflow.metadata.v0.NaturalLanguageDomainH\x02\x12;\n\x0cimage_domain\x18\x19 \x01(\x0b\x32#.tensorflow.metadata.v0.ImageDomainH\x02\x12\x37\n\nmid_domain\x18\x1a \x01(\x0b\x32!.tensorflow.metadata.v0.MIDDomainH\x02\x12\x37\n\nurl_domain\x18\x1b \x01(\x0b\x32!.tensorflow.metadata.v0.URLDomainH\x02\x12\x39\n\x0btime_domain\x18\x1c \x01(\x0b\x32\".tensorflow.metadata.v0.TimeDomainH\x02\x12\x45\n\x12time_of_day_domain\x18\x1e \x01(\x0b\x32\'.tensorflow.metadata.v0.TimeOfDayDomainH\x02\x12Q\n\x18\x64istribution_constraints\x18\x0f \x01(\x0b\x32/.tensorflow.metadata.v0.DistributionConstraints\x12\x36\n\nannotation\x18\x10 \x01(\x0b\x32\".tensorflow.metadata.v0.Annotation\x12\x42\n\x0fskew_comparator\x18\x12 \x01(\x0b\x32).tensorflow.metadata.v0.FeatureComparator\x12\x43\n\x10\x64rift_comparator\x18\x15 \x01(\x0b\x32).tensorflow.metadata.v0.FeatureComparator\x12\x16\n\x0ein_environment\x18\x14 \x03(\t\x12\x1a\n\x12not_in_environment\x18\x13 \x03(\t\x12?\n\x0flifecycle_stage\x18\x16 \x01(\x0e\x32&.tensorflow.metadata.v0.LifecycleStageB\x16\n\x14presence_constraintsB\x0c\n\nshape_typeB\r\n\x0b\x64omain_info\"X\n\nAnnotation\x12\x0b\n\x03tag\x18\x01 \x03(\t\x12\x0f\n\x07\x63omment\x18\x02 \x03(\t\x12,\n\x0e\x65xtra_metadata\x18\x03 \x03(\x0b\x32\x14.google.protobuf.Any\"X\n\x16NumericValueComparator\x12\x1e\n\x16min_fraction_threshold\x18\x01 \x01(\x01\x12\x1e\n\x16max_fraction_threshold\x18\x02 \x01(\x01\"\xe0\x01\n\x12\x44\x61tasetConstraints\x12U\n\x1dnum_examples_drift_comparator\x18\x01 \x01(\x0b\x32..tensorflow.metadata.v0.NumericValueComparator\x12W\n\x1fnum_examples_version_comparator\x18\x02 \x01(\x0b\x32..tensorflow.metadata.v0.NumericValueComparator\x12\x1a\n\x12min_examples_count\x18\x03 \x01(\x03\"d\n\nFixedShape\x12\x33\n\x03\x64im\x18\x02 \x03(\x0b\x32&.tensorflow.metadata.v0.FixedShape.Dim\x1a!\n\x03\x44im\x12\x0c\n\x04size\x18\x01 \x01(\x03\x12\x0c\n\x04name\x18\x02 \x01(\t\"&\n\nValueCount\x12\x0b\n\x03min\x18\x01 \x01(\x03\x12\x0b\n\x03max\x18\x02 \x01(\x03\"\xc5\x01\n\x0fWeightedFeature\x12\x0c\n\x04name\x18\x01 \x01(\t\x12-\n\x07\x66\x65\x61ture\x18\x02 \x01(\x0b\x32\x1c.tensorflow.metadata.v0.Path\x12\x34\n\x0eweight_feature\x18\x03 \x01(\x0b\x32\x1c.tensorflow.metadata.v0.Path\x12?\n\x0flifecycle_stage\x18\x04 \x01(\x0e\x32&.tensorflow.metadata.v0.LifecycleStage\"\x90\x04\n\rSparseFeature\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x16\n\ndeprecated\x18\x02 \x01(\x08\x42\x02\x18\x01\x12?\n\x0flifecycle_stage\x18\x07 \x01(\x0e\x32&.tensorflow.metadata.v0.LifecycleStage\x12=\n\x08presence\x18\x04 \x01(\x0b\x32\'.tensorflow.metadata.v0.FeaturePresenceB\x02\x18\x01\x12\x37\n\x0b\x64\x65nse_shape\x18\x05 \x01(\x0b\x32\".tensorflow.metadata.v0.FixedShape\x12I\n\rindex_feature\x18\x06 \x03(\x0b\x32\x32.tensorflow.metadata.v0.SparseFeature.IndexFeature\x12\x11\n\tis_sorted\x18\x08 \x01(\x08\x12I\n\rvalue_feature\x18\t \x01(\x0b\x32\x32.tensorflow.metadata.v0.SparseFeature.ValueFeature\x12\x35\n\x04type\x18\n \x01(\x0e\x32#.tensorflow.metadata.v0.FeatureTypeB\x02\x18\x01\x1a\x1c\n\x0cIndexFeature\x12\x0c\n\x04name\x18\x01 \x01(\t\x1a\x1c\n\x0cValueFeature\x12\x0c\n\x04name\x18\x01 \x01(\tJ\x04\x08\x0b\x10\x0c\"5\n\x17\x44istributionConstraints\x12\x1a\n\x0fmin_domain_mass\x18\x01 \x01(\x01:\x01\x31\"K\n\tIntDomain\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0b\n\x03min\x18\x03 \x01(\x03\x12\x0b\n\x03max\x18\x04 \x01(\x03\x12\x16\n\x0eis_categorical\x18\x05 \x01(\x08\"5\n\x0b\x46loatDomain\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0b\n\x03min\x18\x03 \x01(\x02\x12\x0b\n\x03max\x18\x04 \x01(\x02\"\x7f\n\x0cStructDomain\x12\x30\n\x07\x66\x65\x61ture\x18\x01 \x03(\x0b\x32\x1f.tensorflow.metadata.v0.Feature\x12=\n\x0esparse_feature\x18\x02 \x03(\x0b\x32%.tensorflow.metadata.v0.SparseFeature\"+\n\x0cStringDomain\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x03(\t\"C\n\nBoolDomain\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x12\n\ntrue_value\x18\x02 \x01(\t\x12\x13\n\x0b\x66\x61lse_value\x18\x03 \x01(\t\"\x17\n\x15NaturalLanguageDomain\"\r\n\x0bImageDomain\"\x0b\n\tMIDDomain\"\x0b\n\tURLDomain\"\x8e\x02\n\nTimeDomain\x12\x17\n\rstring_format\x18\x01 \x01(\tH\x00\x12N\n\x0einteger_format\x18\x02 \x01(\x0e\x32\x34.tensorflow.metadata.v0.TimeDomain.IntegerTimeFormatH\x00\"\x8c\x01\n\x11IntegerTimeFormat\x12\x12\n\x0e\x46ORMAT_UNKNOWN\x10\x00\x12\r\n\tUNIX_DAYS\x10\x05\x12\x10\n\x0cUNIX_SECONDS\x10\x01\x12\x15\n\x11UNIX_MILLISECONDS\x10\x02\x12\x15\n\x11UNIX_MICROSECONDS\x10\x03\x12\x14\n\x10UNIX_NANOSECONDS\x10\x04\x42\x08\n\x06\x66ormat\"\xd1\x01\n\x0fTimeOfDayDomain\x12\x17\n\rstring_format\x18\x01 \x01(\tH\x00\x12X\n\x0einteger_format\x18\x02 \x01(\x0e\x32>.tensorflow.metadata.v0.TimeOfDayDomain.IntegerTimeOfDayFormatH\x00\"A\n\x16IntegerTimeOfDayFormat\x12\x12\n\x0e\x46ORMAT_UNKNOWN\x10\x00\x12\x13\n\x0fPACKED_64_NANOS\x10\x01\x42\x08\n\x06\x66ormat\":\n\x0f\x46\x65\x61turePresence\x12\x14\n\x0cmin_fraction\x18\x01 \x01(\x01\x12\x11\n\tmin_count\x18\x02 \x01(\x03\".\n\x1a\x46\x65\x61turePresenceWithinGroup\x12\x10\n\x08required\x18\x01 \x01(\x08\"!\n\x0cInfinityNorm\x12\x11\n\tthreshold\x18\x01 \x01(\x01\"P\n\x11\x46\x65\x61tureComparator\x12;\n\rinfinity_norm\x18\x01 \x01(\x0b\x32$.tensorflow.metadata.v0.InfinityNorm\"\xeb\x05\n\x14TensorRepresentation\x12P\n\x0c\x64\x65nse_tensor\x18\x01 \x01(\x0b\x32\x38.tensorflow.metadata.v0.TensorRepresentation.DenseTensorH\x00\x12_\n\x14varlen_sparse_tensor\x18\x02 \x01(\x0b\x32?.tensorflow.metadata.v0.TensorRepresentation.VarLenSparseTensorH\x00\x12R\n\rsparse_tensor\x18\x03 \x01(\x0b\x32\x39.tensorflow.metadata.v0.TensorRepresentation.SparseTensorH\x00\x1ao\n\x0c\x44\x65\x66\x61ultValue\x12\x15\n\x0b\x66loat_value\x18\x01 \x01(\x01H\x00\x12\x13\n\tint_value\x18\x02 \x01(\x03H\x00\x12\x15\n\x0b\x62ytes_value\x18\x03 \x01(\x0cH\x00\x12\x14\n\nuint_value\x18\x04 \x01(\x04H\x00\x42\x06\n\x04kind\x1a\xa7\x01\n\x0b\x44\x65nseTensor\x12\x13\n\x0b\x63olumn_name\x18\x01 \x01(\t\x12\x31\n\x05shape\x18\x02 \x01(\x0b\x32\".tensorflow.metadata.v0.FixedShape\x12P\n\rdefault_value\x18\x03 \x01(\x0b\x32\x39.tensorflow.metadata.v0.TensorRepresentation.DefaultValue\x1a)\n\x12VarLenSparseTensor\x12\x13\n\x0b\x63olumn_name\x18\x01 \x01(\t\x1a~\n\x0cSparseTensor\x12\x37\n\x0b\x64\x65nse_shape\x18\x01 \x01(\x0b\x32\".tensorflow.metadata.v0.FixedShape\x12\x1a\n\x12index_column_names\x18\x02 \x03(\t\x12\x19\n\x11value_column_name\x18\x03 \x01(\tB\x06\n\x04kind\"\xf2\x01\n\x19TensorRepresentationGroup\x12j\n\x15tensor_representation\x18\x01 \x03(\x0b\x32K.tensorflow.metadata.v0.TensorRepresentationGroup.TensorRepresentationEntry\x1ai\n\x19TensorRepresentationEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12;\n\x05value\x18\x02 \x01(\x0b\x32,.tensorflow.metadata.v0.TensorRepresentation:\x02\x38\x01*u\n\x0eLifecycleStage\x12\x11\n\rUNKNOWN_STAGE\x10\x00\x12\x0b\n\x07PLANNED\x10\x01\x12\t\n\x05\x41LPHA\x10\x02\x12\x08\n\x04\x42\x45TA\x10\x03\x12\x0e\n\nPRODUCTION\x10\x04\x12\x0e\n\nDEPRECATED\x10\x05\x12\x0e\n\nDEBUG_ONLY\x10\x06*J\n\x0b\x46\x65\x61tureType\x12\x10\n\x0cTYPE_UNKNOWN\x10\x00\x12\t\n\x05\x42YTES\x10\x01\x12\x07\n\x03INT\x10\x02\x12\t\n\x05\x46LOAT\x10\x03\x12\n\n\x06STRUCT\x10\x04\x42!\n\x1aorg.tensorflow.metadata.v0P\x01\xf8\x01\x01' - , - dependencies=[google_dot_protobuf_dot_any__pb2.DESCRIPTOR,tensorflow__metadata_dot_proto_dot_v0_dot_path__pb2.DESCRIPTOR,]) - -_LIFECYCLESTAGE = _descriptor.EnumDescriptor( - name='LifecycleStage', - full_name='tensorflow.metadata.v0.LifecycleStage', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='UNKNOWN_STAGE', index=0, number=0, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='PLANNED', index=1, number=1, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='ALPHA', index=2, number=2, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='BETA', index=3, number=3, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='PRODUCTION', index=4, number=4, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='DEPRECATED', index=5, number=5, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='DEBUG_ONLY', index=6, number=6, - serialized_options=None, - type=None), - ], - containing_type=None, - serialized_options=None, - serialized_start=5865, - serialized_end=5982, -) -_sym_db.RegisterEnumDescriptor(_LIFECYCLESTAGE) - -LifecycleStage = enum_type_wrapper.EnumTypeWrapper(_LIFECYCLESTAGE) -_FEATURETYPE = _descriptor.EnumDescriptor( - name='FeatureType', - full_name='tensorflow.metadata.v0.FeatureType', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='TYPE_UNKNOWN', index=0, number=0, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='BYTES', index=1, number=1, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='INT', index=2, number=2, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='FLOAT', index=3, number=3, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='STRUCT', index=4, number=4, - serialized_options=None, - type=None), - ], - containing_type=None, - serialized_options=None, - serialized_start=5984, - serialized_end=6058, -) -_sym_db.RegisterEnumDescriptor(_FEATURETYPE) - -FeatureType = enum_type_wrapper.EnumTypeWrapper(_FEATURETYPE) -UNKNOWN_STAGE = 0 -PLANNED = 1 -ALPHA = 2 -BETA = 3 -PRODUCTION = 4 -DEPRECATED = 5 -DEBUG_ONLY = 6 -TYPE_UNKNOWN = 0 -BYTES = 1 -INT = 2 -FLOAT = 3 -STRUCT = 4 - - -_TIMEDOMAIN_INTEGERTIMEFORMAT = _descriptor.EnumDescriptor( - name='IntegerTimeFormat', - full_name='tensorflow.metadata.v0.TimeDomain.IntegerTimeFormat', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='FORMAT_UNKNOWN', index=0, number=0, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UNIX_DAYS', index=1, number=5, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UNIX_SECONDS', index=2, number=1, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UNIX_MILLISECONDS', index=3, number=2, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UNIX_MICROSECONDS', index=4, number=3, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='UNIX_NANOSECONDS', index=5, number=4, - serialized_options=None, - type=None), - ], - containing_type=None, - serialized_options=None, - serialized_start=4281, - serialized_end=4421, -) -_sym_db.RegisterEnumDescriptor(_TIMEDOMAIN_INTEGERTIMEFORMAT) - -_TIMEOFDAYDOMAIN_INTEGERTIMEOFDAYFORMAT = _descriptor.EnumDescriptor( - name='IntegerTimeOfDayFormat', - full_name='tensorflow.metadata.v0.TimeOfDayDomain.IntegerTimeOfDayFormat', - filename=None, - file=DESCRIPTOR, - values=[ - _descriptor.EnumValueDescriptor( - name='FORMAT_UNKNOWN', index=0, number=0, - serialized_options=None, - type=None), - _descriptor.EnumValueDescriptor( - name='PACKED_64_NANOS', index=1, number=1, - serialized_options=None, - type=None), - ], - containing_type=None, - serialized_options=None, - serialized_start=4568, - serialized_end=4633, -) -_sym_db.RegisterEnumDescriptor(_TIMEOFDAYDOMAIN_INTEGERTIMEOFDAYFORMAT) - - -_SCHEMA_TENSORREPRESENTATIONGROUPENTRY = _descriptor.Descriptor( - name='TensorRepresentationGroupEntry', - full_name='tensorflow.metadata.v0.Schema.TensorRepresentationGroupEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.metadata.v0.Schema.TensorRepresentationGroupEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.metadata.v0.Schema.TensorRepresentationGroupEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=761, - serialized_end=876, -) - -_SCHEMA = _descriptor.Descriptor( - name='Schema', - full_name='tensorflow.metadata.v0.Schema', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='feature', full_name='tensorflow.metadata.v0.Schema.feature', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='sparse_feature', full_name='tensorflow.metadata.v0.Schema.sparse_feature', index=1, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='weighted_feature', full_name='tensorflow.metadata.v0.Schema.weighted_feature', index=2, - number=12, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='string_domain', full_name='tensorflow.metadata.v0.Schema.string_domain', index=3, - number=4, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='float_domain', full_name='tensorflow.metadata.v0.Schema.float_domain', index=4, - number=9, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int_domain', full_name='tensorflow.metadata.v0.Schema.int_domain', index=5, - number=10, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='default_environment', full_name='tensorflow.metadata.v0.Schema.default_environment', index=6, - number=5, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='annotation', full_name='tensorflow.metadata.v0.Schema.annotation', index=7, - number=8, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dataset_constraints', full_name='tensorflow.metadata.v0.Schema.dataset_constraints', index=8, - number=11, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='tensor_representation_group', full_name='tensorflow.metadata.v0.Schema.tensor_representation_group', index=9, - number=13, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SCHEMA_TENSORREPRESENTATIONGROUPENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=138, - serialized_end=876, -) - - -_FEATURE = _descriptor.Descriptor( - name='Feature', - full_name='tensorflow.metadata.v0.Feature', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.Feature.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='deprecated', full_name='tensorflow.metadata.v0.Feature.deprecated', index=1, - number=2, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='presence', full_name='tensorflow.metadata.v0.Feature.presence', index=2, - number=14, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='group_presence', full_name='tensorflow.metadata.v0.Feature.group_presence', index=3, - number=17, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.metadata.v0.Feature.shape', index=4, - number=23, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value_count', full_name='tensorflow.metadata.v0.Feature.value_count', index=5, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.metadata.v0.Feature.type', index=6, - number=6, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='domain', full_name='tensorflow.metadata.v0.Feature.domain', index=7, - number=7, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int_domain', full_name='tensorflow.metadata.v0.Feature.int_domain', index=8, - number=9, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='float_domain', full_name='tensorflow.metadata.v0.Feature.float_domain', index=9, - number=10, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='string_domain', full_name='tensorflow.metadata.v0.Feature.string_domain', index=10, - number=11, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bool_domain', full_name='tensorflow.metadata.v0.Feature.bool_domain', index=11, - number=13, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='struct_domain', full_name='tensorflow.metadata.v0.Feature.struct_domain', index=12, - number=29, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='natural_language_domain', full_name='tensorflow.metadata.v0.Feature.natural_language_domain', index=13, - number=24, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='image_domain', full_name='tensorflow.metadata.v0.Feature.image_domain', index=14, - number=25, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='mid_domain', full_name='tensorflow.metadata.v0.Feature.mid_domain', index=15, - number=26, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='url_domain', full_name='tensorflow.metadata.v0.Feature.url_domain', index=16, - number=27, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='time_domain', full_name='tensorflow.metadata.v0.Feature.time_domain', index=17, - number=28, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='time_of_day_domain', full_name='tensorflow.metadata.v0.Feature.time_of_day_domain', index=18, - number=30, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='distribution_constraints', full_name='tensorflow.metadata.v0.Feature.distribution_constraints', index=19, - number=15, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='annotation', full_name='tensorflow.metadata.v0.Feature.annotation', index=20, - number=16, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='skew_comparator', full_name='tensorflow.metadata.v0.Feature.skew_comparator', index=21, - number=18, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='drift_comparator', full_name='tensorflow.metadata.v0.Feature.drift_comparator', index=22, - number=21, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='in_environment', full_name='tensorflow.metadata.v0.Feature.in_environment', index=23, - number=20, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='not_in_environment', full_name='tensorflow.metadata.v0.Feature.not_in_environment', index=24, - number=19, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='lifecycle_stage', full_name='tensorflow.metadata.v0.Feature.lifecycle_stage', index=25, - number=22, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='presence_constraints', full_name='tensorflow.metadata.v0.Feature.presence_constraints', - index=0, containing_type=None, fields=[]), - _descriptor.OneofDescriptor( - name='shape_type', full_name='tensorflow.metadata.v0.Feature.shape_type', - index=1, containing_type=None, fields=[]), - _descriptor.OneofDescriptor( - name='domain_info', full_name='tensorflow.metadata.v0.Feature.domain_info', - index=2, containing_type=None, fields=[]), - ], - serialized_start=879, - serialized_end=2382, -) - - -_ANNOTATION = _descriptor.Descriptor( - name='Annotation', - full_name='tensorflow.metadata.v0.Annotation', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tag', full_name='tensorflow.metadata.v0.Annotation.tag', index=0, - number=1, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='comment', full_name='tensorflow.metadata.v0.Annotation.comment', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='extra_metadata', full_name='tensorflow.metadata.v0.Annotation.extra_metadata', index=2, - number=3, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2384, - serialized_end=2472, -) - - -_NUMERICVALUECOMPARATOR = _descriptor.Descriptor( - name='NumericValueComparator', - full_name='tensorflow.metadata.v0.NumericValueComparator', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='min_fraction_threshold', full_name='tensorflow.metadata.v0.NumericValueComparator.min_fraction_threshold', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='max_fraction_threshold', full_name='tensorflow.metadata.v0.NumericValueComparator.max_fraction_threshold', index=1, - number=2, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2474, - serialized_end=2562, -) - - -_DATASETCONSTRAINTS = _descriptor.Descriptor( - name='DatasetConstraints', - full_name='tensorflow.metadata.v0.DatasetConstraints', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='num_examples_drift_comparator', full_name='tensorflow.metadata.v0.DatasetConstraints.num_examples_drift_comparator', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='num_examples_version_comparator', full_name='tensorflow.metadata.v0.DatasetConstraints.num_examples_version_comparator', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='min_examples_count', full_name='tensorflow.metadata.v0.DatasetConstraints.min_examples_count', index=2, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2565, - serialized_end=2789, -) - - -_FIXEDSHAPE_DIM = _descriptor.Descriptor( - name='Dim', - full_name='tensorflow.metadata.v0.FixedShape.Dim', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='size', full_name='tensorflow.metadata.v0.FixedShape.Dim.size', index=0, - number=1, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.FixedShape.Dim.name', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2858, - serialized_end=2891, -) - -_FIXEDSHAPE = _descriptor.Descriptor( - name='FixedShape', - full_name='tensorflow.metadata.v0.FixedShape', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dim', full_name='tensorflow.metadata.v0.FixedShape.dim', index=0, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_FIXEDSHAPE_DIM, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2791, - serialized_end=2891, -) - - -_VALUECOUNT = _descriptor.Descriptor( - name='ValueCount', - full_name='tensorflow.metadata.v0.ValueCount', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='min', full_name='tensorflow.metadata.v0.ValueCount.min', index=0, - number=1, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='max', full_name='tensorflow.metadata.v0.ValueCount.max', index=1, - number=2, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2893, - serialized_end=2931, -) - - -_WEIGHTEDFEATURE = _descriptor.Descriptor( - name='WeightedFeature', - full_name='tensorflow.metadata.v0.WeightedFeature', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.WeightedFeature.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='feature', full_name='tensorflow.metadata.v0.WeightedFeature.feature', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='weight_feature', full_name='tensorflow.metadata.v0.WeightedFeature.weight_feature', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='lifecycle_stage', full_name='tensorflow.metadata.v0.WeightedFeature.lifecycle_stage', index=3, - number=4, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=2934, - serialized_end=3131, -) - - -_SPARSEFEATURE_INDEXFEATURE = _descriptor.Descriptor( - name='IndexFeature', - full_name='tensorflow.metadata.v0.SparseFeature.IndexFeature', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.SparseFeature.IndexFeature.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3598, - serialized_end=3626, -) - -_SPARSEFEATURE_VALUEFEATURE = _descriptor.Descriptor( - name='ValueFeature', - full_name='tensorflow.metadata.v0.SparseFeature.ValueFeature', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.SparseFeature.ValueFeature.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3628, - serialized_end=3656, -) - -_SPARSEFEATURE = _descriptor.Descriptor( - name='SparseFeature', - full_name='tensorflow.metadata.v0.SparseFeature', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.SparseFeature.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='deprecated', full_name='tensorflow.metadata.v0.SparseFeature.deprecated', index=1, - number=2, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='lifecycle_stage', full_name='tensorflow.metadata.v0.SparseFeature.lifecycle_stage', index=2, - number=7, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='presence', full_name='tensorflow.metadata.v0.SparseFeature.presence', index=3, - number=4, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='dense_shape', full_name='tensorflow.metadata.v0.SparseFeature.dense_shape', index=4, - number=5, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='index_feature', full_name='tensorflow.metadata.v0.SparseFeature.index_feature', index=5, - number=6, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_sorted', full_name='tensorflow.metadata.v0.SparseFeature.is_sorted', index=6, - number=8, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value_feature', full_name='tensorflow.metadata.v0.SparseFeature.value_feature', index=7, - number=9, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='type', full_name='tensorflow.metadata.v0.SparseFeature.type', index=8, - number=10, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=b'\030\001', file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_SPARSEFEATURE_INDEXFEATURE, _SPARSEFEATURE_VALUEFEATURE, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3134, - serialized_end=3662, -) - - -_DISTRIBUTIONCONSTRAINTS = _descriptor.Descriptor( - name='DistributionConstraints', - full_name='tensorflow.metadata.v0.DistributionConstraints', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='min_domain_mass', full_name='tensorflow.metadata.v0.DistributionConstraints.min_domain_mass', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=True, default_value=float(1), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3664, - serialized_end=3717, -) - - -_INTDOMAIN = _descriptor.Descriptor( - name='IntDomain', - full_name='tensorflow.metadata.v0.IntDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.IntDomain.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='min', full_name='tensorflow.metadata.v0.IntDomain.min', index=1, - number=3, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='max', full_name='tensorflow.metadata.v0.IntDomain.max', index=2, - number=4, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='is_categorical', full_name='tensorflow.metadata.v0.IntDomain.is_categorical', index=3, - number=5, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3719, - serialized_end=3794, -) - - -_FLOATDOMAIN = _descriptor.Descriptor( - name='FloatDomain', - full_name='tensorflow.metadata.v0.FloatDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.FloatDomain.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='min', full_name='tensorflow.metadata.v0.FloatDomain.min', index=1, - number=3, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='max', full_name='tensorflow.metadata.v0.FloatDomain.max', index=2, - number=4, type=2, cpp_type=6, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3796, - serialized_end=3849, -) - - -_STRUCTDOMAIN = _descriptor.Descriptor( - name='StructDomain', - full_name='tensorflow.metadata.v0.StructDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='feature', full_name='tensorflow.metadata.v0.StructDomain.feature', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='sparse_feature', full_name='tensorflow.metadata.v0.StructDomain.sparse_feature', index=1, - number=2, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3851, - serialized_end=3978, -) - - -_STRINGDOMAIN = _descriptor.Descriptor( - name='StringDomain', - full_name='tensorflow.metadata.v0.StringDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.StringDomain.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.metadata.v0.StringDomain.value', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=3980, - serialized_end=4023, -) - - -_BOOLDOMAIN = _descriptor.Descriptor( - name='BoolDomain', - full_name='tensorflow.metadata.v0.BoolDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='name', full_name='tensorflow.metadata.v0.BoolDomain.name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='true_value', full_name='tensorflow.metadata.v0.BoolDomain.true_value', index=1, - number=2, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='false_value', full_name='tensorflow.metadata.v0.BoolDomain.false_value', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4025, - serialized_end=4092, -) - - -_NATURALLANGUAGEDOMAIN = _descriptor.Descriptor( - name='NaturalLanguageDomain', - full_name='tensorflow.metadata.v0.NaturalLanguageDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4094, - serialized_end=4117, -) - - -_IMAGEDOMAIN = _descriptor.Descriptor( - name='ImageDomain', - full_name='tensorflow.metadata.v0.ImageDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4119, - serialized_end=4132, -) - - -_MIDDOMAIN = _descriptor.Descriptor( - name='MIDDomain', - full_name='tensorflow.metadata.v0.MIDDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4134, - serialized_end=4145, -) - - -_URLDOMAIN = _descriptor.Descriptor( - name='URLDomain', - full_name='tensorflow.metadata.v0.URLDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4147, - serialized_end=4158, -) - - -_TIMEDOMAIN = _descriptor.Descriptor( - name='TimeDomain', - full_name='tensorflow.metadata.v0.TimeDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='string_format', full_name='tensorflow.metadata.v0.TimeDomain.string_format', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='integer_format', full_name='tensorflow.metadata.v0.TimeDomain.integer_format', index=1, - number=2, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - _TIMEDOMAIN_INTEGERTIMEFORMAT, - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='format', full_name='tensorflow.metadata.v0.TimeDomain.format', - index=0, containing_type=None, fields=[]), - ], - serialized_start=4161, - serialized_end=4431, -) - - -_TIMEOFDAYDOMAIN = _descriptor.Descriptor( - name='TimeOfDayDomain', - full_name='tensorflow.metadata.v0.TimeOfDayDomain', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='string_format', full_name='tensorflow.metadata.v0.TimeOfDayDomain.string_format', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='integer_format', full_name='tensorflow.metadata.v0.TimeOfDayDomain.integer_format', index=1, - number=2, type=14, cpp_type=8, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - _TIMEOFDAYDOMAIN_INTEGERTIMEOFDAYFORMAT, - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='format', full_name='tensorflow.metadata.v0.TimeOfDayDomain.format', - index=0, containing_type=None, fields=[]), - ], - serialized_start=4434, - serialized_end=4643, -) - - -_FEATUREPRESENCE = _descriptor.Descriptor( - name='FeaturePresence', - full_name='tensorflow.metadata.v0.FeaturePresence', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='min_fraction', full_name='tensorflow.metadata.v0.FeaturePresence.min_fraction', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='min_count', full_name='tensorflow.metadata.v0.FeaturePresence.min_count', index=1, - number=2, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4645, - serialized_end=4703, -) - - -_FEATUREPRESENCEWITHINGROUP = _descriptor.Descriptor( - name='FeaturePresenceWithinGroup', - full_name='tensorflow.metadata.v0.FeaturePresenceWithinGroup', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='required', full_name='tensorflow.metadata.v0.FeaturePresenceWithinGroup.required', index=0, - number=1, type=8, cpp_type=7, label=1, - has_default_value=False, default_value=False, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4705, - serialized_end=4751, -) - - -_INFINITYNORM = _descriptor.Descriptor( - name='InfinityNorm', - full_name='tensorflow.metadata.v0.InfinityNorm', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='threshold', full_name='tensorflow.metadata.v0.InfinityNorm.threshold', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4753, - serialized_end=4786, -) - - -_FEATURECOMPARATOR = _descriptor.Descriptor( - name='FeatureComparator', - full_name='tensorflow.metadata.v0.FeatureComparator', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='infinity_norm', full_name='tensorflow.metadata.v0.FeatureComparator.infinity_norm', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=4788, - serialized_end=4868, -) - - -_TENSORREPRESENTATION_DEFAULTVALUE = _descriptor.Descriptor( - name='DefaultValue', - full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='float_value', full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue.float_value', index=0, - number=1, type=1, cpp_type=5, label=1, - has_default_value=False, default_value=float(0), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='int_value', full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue.int_value', index=1, - number=2, type=3, cpp_type=2, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='bytes_value', full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue.bytes_value', index=2, - number=3, type=12, cpp_type=9, label=1, - has_default_value=False, default_value=b"", - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='uint_value', full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue.uint_value', index=3, - number=4, type=4, cpp_type=4, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.metadata.v0.TensorRepresentation.DefaultValue.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=5158, - serialized_end=5269, -) - -_TENSORREPRESENTATION_DENSETENSOR = _descriptor.Descriptor( - name='DenseTensor', - full_name='tensorflow.metadata.v0.TensorRepresentation.DenseTensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='column_name', full_name='tensorflow.metadata.v0.TensorRepresentation.DenseTensor.column_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='shape', full_name='tensorflow.metadata.v0.TensorRepresentation.DenseTensor.shape', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='default_value', full_name='tensorflow.metadata.v0.TensorRepresentation.DenseTensor.default_value', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=5272, - serialized_end=5439, -) - -_TENSORREPRESENTATION_VARLENSPARSETENSOR = _descriptor.Descriptor( - name='VarLenSparseTensor', - full_name='tensorflow.metadata.v0.TensorRepresentation.VarLenSparseTensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='column_name', full_name='tensorflow.metadata.v0.TensorRepresentation.VarLenSparseTensor.column_name', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=5441, - serialized_end=5482, -) - -_TENSORREPRESENTATION_SPARSETENSOR = _descriptor.Descriptor( - name='SparseTensor', - full_name='tensorflow.metadata.v0.TensorRepresentation.SparseTensor', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dense_shape', full_name='tensorflow.metadata.v0.TensorRepresentation.SparseTensor.dense_shape', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='index_column_names', full_name='tensorflow.metadata.v0.TensorRepresentation.SparseTensor.index_column_names', index=1, - number=2, type=9, cpp_type=9, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value_column_name', full_name='tensorflow.metadata.v0.TensorRepresentation.SparseTensor.value_column_name', index=2, - number=3, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=5484, - serialized_end=5610, -) - -_TENSORREPRESENTATION = _descriptor.Descriptor( - name='TensorRepresentation', - full_name='tensorflow.metadata.v0.TensorRepresentation', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='dense_tensor', full_name='tensorflow.metadata.v0.TensorRepresentation.dense_tensor', index=0, - number=1, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='varlen_sparse_tensor', full_name='tensorflow.metadata.v0.TensorRepresentation.varlen_sparse_tensor', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='sparse_tensor', full_name='tensorflow.metadata.v0.TensorRepresentation.sparse_tensor', index=2, - number=3, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_TENSORREPRESENTATION_DEFAULTVALUE, _TENSORREPRESENTATION_DENSETENSOR, _TENSORREPRESENTATION_VARLENSPARSETENSOR, _TENSORREPRESENTATION_SPARSETENSOR, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - _descriptor.OneofDescriptor( - name='kind', full_name='tensorflow.metadata.v0.TensorRepresentation.kind', - index=0, containing_type=None, fields=[]), - ], - serialized_start=4871, - serialized_end=5618, -) - - -_TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY = _descriptor.Descriptor( - name='TensorRepresentationEntry', - full_name='tensorflow.metadata.v0.TensorRepresentationGroup.TensorRepresentationEntry', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='key', full_name='tensorflow.metadata.v0.TensorRepresentationGroup.TensorRepresentationEntry.key', index=0, - number=1, type=9, cpp_type=9, label=1, - has_default_value=False, default_value=b"".decode('utf-8'), - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - _descriptor.FieldDescriptor( - name='value', full_name='tensorflow.metadata.v0.TensorRepresentationGroup.TensorRepresentationEntry.value', index=1, - number=2, type=11, cpp_type=10, label=1, - has_default_value=False, default_value=None, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[], - enum_types=[ - ], - serialized_options=b'8\001', - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=5758, - serialized_end=5863, -) - -_TENSORREPRESENTATIONGROUP = _descriptor.Descriptor( - name='TensorRepresentationGroup', - full_name='tensorflow.metadata.v0.TensorRepresentationGroup', - filename=None, - file=DESCRIPTOR, - containing_type=None, - fields=[ - _descriptor.FieldDescriptor( - name='tensor_representation', full_name='tensorflow.metadata.v0.TensorRepresentationGroup.tensor_representation', index=0, - number=1, type=11, cpp_type=10, label=3, - has_default_value=False, default_value=[], - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None, - serialized_options=None, file=DESCRIPTOR), - ], - extensions=[ - ], - nested_types=[_TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY, ], - enum_types=[ - ], - serialized_options=None, - is_extendable=False, - syntax='proto2', - extension_ranges=[], - oneofs=[ - ], - serialized_start=5621, - serialized_end=5863, -) - -_SCHEMA_TENSORREPRESENTATIONGROUPENTRY.fields_by_name['value'].message_type = _TENSORREPRESENTATIONGROUP -_SCHEMA_TENSORREPRESENTATIONGROUPENTRY.containing_type = _SCHEMA -_SCHEMA.fields_by_name['feature'].message_type = _FEATURE -_SCHEMA.fields_by_name['sparse_feature'].message_type = _SPARSEFEATURE -_SCHEMA.fields_by_name['weighted_feature'].message_type = _WEIGHTEDFEATURE -_SCHEMA.fields_by_name['string_domain'].message_type = _STRINGDOMAIN -_SCHEMA.fields_by_name['float_domain'].message_type = _FLOATDOMAIN -_SCHEMA.fields_by_name['int_domain'].message_type = _INTDOMAIN -_SCHEMA.fields_by_name['annotation'].message_type = _ANNOTATION -_SCHEMA.fields_by_name['dataset_constraints'].message_type = _DATASETCONSTRAINTS -_SCHEMA.fields_by_name['tensor_representation_group'].message_type = _SCHEMA_TENSORREPRESENTATIONGROUPENTRY -_FEATURE.fields_by_name['presence'].message_type = _FEATUREPRESENCE -_FEATURE.fields_by_name['group_presence'].message_type = _FEATUREPRESENCEWITHINGROUP -_FEATURE.fields_by_name['shape'].message_type = _FIXEDSHAPE -_FEATURE.fields_by_name['value_count'].message_type = _VALUECOUNT -_FEATURE.fields_by_name['type'].enum_type = _FEATURETYPE -_FEATURE.fields_by_name['int_domain'].message_type = _INTDOMAIN -_FEATURE.fields_by_name['float_domain'].message_type = _FLOATDOMAIN -_FEATURE.fields_by_name['string_domain'].message_type = _STRINGDOMAIN -_FEATURE.fields_by_name['bool_domain'].message_type = _BOOLDOMAIN -_FEATURE.fields_by_name['struct_domain'].message_type = _STRUCTDOMAIN -_FEATURE.fields_by_name['natural_language_domain'].message_type = _NATURALLANGUAGEDOMAIN -_FEATURE.fields_by_name['image_domain'].message_type = _IMAGEDOMAIN -_FEATURE.fields_by_name['mid_domain'].message_type = _MIDDOMAIN -_FEATURE.fields_by_name['url_domain'].message_type = _URLDOMAIN -_FEATURE.fields_by_name['time_domain'].message_type = _TIMEDOMAIN -_FEATURE.fields_by_name['time_of_day_domain'].message_type = _TIMEOFDAYDOMAIN -_FEATURE.fields_by_name['distribution_constraints'].message_type = _DISTRIBUTIONCONSTRAINTS -_FEATURE.fields_by_name['annotation'].message_type = _ANNOTATION -_FEATURE.fields_by_name['skew_comparator'].message_type = _FEATURECOMPARATOR -_FEATURE.fields_by_name['drift_comparator'].message_type = _FEATURECOMPARATOR -_FEATURE.fields_by_name['lifecycle_stage'].enum_type = _LIFECYCLESTAGE -_FEATURE.oneofs_by_name['presence_constraints'].fields.append( - _FEATURE.fields_by_name['presence']) -_FEATURE.fields_by_name['presence'].containing_oneof = _FEATURE.oneofs_by_name['presence_constraints'] -_FEATURE.oneofs_by_name['presence_constraints'].fields.append( - _FEATURE.fields_by_name['group_presence']) -_FEATURE.fields_by_name['group_presence'].containing_oneof = _FEATURE.oneofs_by_name['presence_constraints'] -_FEATURE.oneofs_by_name['shape_type'].fields.append( - _FEATURE.fields_by_name['shape']) -_FEATURE.fields_by_name['shape'].containing_oneof = _FEATURE.oneofs_by_name['shape_type'] -_FEATURE.oneofs_by_name['shape_type'].fields.append( - _FEATURE.fields_by_name['value_count']) -_FEATURE.fields_by_name['value_count'].containing_oneof = _FEATURE.oneofs_by_name['shape_type'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['domain']) -_FEATURE.fields_by_name['domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['int_domain']) -_FEATURE.fields_by_name['int_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['float_domain']) -_FEATURE.fields_by_name['float_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['string_domain']) -_FEATURE.fields_by_name['string_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['bool_domain']) -_FEATURE.fields_by_name['bool_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['struct_domain']) -_FEATURE.fields_by_name['struct_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['natural_language_domain']) -_FEATURE.fields_by_name['natural_language_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['image_domain']) -_FEATURE.fields_by_name['image_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['mid_domain']) -_FEATURE.fields_by_name['mid_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['url_domain']) -_FEATURE.fields_by_name['url_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['time_domain']) -_FEATURE.fields_by_name['time_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_FEATURE.oneofs_by_name['domain_info'].fields.append( - _FEATURE.fields_by_name['time_of_day_domain']) -_FEATURE.fields_by_name['time_of_day_domain'].containing_oneof = _FEATURE.oneofs_by_name['domain_info'] -_ANNOTATION.fields_by_name['extra_metadata'].message_type = google_dot_protobuf_dot_any__pb2._ANY -_DATASETCONSTRAINTS.fields_by_name['num_examples_drift_comparator'].message_type = _NUMERICVALUECOMPARATOR -_DATASETCONSTRAINTS.fields_by_name['num_examples_version_comparator'].message_type = _NUMERICVALUECOMPARATOR -_FIXEDSHAPE_DIM.containing_type = _FIXEDSHAPE -_FIXEDSHAPE.fields_by_name['dim'].message_type = _FIXEDSHAPE_DIM -_WEIGHTEDFEATURE.fields_by_name['feature'].message_type = tensorflow__metadata_dot_proto_dot_v0_dot_path__pb2._PATH -_WEIGHTEDFEATURE.fields_by_name['weight_feature'].message_type = tensorflow__metadata_dot_proto_dot_v0_dot_path__pb2._PATH -_WEIGHTEDFEATURE.fields_by_name['lifecycle_stage'].enum_type = _LIFECYCLESTAGE -_SPARSEFEATURE_INDEXFEATURE.containing_type = _SPARSEFEATURE -_SPARSEFEATURE_VALUEFEATURE.containing_type = _SPARSEFEATURE -_SPARSEFEATURE.fields_by_name['lifecycle_stage'].enum_type = _LIFECYCLESTAGE -_SPARSEFEATURE.fields_by_name['presence'].message_type = _FEATUREPRESENCE -_SPARSEFEATURE.fields_by_name['dense_shape'].message_type = _FIXEDSHAPE -_SPARSEFEATURE.fields_by_name['index_feature'].message_type = _SPARSEFEATURE_INDEXFEATURE -_SPARSEFEATURE.fields_by_name['value_feature'].message_type = _SPARSEFEATURE_VALUEFEATURE -_SPARSEFEATURE.fields_by_name['type'].enum_type = _FEATURETYPE -_STRUCTDOMAIN.fields_by_name['feature'].message_type = _FEATURE -_STRUCTDOMAIN.fields_by_name['sparse_feature'].message_type = _SPARSEFEATURE -_TIMEDOMAIN.fields_by_name['integer_format'].enum_type = _TIMEDOMAIN_INTEGERTIMEFORMAT -_TIMEDOMAIN_INTEGERTIMEFORMAT.containing_type = _TIMEDOMAIN -_TIMEDOMAIN.oneofs_by_name['format'].fields.append( - _TIMEDOMAIN.fields_by_name['string_format']) -_TIMEDOMAIN.fields_by_name['string_format'].containing_oneof = _TIMEDOMAIN.oneofs_by_name['format'] -_TIMEDOMAIN.oneofs_by_name['format'].fields.append( - _TIMEDOMAIN.fields_by_name['integer_format']) -_TIMEDOMAIN.fields_by_name['integer_format'].containing_oneof = _TIMEDOMAIN.oneofs_by_name['format'] -_TIMEOFDAYDOMAIN.fields_by_name['integer_format'].enum_type = _TIMEOFDAYDOMAIN_INTEGERTIMEOFDAYFORMAT -_TIMEOFDAYDOMAIN_INTEGERTIMEOFDAYFORMAT.containing_type = _TIMEOFDAYDOMAIN -_TIMEOFDAYDOMAIN.oneofs_by_name['format'].fields.append( - _TIMEOFDAYDOMAIN.fields_by_name['string_format']) -_TIMEOFDAYDOMAIN.fields_by_name['string_format'].containing_oneof = _TIMEOFDAYDOMAIN.oneofs_by_name['format'] -_TIMEOFDAYDOMAIN.oneofs_by_name['format'].fields.append( - _TIMEOFDAYDOMAIN.fields_by_name['integer_format']) -_TIMEOFDAYDOMAIN.fields_by_name['integer_format'].containing_oneof = _TIMEOFDAYDOMAIN.oneofs_by_name['format'] -_FEATURECOMPARATOR.fields_by_name['infinity_norm'].message_type = _INFINITYNORM -_TENSORREPRESENTATION_DEFAULTVALUE.containing_type = _TENSORREPRESENTATION -_TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['float_value']) -_TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['float_value'].containing_oneof = _TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'] -_TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['int_value']) -_TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['int_value'].containing_oneof = _TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'] -_TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['bytes_value']) -_TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['bytes_value'].containing_oneof = _TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'] -_TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['uint_value']) -_TENSORREPRESENTATION_DEFAULTVALUE.fields_by_name['uint_value'].containing_oneof = _TENSORREPRESENTATION_DEFAULTVALUE.oneofs_by_name['kind'] -_TENSORREPRESENTATION_DENSETENSOR.fields_by_name['shape'].message_type = _FIXEDSHAPE -_TENSORREPRESENTATION_DENSETENSOR.fields_by_name['default_value'].message_type = _TENSORREPRESENTATION_DEFAULTVALUE -_TENSORREPRESENTATION_DENSETENSOR.containing_type = _TENSORREPRESENTATION -_TENSORREPRESENTATION_VARLENSPARSETENSOR.containing_type = _TENSORREPRESENTATION -_TENSORREPRESENTATION_SPARSETENSOR.fields_by_name['dense_shape'].message_type = _FIXEDSHAPE -_TENSORREPRESENTATION_SPARSETENSOR.containing_type = _TENSORREPRESENTATION -_TENSORREPRESENTATION.fields_by_name['dense_tensor'].message_type = _TENSORREPRESENTATION_DENSETENSOR -_TENSORREPRESENTATION.fields_by_name['varlen_sparse_tensor'].message_type = _TENSORREPRESENTATION_VARLENSPARSETENSOR -_TENSORREPRESENTATION.fields_by_name['sparse_tensor'].message_type = _TENSORREPRESENTATION_SPARSETENSOR -_TENSORREPRESENTATION.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION.fields_by_name['dense_tensor']) -_TENSORREPRESENTATION.fields_by_name['dense_tensor'].containing_oneof = _TENSORREPRESENTATION.oneofs_by_name['kind'] -_TENSORREPRESENTATION.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION.fields_by_name['varlen_sparse_tensor']) -_TENSORREPRESENTATION.fields_by_name['varlen_sparse_tensor'].containing_oneof = _TENSORREPRESENTATION.oneofs_by_name['kind'] -_TENSORREPRESENTATION.oneofs_by_name['kind'].fields.append( - _TENSORREPRESENTATION.fields_by_name['sparse_tensor']) -_TENSORREPRESENTATION.fields_by_name['sparse_tensor'].containing_oneof = _TENSORREPRESENTATION.oneofs_by_name['kind'] -_TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY.fields_by_name['value'].message_type = _TENSORREPRESENTATION -_TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY.containing_type = _TENSORREPRESENTATIONGROUP -_TENSORREPRESENTATIONGROUP.fields_by_name['tensor_representation'].message_type = _TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY -DESCRIPTOR.message_types_by_name['Schema'] = _SCHEMA -DESCRIPTOR.message_types_by_name['Feature'] = _FEATURE -DESCRIPTOR.message_types_by_name['Annotation'] = _ANNOTATION -DESCRIPTOR.message_types_by_name['NumericValueComparator'] = _NUMERICVALUECOMPARATOR -DESCRIPTOR.message_types_by_name['DatasetConstraints'] = _DATASETCONSTRAINTS -DESCRIPTOR.message_types_by_name['FixedShape'] = _FIXEDSHAPE -DESCRIPTOR.message_types_by_name['ValueCount'] = _VALUECOUNT -DESCRIPTOR.message_types_by_name['WeightedFeature'] = _WEIGHTEDFEATURE -DESCRIPTOR.message_types_by_name['SparseFeature'] = _SPARSEFEATURE -DESCRIPTOR.message_types_by_name['DistributionConstraints'] = _DISTRIBUTIONCONSTRAINTS -DESCRIPTOR.message_types_by_name['IntDomain'] = _INTDOMAIN -DESCRIPTOR.message_types_by_name['FloatDomain'] = _FLOATDOMAIN -DESCRIPTOR.message_types_by_name['StructDomain'] = _STRUCTDOMAIN -DESCRIPTOR.message_types_by_name['StringDomain'] = _STRINGDOMAIN -DESCRIPTOR.message_types_by_name['BoolDomain'] = _BOOLDOMAIN -DESCRIPTOR.message_types_by_name['NaturalLanguageDomain'] = _NATURALLANGUAGEDOMAIN -DESCRIPTOR.message_types_by_name['ImageDomain'] = _IMAGEDOMAIN -DESCRIPTOR.message_types_by_name['MIDDomain'] = _MIDDOMAIN -DESCRIPTOR.message_types_by_name['URLDomain'] = _URLDOMAIN -DESCRIPTOR.message_types_by_name['TimeDomain'] = _TIMEDOMAIN -DESCRIPTOR.message_types_by_name['TimeOfDayDomain'] = _TIMEOFDAYDOMAIN -DESCRIPTOR.message_types_by_name['FeaturePresence'] = _FEATUREPRESENCE -DESCRIPTOR.message_types_by_name['FeaturePresenceWithinGroup'] = _FEATUREPRESENCEWITHINGROUP -DESCRIPTOR.message_types_by_name['InfinityNorm'] = _INFINITYNORM -DESCRIPTOR.message_types_by_name['FeatureComparator'] = _FEATURECOMPARATOR -DESCRIPTOR.message_types_by_name['TensorRepresentation'] = _TENSORREPRESENTATION -DESCRIPTOR.message_types_by_name['TensorRepresentationGroup'] = _TENSORREPRESENTATIONGROUP -DESCRIPTOR.enum_types_by_name['LifecycleStage'] = _LIFECYCLESTAGE -DESCRIPTOR.enum_types_by_name['FeatureType'] = _FEATURETYPE -_sym_db.RegisterFileDescriptor(DESCRIPTOR) - -Schema = _reflection.GeneratedProtocolMessageType('Schema', (_message.Message,), { - - 'TensorRepresentationGroupEntry' : _reflection.GeneratedProtocolMessageType('TensorRepresentationGroupEntry', (_message.Message,), { - 'DESCRIPTOR' : _SCHEMA_TENSORREPRESENTATIONGROUPENTRY, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.Schema.TensorRepresentationGroupEntry) - }) - , - 'DESCRIPTOR' : _SCHEMA, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.Schema) - }) -_sym_db.RegisterMessage(Schema) -_sym_db.RegisterMessage(Schema.TensorRepresentationGroupEntry) - -Feature = _reflection.GeneratedProtocolMessageType('Feature', (_message.Message,), { - 'DESCRIPTOR' : _FEATURE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.Feature) - }) -_sym_db.RegisterMessage(Feature) - -Annotation = _reflection.GeneratedProtocolMessageType('Annotation', (_message.Message,), { - 'DESCRIPTOR' : _ANNOTATION, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.Annotation) - }) -_sym_db.RegisterMessage(Annotation) - -NumericValueComparator = _reflection.GeneratedProtocolMessageType('NumericValueComparator', (_message.Message,), { - 'DESCRIPTOR' : _NUMERICVALUECOMPARATOR, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.NumericValueComparator) - }) -_sym_db.RegisterMessage(NumericValueComparator) - -DatasetConstraints = _reflection.GeneratedProtocolMessageType('DatasetConstraints', (_message.Message,), { - 'DESCRIPTOR' : _DATASETCONSTRAINTS, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.DatasetConstraints) - }) -_sym_db.RegisterMessage(DatasetConstraints) - -FixedShape = _reflection.GeneratedProtocolMessageType('FixedShape', (_message.Message,), { - - 'Dim' : _reflection.GeneratedProtocolMessageType('Dim', (_message.Message,), { - 'DESCRIPTOR' : _FIXEDSHAPE_DIM, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FixedShape.Dim) - }) - , - 'DESCRIPTOR' : _FIXEDSHAPE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FixedShape) - }) -_sym_db.RegisterMessage(FixedShape) -_sym_db.RegisterMessage(FixedShape.Dim) - -ValueCount = _reflection.GeneratedProtocolMessageType('ValueCount', (_message.Message,), { - 'DESCRIPTOR' : _VALUECOUNT, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.ValueCount) - }) -_sym_db.RegisterMessage(ValueCount) - -WeightedFeature = _reflection.GeneratedProtocolMessageType('WeightedFeature', (_message.Message,), { - 'DESCRIPTOR' : _WEIGHTEDFEATURE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.WeightedFeature) - }) -_sym_db.RegisterMessage(WeightedFeature) - -SparseFeature = _reflection.GeneratedProtocolMessageType('SparseFeature', (_message.Message,), { - - 'IndexFeature' : _reflection.GeneratedProtocolMessageType('IndexFeature', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEFEATURE_INDEXFEATURE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.SparseFeature.IndexFeature) - }) - , - - 'ValueFeature' : _reflection.GeneratedProtocolMessageType('ValueFeature', (_message.Message,), { - 'DESCRIPTOR' : _SPARSEFEATURE_VALUEFEATURE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.SparseFeature.ValueFeature) - }) - , - 'DESCRIPTOR' : _SPARSEFEATURE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.SparseFeature) - }) -_sym_db.RegisterMessage(SparseFeature) -_sym_db.RegisterMessage(SparseFeature.IndexFeature) -_sym_db.RegisterMessage(SparseFeature.ValueFeature) - -DistributionConstraints = _reflection.GeneratedProtocolMessageType('DistributionConstraints', (_message.Message,), { - 'DESCRIPTOR' : _DISTRIBUTIONCONSTRAINTS, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.DistributionConstraints) - }) -_sym_db.RegisterMessage(DistributionConstraints) - -IntDomain = _reflection.GeneratedProtocolMessageType('IntDomain', (_message.Message,), { - 'DESCRIPTOR' : _INTDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.IntDomain) - }) -_sym_db.RegisterMessage(IntDomain) - -FloatDomain = _reflection.GeneratedProtocolMessageType('FloatDomain', (_message.Message,), { - 'DESCRIPTOR' : _FLOATDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FloatDomain) - }) -_sym_db.RegisterMessage(FloatDomain) - -StructDomain = _reflection.GeneratedProtocolMessageType('StructDomain', (_message.Message,), { - 'DESCRIPTOR' : _STRUCTDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.StructDomain) - }) -_sym_db.RegisterMessage(StructDomain) - -StringDomain = _reflection.GeneratedProtocolMessageType('StringDomain', (_message.Message,), { - 'DESCRIPTOR' : _STRINGDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.StringDomain) - }) -_sym_db.RegisterMessage(StringDomain) - -BoolDomain = _reflection.GeneratedProtocolMessageType('BoolDomain', (_message.Message,), { - 'DESCRIPTOR' : _BOOLDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.BoolDomain) - }) -_sym_db.RegisterMessage(BoolDomain) - -NaturalLanguageDomain = _reflection.GeneratedProtocolMessageType('NaturalLanguageDomain', (_message.Message,), { - 'DESCRIPTOR' : _NATURALLANGUAGEDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.NaturalLanguageDomain) - }) -_sym_db.RegisterMessage(NaturalLanguageDomain) - -ImageDomain = _reflection.GeneratedProtocolMessageType('ImageDomain', (_message.Message,), { - 'DESCRIPTOR' : _IMAGEDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.ImageDomain) - }) -_sym_db.RegisterMessage(ImageDomain) - -MIDDomain = _reflection.GeneratedProtocolMessageType('MIDDomain', (_message.Message,), { - 'DESCRIPTOR' : _MIDDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.MIDDomain) - }) -_sym_db.RegisterMessage(MIDDomain) - -URLDomain = _reflection.GeneratedProtocolMessageType('URLDomain', (_message.Message,), { - 'DESCRIPTOR' : _URLDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.URLDomain) - }) -_sym_db.RegisterMessage(URLDomain) - -TimeDomain = _reflection.GeneratedProtocolMessageType('TimeDomain', (_message.Message,), { - 'DESCRIPTOR' : _TIMEDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TimeDomain) - }) -_sym_db.RegisterMessage(TimeDomain) - -TimeOfDayDomain = _reflection.GeneratedProtocolMessageType('TimeOfDayDomain', (_message.Message,), { - 'DESCRIPTOR' : _TIMEOFDAYDOMAIN, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TimeOfDayDomain) - }) -_sym_db.RegisterMessage(TimeOfDayDomain) - -FeaturePresence = _reflection.GeneratedProtocolMessageType('FeaturePresence', (_message.Message,), { - 'DESCRIPTOR' : _FEATUREPRESENCE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FeaturePresence) - }) -_sym_db.RegisterMessage(FeaturePresence) - -FeaturePresenceWithinGroup = _reflection.GeneratedProtocolMessageType('FeaturePresenceWithinGroup', (_message.Message,), { - 'DESCRIPTOR' : _FEATUREPRESENCEWITHINGROUP, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FeaturePresenceWithinGroup) - }) -_sym_db.RegisterMessage(FeaturePresenceWithinGroup) - -InfinityNorm = _reflection.GeneratedProtocolMessageType('InfinityNorm', (_message.Message,), { - 'DESCRIPTOR' : _INFINITYNORM, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.InfinityNorm) - }) -_sym_db.RegisterMessage(InfinityNorm) - -FeatureComparator = _reflection.GeneratedProtocolMessageType('FeatureComparator', (_message.Message,), { - 'DESCRIPTOR' : _FEATURECOMPARATOR, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.FeatureComparator) - }) -_sym_db.RegisterMessage(FeatureComparator) - -TensorRepresentation = _reflection.GeneratedProtocolMessageType('TensorRepresentation', (_message.Message,), { - - 'DefaultValue' : _reflection.GeneratedProtocolMessageType('DefaultValue', (_message.Message,), { - 'DESCRIPTOR' : _TENSORREPRESENTATION_DEFAULTVALUE, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentation.DefaultValue) - }) - , - - 'DenseTensor' : _reflection.GeneratedProtocolMessageType('DenseTensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSORREPRESENTATION_DENSETENSOR, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentation.DenseTensor) - }) - , - - 'VarLenSparseTensor' : _reflection.GeneratedProtocolMessageType('VarLenSparseTensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSORREPRESENTATION_VARLENSPARSETENSOR, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentation.VarLenSparseTensor) - }) - , - - 'SparseTensor' : _reflection.GeneratedProtocolMessageType('SparseTensor', (_message.Message,), { - 'DESCRIPTOR' : _TENSORREPRESENTATION_SPARSETENSOR, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentation.SparseTensor) - }) - , - 'DESCRIPTOR' : _TENSORREPRESENTATION, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentation) - }) -_sym_db.RegisterMessage(TensorRepresentation) -_sym_db.RegisterMessage(TensorRepresentation.DefaultValue) -_sym_db.RegisterMessage(TensorRepresentation.DenseTensor) -_sym_db.RegisterMessage(TensorRepresentation.VarLenSparseTensor) -_sym_db.RegisterMessage(TensorRepresentation.SparseTensor) - -TensorRepresentationGroup = _reflection.GeneratedProtocolMessageType('TensorRepresentationGroup', (_message.Message,), { - - 'TensorRepresentationEntry' : _reflection.GeneratedProtocolMessageType('TensorRepresentationEntry', (_message.Message,), { - 'DESCRIPTOR' : _TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentationGroup.TensorRepresentationEntry) - }) - , - 'DESCRIPTOR' : _TENSORREPRESENTATIONGROUP, - '__module__' : 'tensorflow_metadata.proto.v0.schema_pb2' - # @@protoc_insertion_point(class_scope:tensorflow.metadata.v0.TensorRepresentationGroup) - }) -_sym_db.RegisterMessage(TensorRepresentationGroup) -_sym_db.RegisterMessage(TensorRepresentationGroup.TensorRepresentationEntry) - - -DESCRIPTOR._options = None -_SCHEMA_TENSORREPRESENTATIONGROUPENTRY._options = None -_FEATURE.fields_by_name['deprecated']._options = None -_SPARSEFEATURE.fields_by_name['deprecated']._options = None -_SPARSEFEATURE.fields_by_name['presence']._options = None -_SPARSEFEATURE.fields_by_name['type']._options = None -_TENSORREPRESENTATIONGROUP_TENSORREPRESENTATIONENTRY._options = None -# @@protoc_insertion_point(module_scope) diff --git a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.pyi b/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.pyi deleted file mode 100644 index d684e28c0c2..00000000000 --- a/sdk/python/tensorflow_metadata/proto/v0/schema_pb2.pyi +++ /dev/null @@ -1,1063 +0,0 @@ -# @generated by generate_proto_mypy_stubs.py. Do not edit! -import sys -from google.protobuf.any_pb2 import ( - Any as google___protobuf___any_pb2___Any, -) - -from google.protobuf.descriptor import ( - Descriptor as google___protobuf___descriptor___Descriptor, - EnumDescriptor as google___protobuf___descriptor___EnumDescriptor, -) - -from google.protobuf.internal.containers import ( - RepeatedCompositeFieldContainer as google___protobuf___internal___containers___RepeatedCompositeFieldContainer, - RepeatedScalarFieldContainer as google___protobuf___internal___containers___RepeatedScalarFieldContainer, -) - -from google.protobuf.message import ( - Message as google___protobuf___message___Message, -) - -from tensorflow_metadata.proto.v0.path_pb2 import ( - Path as tensorflow_metadata___proto___v0___path_pb2___Path, -) - -from typing import ( - Iterable as typing___Iterable, - List as typing___List, - Mapping as typing___Mapping, - MutableMapping as typing___MutableMapping, - Optional as typing___Optional, - Text as typing___Text, - Tuple as typing___Tuple, - Union as typing___Union, - cast as typing___cast, - overload as typing___overload, -) - -from typing_extensions import ( - Literal as typing_extensions___Literal, -) - - -builtin___bool = bool -builtin___bytes = bytes -builtin___float = float -builtin___int = int -builtin___str = str -if sys.version_info < (3,): - builtin___buffer = buffer - builtin___unicode = unicode - - -class LifecycleStage(builtin___int): - DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... - @classmethod - def Name(cls, number: builtin___int) -> builtin___str: ... - @classmethod - def Value(cls, name: builtin___str) -> 'LifecycleStage': ... - @classmethod - def keys(cls) -> typing___List[builtin___str]: ... - @classmethod - def values(cls) -> typing___List['LifecycleStage']: ... - @classmethod - def items(cls) -> typing___List[typing___Tuple[builtin___str, 'LifecycleStage']]: ... - UNKNOWN_STAGE = typing___cast('LifecycleStage', 0) - PLANNED = typing___cast('LifecycleStage', 1) - ALPHA = typing___cast('LifecycleStage', 2) - BETA = typing___cast('LifecycleStage', 3) - PRODUCTION = typing___cast('LifecycleStage', 4) - DEPRECATED = typing___cast('LifecycleStage', 5) - DEBUG_ONLY = typing___cast('LifecycleStage', 6) -UNKNOWN_STAGE = typing___cast('LifecycleStage', 0) -PLANNED = typing___cast('LifecycleStage', 1) -ALPHA = typing___cast('LifecycleStage', 2) -BETA = typing___cast('LifecycleStage', 3) -PRODUCTION = typing___cast('LifecycleStage', 4) -DEPRECATED = typing___cast('LifecycleStage', 5) -DEBUG_ONLY = typing___cast('LifecycleStage', 6) - -class FeatureType(builtin___int): - DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... - @classmethod - def Name(cls, number: builtin___int) -> builtin___str: ... - @classmethod - def Value(cls, name: builtin___str) -> 'FeatureType': ... - @classmethod - def keys(cls) -> typing___List[builtin___str]: ... - @classmethod - def values(cls) -> typing___List['FeatureType']: ... - @classmethod - def items(cls) -> typing___List[typing___Tuple[builtin___str, 'FeatureType']]: ... - TYPE_UNKNOWN = typing___cast('FeatureType', 0) - BYTES = typing___cast('FeatureType', 1) - INT = typing___cast('FeatureType', 2) - FLOAT = typing___cast('FeatureType', 3) - STRUCT = typing___cast('FeatureType', 4) -TYPE_UNKNOWN = typing___cast('FeatureType', 0) -BYTES = typing___cast('FeatureType', 1) -INT = typing___cast('FeatureType', 2) -FLOAT = typing___cast('FeatureType', 3) -STRUCT = typing___cast('FeatureType', 4) - -class Schema(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class TensorRepresentationGroupEntry(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - key = ... # type: typing___Text - - @property - def value(self) -> TensorRepresentationGroup: ... - - def __init__(self, - *, - key : typing___Optional[typing___Text] = None, - value : typing___Optional[TensorRepresentationGroup] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> Schema.TensorRepresentationGroupEntry: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> Schema.TensorRepresentationGroupEntry: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> None: ... - - default_environment = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - - @property - def feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[Feature]: ... - - @property - def sparse_feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[SparseFeature]: ... - - @property - def weighted_feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[WeightedFeature]: ... - - @property - def string_domain(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[StringDomain]: ... - - @property - def float_domain(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[FloatDomain]: ... - - @property - def int_domain(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[IntDomain]: ... - - @property - def annotation(self) -> Annotation: ... - - @property - def dataset_constraints(self) -> DatasetConstraints: ... - - @property - def tensor_representation_group(self) -> typing___MutableMapping[typing___Text, TensorRepresentationGroup]: ... - - def __init__(self, - *, - feature : typing___Optional[typing___Iterable[Feature]] = None, - sparse_feature : typing___Optional[typing___Iterable[SparseFeature]] = None, - weighted_feature : typing___Optional[typing___Iterable[WeightedFeature]] = None, - string_domain : typing___Optional[typing___Iterable[StringDomain]] = None, - float_domain : typing___Optional[typing___Iterable[FloatDomain]] = None, - int_domain : typing___Optional[typing___Iterable[IntDomain]] = None, - default_environment : typing___Optional[typing___Iterable[typing___Text]] = None, - annotation : typing___Optional[Annotation] = None, - dataset_constraints : typing___Optional[DatasetConstraints] = None, - tensor_representation_group : typing___Optional[typing___Mapping[typing___Text, TensorRepresentationGroup]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> Schema: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> Schema: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"annotation",b"annotation",u"dataset_constraints",b"dataset_constraints"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"annotation",b"annotation",u"dataset_constraints",b"dataset_constraints",u"default_environment",b"default_environment",u"feature",b"feature",u"float_domain",b"float_domain",u"int_domain",b"int_domain",u"sparse_feature",b"sparse_feature",u"string_domain",b"string_domain",u"tensor_representation_group",b"tensor_representation_group",u"weighted_feature",b"weighted_feature"]) -> None: ... - -class Feature(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - deprecated = ... # type: builtin___bool - type = ... # type: FeatureType - domain = ... # type: typing___Text - in_environment = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - not_in_environment = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - lifecycle_stage = ... # type: LifecycleStage - - @property - def presence(self) -> FeaturePresence: ... - - @property - def group_presence(self) -> FeaturePresenceWithinGroup: ... - - @property - def shape(self) -> FixedShape: ... - - @property - def value_count(self) -> ValueCount: ... - - @property - def int_domain(self) -> IntDomain: ... - - @property - def float_domain(self) -> FloatDomain: ... - - @property - def string_domain(self) -> StringDomain: ... - - @property - def bool_domain(self) -> BoolDomain: ... - - @property - def struct_domain(self) -> StructDomain: ... - - @property - def natural_language_domain(self) -> NaturalLanguageDomain: ... - - @property - def image_domain(self) -> ImageDomain: ... - - @property - def mid_domain(self) -> MIDDomain: ... - - @property - def url_domain(self) -> URLDomain: ... - - @property - def time_domain(self) -> TimeDomain: ... - - @property - def time_of_day_domain(self) -> TimeOfDayDomain: ... - - @property - def distribution_constraints(self) -> DistributionConstraints: ... - - @property - def annotation(self) -> Annotation: ... - - @property - def skew_comparator(self) -> FeatureComparator: ... - - @property - def drift_comparator(self) -> FeatureComparator: ... - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - deprecated : typing___Optional[builtin___bool] = None, - presence : typing___Optional[FeaturePresence] = None, - group_presence : typing___Optional[FeaturePresenceWithinGroup] = None, - shape : typing___Optional[FixedShape] = None, - value_count : typing___Optional[ValueCount] = None, - type : typing___Optional[FeatureType] = None, - domain : typing___Optional[typing___Text] = None, - int_domain : typing___Optional[IntDomain] = None, - float_domain : typing___Optional[FloatDomain] = None, - string_domain : typing___Optional[StringDomain] = None, - bool_domain : typing___Optional[BoolDomain] = None, - struct_domain : typing___Optional[StructDomain] = None, - natural_language_domain : typing___Optional[NaturalLanguageDomain] = None, - image_domain : typing___Optional[ImageDomain] = None, - mid_domain : typing___Optional[MIDDomain] = None, - url_domain : typing___Optional[URLDomain] = None, - time_domain : typing___Optional[TimeDomain] = None, - time_of_day_domain : typing___Optional[TimeOfDayDomain] = None, - distribution_constraints : typing___Optional[DistributionConstraints] = None, - annotation : typing___Optional[Annotation] = None, - skew_comparator : typing___Optional[FeatureComparator] = None, - drift_comparator : typing___Optional[FeatureComparator] = None, - in_environment : typing___Optional[typing___Iterable[typing___Text]] = None, - not_in_environment : typing___Optional[typing___Iterable[typing___Text]] = None, - lifecycle_stage : typing___Optional[LifecycleStage] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> Feature: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> Feature: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"annotation",b"annotation",u"bool_domain",b"bool_domain",u"deprecated",b"deprecated",u"distribution_constraints",b"distribution_constraints",u"domain",b"domain",u"domain_info",b"domain_info",u"drift_comparator",b"drift_comparator",u"float_domain",b"float_domain",u"group_presence",b"group_presence",u"image_domain",b"image_domain",u"int_domain",b"int_domain",u"lifecycle_stage",b"lifecycle_stage",u"mid_domain",b"mid_domain",u"name",b"name",u"natural_language_domain",b"natural_language_domain",u"presence",b"presence",u"presence_constraints",b"presence_constraints",u"shape",b"shape",u"shape_type",b"shape_type",u"skew_comparator",b"skew_comparator",u"string_domain",b"string_domain",u"struct_domain",b"struct_domain",u"time_domain",b"time_domain",u"time_of_day_domain",b"time_of_day_domain",u"type",b"type",u"url_domain",b"url_domain",u"value_count",b"value_count"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"annotation",b"annotation",u"bool_domain",b"bool_domain",u"deprecated",b"deprecated",u"distribution_constraints",b"distribution_constraints",u"domain",b"domain",u"domain_info",b"domain_info",u"drift_comparator",b"drift_comparator",u"float_domain",b"float_domain",u"group_presence",b"group_presence",u"image_domain",b"image_domain",u"in_environment",b"in_environment",u"int_domain",b"int_domain",u"lifecycle_stage",b"lifecycle_stage",u"mid_domain",b"mid_domain",u"name",b"name",u"natural_language_domain",b"natural_language_domain",u"not_in_environment",b"not_in_environment",u"presence",b"presence",u"presence_constraints",b"presence_constraints",u"shape",b"shape",u"shape_type",b"shape_type",u"skew_comparator",b"skew_comparator",u"string_domain",b"string_domain",u"struct_domain",b"struct_domain",u"time_domain",b"time_domain",u"time_of_day_domain",b"time_of_day_domain",u"type",b"type",u"url_domain",b"url_domain",u"value_count",b"value_count"]) -> None: ... - @typing___overload - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"domain_info",b"domain_info"]) -> typing_extensions___Literal["domain","int_domain","float_domain","string_domain","bool_domain","struct_domain","natural_language_domain","image_domain","mid_domain","url_domain","time_domain","time_of_day_domain"]: ... - @typing___overload - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"presence_constraints",b"presence_constraints"]) -> typing_extensions___Literal["presence","group_presence"]: ... - @typing___overload - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"shape_type",b"shape_type"]) -> typing_extensions___Literal["shape","value_count"]: ... - -class Annotation(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - tag = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - comment = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - - @property - def extra_metadata(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[google___protobuf___any_pb2___Any]: ... - - def __init__(self, - *, - tag : typing___Optional[typing___Iterable[typing___Text]] = None, - comment : typing___Optional[typing___Iterable[typing___Text]] = None, - extra_metadata : typing___Optional[typing___Iterable[google___protobuf___any_pb2___Any]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> Annotation: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> Annotation: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def ClearField(self, field_name: typing_extensions___Literal[u"comment",b"comment",u"extra_metadata",b"extra_metadata",u"tag",b"tag"]) -> None: ... - -class NumericValueComparator(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - min_fraction_threshold = ... # type: builtin___float - max_fraction_threshold = ... # type: builtin___float - - def __init__(self, - *, - min_fraction_threshold : typing___Optional[builtin___float] = None, - max_fraction_threshold : typing___Optional[builtin___float] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> NumericValueComparator: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> NumericValueComparator: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"max_fraction_threshold",b"max_fraction_threshold",u"min_fraction_threshold",b"min_fraction_threshold"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"max_fraction_threshold",b"max_fraction_threshold",u"min_fraction_threshold",b"min_fraction_threshold"]) -> None: ... - -class DatasetConstraints(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - min_examples_count = ... # type: builtin___int - - @property - def num_examples_drift_comparator(self) -> NumericValueComparator: ... - - @property - def num_examples_version_comparator(self) -> NumericValueComparator: ... - - def __init__(self, - *, - num_examples_drift_comparator : typing___Optional[NumericValueComparator] = None, - num_examples_version_comparator : typing___Optional[NumericValueComparator] = None, - min_examples_count : typing___Optional[builtin___int] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> DatasetConstraints: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> DatasetConstraints: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"min_examples_count",b"min_examples_count",u"num_examples_drift_comparator",b"num_examples_drift_comparator",u"num_examples_version_comparator",b"num_examples_version_comparator"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"min_examples_count",b"min_examples_count",u"num_examples_drift_comparator",b"num_examples_drift_comparator",u"num_examples_version_comparator",b"num_examples_version_comparator"]) -> None: ... - -class FixedShape(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class Dim(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - size = ... # type: builtin___int - name = ... # type: typing___Text - - def __init__(self, - *, - size : typing___Optional[builtin___int] = None, - name : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FixedShape.Dim: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FixedShape.Dim: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"name",b"name",u"size",b"size"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"name",b"name",u"size",b"size"]) -> None: ... - - - @property - def dim(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[FixedShape.Dim]: ... - - def __init__(self, - *, - dim : typing___Optional[typing___Iterable[FixedShape.Dim]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FixedShape: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FixedShape: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def ClearField(self, field_name: typing_extensions___Literal[u"dim",b"dim"]) -> None: ... - -class ValueCount(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - min = ... # type: builtin___int - max = ... # type: builtin___int - - def __init__(self, - *, - min : typing___Optional[builtin___int] = None, - max : typing___Optional[builtin___int] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> ValueCount: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> ValueCount: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"max",b"max",u"min",b"min"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"max",b"max",u"min",b"min"]) -> None: ... - -class WeightedFeature(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - lifecycle_stage = ... # type: LifecycleStage - - @property - def feature(self) -> tensorflow_metadata___proto___v0___path_pb2___Path: ... - - @property - def weight_feature(self) -> tensorflow_metadata___proto___v0___path_pb2___Path: ... - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - feature : typing___Optional[tensorflow_metadata___proto___v0___path_pb2___Path] = None, - weight_feature : typing___Optional[tensorflow_metadata___proto___v0___path_pb2___Path] = None, - lifecycle_stage : typing___Optional[LifecycleStage] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> WeightedFeature: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> WeightedFeature: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"feature",b"feature",u"lifecycle_stage",b"lifecycle_stage",u"name",b"name",u"weight_feature",b"weight_feature"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"feature",b"feature",u"lifecycle_stage",b"lifecycle_stage",u"name",b"name",u"weight_feature",b"weight_feature"]) -> None: ... - -class SparseFeature(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class IndexFeature(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> SparseFeature.IndexFeature: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> SparseFeature.IndexFeature: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"name",b"name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"name",b"name"]) -> None: ... - - class ValueFeature(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> SparseFeature.ValueFeature: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> SparseFeature.ValueFeature: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"name",b"name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"name",b"name"]) -> None: ... - - name = ... # type: typing___Text - deprecated = ... # type: builtin___bool - lifecycle_stage = ... # type: LifecycleStage - is_sorted = ... # type: builtin___bool - type = ... # type: FeatureType - - @property - def presence(self) -> FeaturePresence: ... - - @property - def dense_shape(self) -> FixedShape: ... - - @property - def index_feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[SparseFeature.IndexFeature]: ... - - @property - def value_feature(self) -> SparseFeature.ValueFeature: ... - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - deprecated : typing___Optional[builtin___bool] = None, - lifecycle_stage : typing___Optional[LifecycleStage] = None, - presence : typing___Optional[FeaturePresence] = None, - dense_shape : typing___Optional[FixedShape] = None, - index_feature : typing___Optional[typing___Iterable[SparseFeature.IndexFeature]] = None, - is_sorted : typing___Optional[builtin___bool] = None, - value_feature : typing___Optional[SparseFeature.ValueFeature] = None, - type : typing___Optional[FeatureType] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> SparseFeature: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> SparseFeature: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"dense_shape",b"dense_shape",u"deprecated",b"deprecated",u"is_sorted",b"is_sorted",u"lifecycle_stage",b"lifecycle_stage",u"name",b"name",u"presence",b"presence",u"type",b"type",u"value_feature",b"value_feature"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"dense_shape",b"dense_shape",u"deprecated",b"deprecated",u"index_feature",b"index_feature",u"is_sorted",b"is_sorted",u"lifecycle_stage",b"lifecycle_stage",u"name",b"name",u"presence",b"presence",u"type",b"type",u"value_feature",b"value_feature"]) -> None: ... - -class DistributionConstraints(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - min_domain_mass = ... # type: builtin___float - - def __init__(self, - *, - min_domain_mass : typing___Optional[builtin___float] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> DistributionConstraints: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> DistributionConstraints: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"min_domain_mass",b"min_domain_mass"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"min_domain_mass",b"min_domain_mass"]) -> None: ... - -class IntDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - min = ... # type: builtin___int - max = ... # type: builtin___int - is_categorical = ... # type: builtin___bool - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - min : typing___Optional[builtin___int] = None, - max : typing___Optional[builtin___int] = None, - is_categorical : typing___Optional[builtin___bool] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> IntDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> IntDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"is_categorical",b"is_categorical",u"max",b"max",u"min",b"min",u"name",b"name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"is_categorical",b"is_categorical",u"max",b"max",u"min",b"min",u"name",b"name"]) -> None: ... - -class FloatDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - min = ... # type: builtin___float - max = ... # type: builtin___float - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - min : typing___Optional[builtin___float] = None, - max : typing___Optional[builtin___float] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FloatDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FloatDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"max",b"max",u"min",b"min",u"name",b"name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"max",b"max",u"min",b"min",u"name",b"name"]) -> None: ... - -class StructDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - @property - def feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[Feature]: ... - - @property - def sparse_feature(self) -> google___protobuf___internal___containers___RepeatedCompositeFieldContainer[SparseFeature]: ... - - def __init__(self, - *, - feature : typing___Optional[typing___Iterable[Feature]] = None, - sparse_feature : typing___Optional[typing___Iterable[SparseFeature]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> StructDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> StructDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def ClearField(self, field_name: typing_extensions___Literal[u"feature",b"feature",u"sparse_feature",b"sparse_feature"]) -> None: ... - -class StringDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - value = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - value : typing___Optional[typing___Iterable[typing___Text]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> StringDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> StringDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"name",b"name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"name",b"name",u"value",b"value"]) -> None: ... - -class BoolDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - name = ... # type: typing___Text - true_value = ... # type: typing___Text - false_value = ... # type: typing___Text - - def __init__(self, - *, - name : typing___Optional[typing___Text] = None, - true_value : typing___Optional[typing___Text] = None, - false_value : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> BoolDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> BoolDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"false_value",b"false_value",u"name",b"name",u"true_value",b"true_value"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"false_value",b"false_value",u"name",b"name",u"true_value",b"true_value"]) -> None: ... - -class NaturalLanguageDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - def __init__(self, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> NaturalLanguageDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> NaturalLanguageDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - -class ImageDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - def __init__(self, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> ImageDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> ImageDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - -class MIDDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - def __init__(self, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> MIDDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> MIDDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - -class URLDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - def __init__(self, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> URLDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> URLDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - -class TimeDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class IntegerTimeFormat(builtin___int): - DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... - @classmethod - def Name(cls, number: builtin___int) -> builtin___str: ... - @classmethod - def Value(cls, name: builtin___str) -> 'TimeDomain.IntegerTimeFormat': ... - @classmethod - def keys(cls) -> typing___List[builtin___str]: ... - @classmethod - def values(cls) -> typing___List['TimeDomain.IntegerTimeFormat']: ... - @classmethod - def items(cls) -> typing___List[typing___Tuple[builtin___str, 'TimeDomain.IntegerTimeFormat']]: ... - FORMAT_UNKNOWN = typing___cast('TimeDomain.IntegerTimeFormat', 0) - UNIX_DAYS = typing___cast('TimeDomain.IntegerTimeFormat', 5) - UNIX_SECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 1) - UNIX_MILLISECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 2) - UNIX_MICROSECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 3) - UNIX_NANOSECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 4) - FORMAT_UNKNOWN = typing___cast('TimeDomain.IntegerTimeFormat', 0) - UNIX_DAYS = typing___cast('TimeDomain.IntegerTimeFormat', 5) - UNIX_SECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 1) - UNIX_MILLISECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 2) - UNIX_MICROSECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 3) - UNIX_NANOSECONDS = typing___cast('TimeDomain.IntegerTimeFormat', 4) - - string_format = ... # type: typing___Text - integer_format = ... # type: TimeDomain.IntegerTimeFormat - - def __init__(self, - *, - string_format : typing___Optional[typing___Text] = None, - integer_format : typing___Optional[TimeDomain.IntegerTimeFormat] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TimeDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TimeDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"format",b"format",u"integer_format",b"integer_format",u"string_format",b"string_format"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"format",b"format",u"integer_format",b"integer_format",u"string_format",b"string_format"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"format",b"format"]) -> typing_extensions___Literal["string_format","integer_format"]: ... - -class TimeOfDayDomain(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class IntegerTimeOfDayFormat(builtin___int): - DESCRIPTOR: google___protobuf___descriptor___EnumDescriptor = ... - @classmethod - def Name(cls, number: builtin___int) -> builtin___str: ... - @classmethod - def Value(cls, name: builtin___str) -> 'TimeOfDayDomain.IntegerTimeOfDayFormat': ... - @classmethod - def keys(cls) -> typing___List[builtin___str]: ... - @classmethod - def values(cls) -> typing___List['TimeOfDayDomain.IntegerTimeOfDayFormat']: ... - @classmethod - def items(cls) -> typing___List[typing___Tuple[builtin___str, 'TimeOfDayDomain.IntegerTimeOfDayFormat']]: ... - FORMAT_UNKNOWN = typing___cast('TimeOfDayDomain.IntegerTimeOfDayFormat', 0) - PACKED_64_NANOS = typing___cast('TimeOfDayDomain.IntegerTimeOfDayFormat', 1) - FORMAT_UNKNOWN = typing___cast('TimeOfDayDomain.IntegerTimeOfDayFormat', 0) - PACKED_64_NANOS = typing___cast('TimeOfDayDomain.IntegerTimeOfDayFormat', 1) - - string_format = ... # type: typing___Text - integer_format = ... # type: TimeOfDayDomain.IntegerTimeOfDayFormat - - def __init__(self, - *, - string_format : typing___Optional[typing___Text] = None, - integer_format : typing___Optional[TimeOfDayDomain.IntegerTimeOfDayFormat] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TimeOfDayDomain: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TimeOfDayDomain: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"format",b"format",u"integer_format",b"integer_format",u"string_format",b"string_format"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"format",b"format",u"integer_format",b"integer_format",u"string_format",b"string_format"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"format",b"format"]) -> typing_extensions___Literal["string_format","integer_format"]: ... - -class FeaturePresence(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - min_fraction = ... # type: builtin___float - min_count = ... # type: builtin___int - - def __init__(self, - *, - min_fraction : typing___Optional[builtin___float] = None, - min_count : typing___Optional[builtin___int] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FeaturePresence: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FeaturePresence: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"min_count",b"min_count",u"min_fraction",b"min_fraction"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"min_count",b"min_count",u"min_fraction",b"min_fraction"]) -> None: ... - -class FeaturePresenceWithinGroup(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - required = ... # type: builtin___bool - - def __init__(self, - *, - required : typing___Optional[builtin___bool] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FeaturePresenceWithinGroup: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FeaturePresenceWithinGroup: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"required",b"required"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"required",b"required"]) -> None: ... - -class InfinityNorm(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - threshold = ... # type: builtin___float - - def __init__(self, - *, - threshold : typing___Optional[builtin___float] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> InfinityNorm: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> InfinityNorm: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"threshold",b"threshold"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"threshold",b"threshold"]) -> None: ... - -class FeatureComparator(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - - @property - def infinity_norm(self) -> InfinityNorm: ... - - def __init__(self, - *, - infinity_norm : typing___Optional[InfinityNorm] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> FeatureComparator: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> FeatureComparator: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"infinity_norm",b"infinity_norm"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"infinity_norm",b"infinity_norm"]) -> None: ... - -class TensorRepresentation(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class DefaultValue(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - float_value = ... # type: builtin___float - int_value = ... # type: builtin___int - bytes_value = ... # type: builtin___bytes - uint_value = ... # type: builtin___int - - def __init__(self, - *, - float_value : typing___Optional[builtin___float] = None, - int_value : typing___Optional[builtin___int] = None, - bytes_value : typing___Optional[builtin___bytes] = None, - uint_value : typing___Optional[builtin___int] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentation.DefaultValue: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentation.DefaultValue: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"bytes_value",b"bytes_value",u"float_value",b"float_value",u"int_value",b"int_value",u"kind",b"kind",u"uint_value",b"uint_value"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"bytes_value",b"bytes_value",u"float_value",b"float_value",u"int_value",b"int_value",u"kind",b"kind",u"uint_value",b"uint_value"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"kind",b"kind"]) -> typing_extensions___Literal["float_value","int_value","bytes_value","uint_value"]: ... - - class DenseTensor(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - column_name = ... # type: typing___Text - - @property - def shape(self) -> FixedShape: ... - - @property - def default_value(self) -> TensorRepresentation.DefaultValue: ... - - def __init__(self, - *, - column_name : typing___Optional[typing___Text] = None, - shape : typing___Optional[FixedShape] = None, - default_value : typing___Optional[TensorRepresentation.DefaultValue] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentation.DenseTensor: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentation.DenseTensor: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"column_name",b"column_name",u"default_value",b"default_value",u"shape",b"shape"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"column_name",b"column_name",u"default_value",b"default_value",u"shape",b"shape"]) -> None: ... - - class VarLenSparseTensor(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - column_name = ... # type: typing___Text - - def __init__(self, - *, - column_name : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentation.VarLenSparseTensor: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentation.VarLenSparseTensor: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"column_name",b"column_name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"column_name",b"column_name"]) -> None: ... - - class SparseTensor(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - index_column_names = ... # type: google___protobuf___internal___containers___RepeatedScalarFieldContainer[typing___Text] - value_column_name = ... # type: typing___Text - - @property - def dense_shape(self) -> FixedShape: ... - - def __init__(self, - *, - dense_shape : typing___Optional[FixedShape] = None, - index_column_names : typing___Optional[typing___Iterable[typing___Text]] = None, - value_column_name : typing___Optional[typing___Text] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentation.SparseTensor: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentation.SparseTensor: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"dense_shape",b"dense_shape",u"value_column_name",b"value_column_name"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"dense_shape",b"dense_shape",u"index_column_names",b"index_column_names",u"value_column_name",b"value_column_name"]) -> None: ... - - - @property - def dense_tensor(self) -> TensorRepresentation.DenseTensor: ... - - @property - def varlen_sparse_tensor(self) -> TensorRepresentation.VarLenSparseTensor: ... - - @property - def sparse_tensor(self) -> TensorRepresentation.SparseTensor: ... - - def __init__(self, - *, - dense_tensor : typing___Optional[TensorRepresentation.DenseTensor] = None, - varlen_sparse_tensor : typing___Optional[TensorRepresentation.VarLenSparseTensor] = None, - sparse_tensor : typing___Optional[TensorRepresentation.SparseTensor] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentation: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentation: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"dense_tensor",b"dense_tensor",u"kind",b"kind",u"sparse_tensor",b"sparse_tensor",u"varlen_sparse_tensor",b"varlen_sparse_tensor"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"dense_tensor",b"dense_tensor",u"kind",b"kind",u"sparse_tensor",b"sparse_tensor",u"varlen_sparse_tensor",b"varlen_sparse_tensor"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions___Literal[u"kind",b"kind"]) -> typing_extensions___Literal["dense_tensor","varlen_sparse_tensor","sparse_tensor"]: ... - -class TensorRepresentationGroup(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - class TensorRepresentationEntry(google___protobuf___message___Message): - DESCRIPTOR: google___protobuf___descriptor___Descriptor = ... - key = ... # type: typing___Text - - @property - def value(self) -> TensorRepresentation: ... - - def __init__(self, - *, - key : typing___Optional[typing___Text] = None, - value : typing___Optional[TensorRepresentation] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentationGroup.TensorRepresentationEntry: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentationGroup.TensorRepresentationEntry: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def HasField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> builtin___bool: ... - def ClearField(self, field_name: typing_extensions___Literal[u"key",b"key",u"value",b"value"]) -> None: ... - - - @property - def tensor_representation(self) -> typing___MutableMapping[typing___Text, TensorRepresentation]: ... - - def __init__(self, - *, - tensor_representation : typing___Optional[typing___Mapping[typing___Text, TensorRepresentation]] = None, - ) -> None: ... - if sys.version_info >= (3,): - @classmethod - def FromString(cls, s: builtin___bytes) -> TensorRepresentationGroup: ... - else: - @classmethod - def FromString(cls, s: typing___Union[builtin___bytes, builtin___buffer, builtin___unicode]) -> TensorRepresentationGroup: ... - def MergeFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def CopyFrom(self, other_msg: google___protobuf___message___Message) -> None: ... - def ClearField(self, field_name: typing_extensions___Literal[u"tensor_representation",b"tensor_representation"]) -> None: ... diff --git a/sdk/python/tests/data/tensorflow_metadata/bikeshare_feature_set.yaml b/sdk/python/tests/data/tensorflow_metadata/bikeshare_feature_set.yaml new file mode 100644 index 00000000000..daa0a35f0ab --- /dev/null +++ b/sdk/python/tests/data/tensorflow_metadata/bikeshare_feature_set.yaml @@ -0,0 +1,81 @@ +spec: + name: bikeshare + entities: + - name: station_id + valueType: INT64 + intDomain: + min: 1 + max: 5000 + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + features: + - name: location + valueType: STRING + stringDomain: + name: location + value: + - (30.24258, -97.71726) + - (30.24472, -97.72336) + - (30.24891, -97.75019) + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + - name: name + valueType: STRING + stringDomain: + name: name + value: + - 10th & Red River + - 11th & Salina + - 11th & San Jacinto + - 13th & San Antonio + - 17th & Guadalupe + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + - name: status + valueType: STRING + stringDomain: + name: status + value: + - "active" + - "closed" + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + - name: latitude + valueType: DOUBLE + floatDomain: + min: 100.0 + max: 105.0 + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + - name: longitude + valueType: DOUBLE + floatDomain: + min: 102.0 + max: 105.0 + presence: + minFraction: 1.0 + minCount: 1 + shape: + dim: + - size: 1 + maxAge: 3600s diff --git a/sdk/python/tests/data/tensorflow_metadata/bikeshare_schema.json b/sdk/python/tests/data/tensorflow_metadata/bikeshare_schema.json new file mode 100644 index 00000000000..e7a886053c1 --- /dev/null +++ b/sdk/python/tests/data/tensorflow_metadata/bikeshare_schema.json @@ -0,0 +1,136 @@ +{ + "feature": [ + { + "name": "location", + "type": "BYTES", + "domain": "location", + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + }, + { + "name": "name", + "type": "BYTES", + "domain": "name", + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + }, + { + "name": "status", + "type": "BYTES", + "domain": "status", + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + }, + { + "name": "latitude", + "type": "FLOAT", + "float_domain": { + "min": 100.0, + "max": 105.0 + }, + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + }, + { + "name": "longitude", + "type": "FLOAT", + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "float_domain": { + "min": 102.0, + "max": 105.0 + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + }, + { + "name": "station_id", + "type": "INT", + "presence": { + "minFraction": 1.0, + "minCount": "1" + }, + "int_domain": { + "min": 1, + "max": 5000 + }, + "shape": { + "dim": [ + { + "size": "1" + } + ] + } + } + ], + "stringDomain": [ + { + "name": "location", + "value": [ + "(30.24258, -97.71726)", + "(30.24472, -97.72336)", + "(30.24891, -97.75019)" + ] + }, + { + "name": "name", + "value": [ + "10th & Red River", + "11th & Salina", + "11th & San Jacinto", + "13th & San Antonio", + "17th & Guadalupe" + ] + }, + { + "name": "status", + "value": [ + "active", + "closed" + ] + } + ] +} \ No newline at end of file diff --git a/sdk/python/tests/test_feature_set.py b/sdk/python/tests/test_feature_set.py index bd31d712bb3..6f087d98bbf 100644 --- a/sdk/python/tests/test_feature_set.py +++ b/sdk/python/tests/test_feature_set.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pathlib from concurrent import futures from datetime import datetime @@ -18,12 +19,19 @@ import pandas as pd import pytest import pytz +from google.protobuf import json_format +from tensorflow_metadata.proto.v0 import schema_pb2 import dataframes import feast.core.CoreService_pb2_grpc as Core from feast.client import Client from feast.entity import Entity -from feast.feature_set import Feature, FeatureSet, FeatureSetRef +from feast.feature_set import ( + Feature, + FeatureSet, + FeatureSetRef, + _make_tfx_schema_domain_info_inline, +) from feast.value_type import ValueType from feast_core_server import CoreServicer @@ -168,6 +176,97 @@ def test_add_features_from_df_success( assert len(my_feature_set.features) == feature_count assert len(my_feature_set.entities) == entity_count + def test_import_tfx_schema(self): + tests_folder = pathlib.Path(__file__).parent + test_input_schema_json = open( + tests_folder / "data" / "tensorflow_metadata" / "bikeshare_schema.json" + ).read() + test_input_schema = schema_pb2.Schema() + json_format.Parse(test_input_schema_json, test_input_schema) + + feature_set = FeatureSet( + name="bikeshare", + entities=[Entity(name="station_id", dtype=ValueType.INT64)], + features=[ + Feature(name="name", dtype=ValueType.STRING), + Feature(name="status", dtype=ValueType.STRING), + Feature(name="latitude", dtype=ValueType.FLOAT), + Feature(name="longitude", dtype=ValueType.FLOAT), + Feature(name="location", dtype=ValueType.STRING), + ], + ) + + # Before update + for entity in feature_set.entities: + assert entity.presence is None + assert entity.shape is None + for feature in feature_set.features: + assert feature.presence is None + assert feature.shape is None + assert feature.string_domain is None + assert feature.float_domain is None + assert feature.int_domain is None + + feature_set.import_tfx_schema(test_input_schema) + + # After update + for entity in feature_set.entities: + assert entity.presence is not None + assert entity.shape is not None + for feature in feature_set.features: + assert feature.presence is not None + assert feature.shape is not None + if feature.name in ["location", "name", "status"]: + assert feature.string_domain is not None + elif feature.name in ["latitude", "longitude"]: + assert feature.float_domain is not None + elif feature.name in ["station_id"]: + assert feature.int_domain is not None + + def test_export_tfx_schema(self): + tests_folder = pathlib.Path(__file__).parent + test_input_feature_set = FeatureSet.from_yaml( + str( + tests_folder + / "data" + / "tensorflow_metadata" + / "bikeshare_feature_set.yaml" + ) + ) + + expected_schema_json = open( + tests_folder / "data" / "tensorflow_metadata" / "bikeshare_schema.json" + ).read() + expected_schema = schema_pb2.Schema() + json_format.Parse(expected_schema_json, expected_schema) + _make_tfx_schema_domain_info_inline(expected_schema) + + actual_schema = test_input_feature_set.export_tfx_schema() + + assert len(actual_schema.feature) == len(expected_schema.feature) + for actual, expected in zip(actual_schema.feature, expected_schema.feature): + assert actual.SerializeToString() == expected.SerializeToString() + + +def make_tfx_schema_domain_info_inline(schema): + # Copy top-level domain info defined in the schema to inline definition. + # One use case is in FeatureSet which does not have access to the top-level domain + # info. + domain_ref_to_string_domain = {d.name: d for d in schema.string_domain} + domain_ref_to_float_domain = {d.name: d for d in schema.float_domain} + domain_ref_to_int_domain = {d.name: d for d in schema.int_domain} + + for feature in schema.feature: + domain_info_case = feature.WhichOneof("domain_info") + if domain_info_case == "domain": + domain_ref = feature.domain + if domain_ref in domain_ref_to_string_domain: + feature.string_domain.MergeFrom(domain_ref_to_string_domain[domain_ref]) + elif domain_ref in domain_ref_to_float_domain: + feature.float_domain.MergeFrom(domain_ref_to_float_domain[domain_ref]) + elif domain_ref in domain_ref_to_int_domain: + feature.int_domain.MergeFrom(domain_ref_to_int_domain[domain_ref]) + class TestFeatureSetRef: def test_from_feature_set(self): From 381dd59a158a9eb716e4a55e8b421ed07acf5643 Mon Sep 17 00:00:00 2001 From: Willem Pienaar Date: Sat, 11 Apr 2020 10:50:25 +0800 Subject: [PATCH 3/4] Add validation to Core configuration and fix version loading Refactor, document, and validate Feast Core Properties Refactor FeastProperties to support nested store configuration Localize all store configuration in Serving in Spring configuration Various configuration updates * Allow Feast Serving to use types properties instead of maps * Reuse Feast Core Store model in serving * Remove redundant config classes for Redis * Update Serving Beans and Config classes to use ne1w configuration getters * Remove hot-loading from store configuration. This reduces a bit of flexibility, but simplifies the code and configuration --- core/pom.xml | 24 ++ .../feast/core/config/FeastProperties.java | 234 ++++++++++++++- .../core/config/FeatureStreamConfig.java | 9 +- .../java/feast/core/config/JobConfig.java | 17 +- .../core/job/dataflow/DataflowJobManager.java | 8 +- .../core/service/JobCoordinatorService.java | 11 +- core/src/main/resources/application.yml | 1 - .../service/JobCoordinatorServiceTest.java | 19 +- pom.xml | 8 + protos/feast/core/Store.proto | 3 + serving/pom.xml | 21 +- .../java/feast/serving/FeastProperties.java | 191 ------------ .../feast/serving/ServingApplication.java | 1 + .../ContextClosedHandler.java | 2 +- .../feast/serving/config/FeastProperties.java | 283 ++++++++++++++++++ .../InstrumentationConfig.java | 3 +- .../JobServiceConfig.java | 21 +- .../JobStoreConfig.java} | 27 +- .../ServingApiConfiguration.java | 2 +- .../ServingServiceConfig.java | 64 ++-- .../SpecServiceConfig.java | 14 +- .../redis/JobStoreRedisConfig.java | 68 ----- .../redis/ServingStoreRedisConfig.java | 62 ---- .../ServingServiceGRpcController.java | 2 +- .../ServingServiceRestController.java | 2 +- .../service/RedisBackedJobService.java | 5 + .../serving/specs/CachedSpecService.java | 39 +-- .../util/mappers/YamlToProtoMapper.java | 22 +- serving/src/main/resources/application.yml | 67 ++--- .../ServingServiceGRpcControllerTest.java | 2 +- .../service/CachedSpecServiceTest.java | 29 +- .../service/RedisBackedJobServiceTest.java | 3 +- .../redis/retriever/RedisOnlineRetriever.java | 11 + 33 files changed, 702 insertions(+), 573 deletions(-) delete mode 100644 serving/src/main/java/feast/serving/FeastProperties.java rename serving/src/main/java/feast/serving/{configuration => config}/ContextClosedHandler.java (96%) create mode 100644 serving/src/main/java/feast/serving/config/FeastProperties.java rename serving/src/main/java/feast/serving/{configuration => config}/InstrumentationConfig.java (96%) rename serving/src/main/java/feast/serving/{configuration => config}/JobServiceConfig.java (62%) rename serving/src/main/java/feast/serving/{configuration/StoreConfiguration.java => config/JobStoreConfig.java} (57%) rename serving/src/main/java/feast/serving/{configuration => config}/ServingApiConfiguration.java (97%) rename serving/src/main/java/feast/serving/{configuration => config}/ServingServiceConfig.java (64%) rename serving/src/main/java/feast/serving/{configuration => config}/SpecServiceConfig.java (90%) delete mode 100644 serving/src/main/java/feast/serving/configuration/redis/JobStoreRedisConfig.java delete mode 100644 serving/src/main/java/feast/serving/configuration/redis/ServingStoreRedisConfig.java diff --git a/core/pom.xml b/core/pom.xml index 7961b45074b..f4fb6c659c0 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -38,6 +38,14 @@ false + + + build-info + + build-info + + + @@ -207,5 +215,21 @@ jaxb-api + + javax.validation + validation-api + 2.0.0.Final + + + org.hibernate.validator + hibernate-validator + 6.1.2.Final + + + org.hibernate.validator + hibernate-validator-annotation-processor + 6.1.2.Final + + diff --git a/core/src/main/java/feast/core/config/FeastProperties.java b/core/src/main/java/feast/core/config/FeastProperties.java index b9c787b6c77..59324d9567e 100644 --- a/core/src/main/java/feast/core/config/FeastProperties.java +++ b/core/src/main/java/feast/core/config/FeastProperties.java @@ -16,53 +16,257 @@ */ package feast.core.config; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import feast.core.config.FeastProperties.JobProperties.RunnerOptions; +import feast.core.config.FeastProperties.StreamProperties.FeatureStreamOptions; +import java.util.Arrays; +import java.util.HashMap; import java.util.Map; +import java.util.Objects; +import java.util.Set; +import javax.annotation.PostConstruct; +import javax.validation.ConstraintViolation; +import javax.validation.ConstraintViolationException; +import javax.validation.Validation; +import javax.validation.Validator; +import javax.validation.ValidatorFactory; +import javax.validation.constraints.AssertTrue; +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.NotNull; +import javax.validation.constraints.Positive; import lombok.Getter; import lombok.Setter; +import org.hibernate.validator.constraints.URL; +import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.info.BuildProperties; @Getter @Setter @ConfigurationProperties(prefix = "feast", ignoreInvalidFields = true) public class FeastProperties { - private String version; - private JobProperties jobs; + @Autowired + public FeastProperties(BuildProperties buildProperties) { + setVersion(buildProperties.getVersion()); + } + + public FeastProperties() { + setVersion("unknown"); + } + + /* Feast Core Build Version */ + @NotBlank private String version; + + /* Population job properties */ + @NotNull private JobProperties jobs; + + @NotNull + /* Feast Kafka stream properties */ private StreamProperties stream; @Getter @Setter public static class JobProperties { + @NotBlank + /* Apache Beam runner type. Possible options: DirectRunner, DataflowRunner */ private String runner; - private Map options; + + /* Apache Beam runner options for population jobs */ + private RunnerOptions runnerOptions; + + /* (Optional) Additional arguments to pass to Beam population jobs */ + private Map extraRunnerOptions; + + @NotNull + /* Population job metric properties */ private MetricsProperties metrics; - private JobUpdatesProperties updates; - } - @Getter - @Setter - public static class JobUpdatesProperties { + /* Timeout in seconds for each attempt to update or submit a new job to the runner */ + @Positive private long jobUpdateTimeout; + + /* Job update polling interval in millisecond. How frequently Feast will update running jobs. */ + @Positive private long pollingIntervalMillis; + + /** Apache Beam runner options for population jobs */ + @Getter + @Setter + public static class RunnerOptions { + + /* (Dataflow Runner Only) Project id to use when launching jobs. */ + @NotBlank private String project; + + /* (Dataflow Runner Only) The Google Compute Engine region for creating Dataflow jobs. */ + @NotBlank private String region; + + /* (Dataflow Runner Only) GCP availability zone for operations. */ + @NotBlank private String zone; + + /* (Dataflow Runner Only) Run the job as a specific service account, instead of the default GCE robot. */ + @NotBlank private String serviceAccount; + + /* (Dataflow Runner Only) GCE network for launching workers. */ + @NotBlank private String network; + + /* (Dataflow Runner Only) GCE subnetwork for launching workers. */ + @NotBlank private String subnetwork; + + /* (Dataflow Runner Only) Machine type to create Dataflow worker VMs as. */ + private String workerMachineType; + + /* (Dataflow Runner Only) The autoscaling algorithm to use for the workerpool. */ + private String autoscalingAlgorithm; + + /* (Dataflow Runner Only) Specifies whether worker pools should be started with public IP addresses. */ + private Boolean usePublicIps; + + /** + * (Dataflow Runner Only) A pipeline level default location for storing temporary files. + * Support Google Cloud Storage locations, e.g. gs://bucket/object + */ + @NotBlank private String tempLocation; + + /* (Dataflow Runner Only) The maximum number of workers to use for the workerpool. */ + private Integer maxNumWorkers; - private long timeoutSeconds; - private long pollingIntervalMillis; + /** + * (Direct Runner Only) Controls the amount of target parallelism the DirectRunner will use. + * Defaults to the greater of the number of available processors and 3. Must be a value + * greater than zero. + */ + private Integer targetParallelism; + + /* BigQuery table specification, e.g. PROJECT_ID:DATASET_ID.PROJECT_ID */ + private String deadLetterTableSpec; + } + + public Map getRunnerOptionsMap() { + // First collect the existing "extra options" + Map combinedOptions = new HashMap(getExtraRunnerOptions()); + + // Convert all fields in RunnerOptions to and merge + ObjectMapper oMapper = new ObjectMapper(); + combinedOptions.putAll( + oMapper.convertValue( + getRunnerOptions(), new TypeReference>() {})); + + return combinedOptions; + } } + @AssertTrue + public boolean isValidJobRunnerSelected() { + String[] validRunners = new String[] {"DataflowRunner", "DirectRunner"}; + return Arrays.asList(validRunners).contains(getJobs().getRunner()); + } + + /** Properties used to configure Feast's managed Kafka feature stream. */ @Getter @Setter public static class StreamProperties { - private String type; - private Map options; + /* Feature stream type. Only "kafka" is supported. */ + @NotBlank private String type; + + /* Feature stream options */ + @NotNull private FeatureStreamOptions options; + + /** Feature stream options */ + @Getter + @Setter + public static class FeatureStreamOptions { + + /* Kafka topic to use for feature sets without source topics. */ + @NotBlank private String topic = "feast-features"; + + /** + * Comma separated list of Kafka bootstrap servers. Used for feature sets without a defined + * source. + */ + @NotBlank private String bootstrapServers = "localhost:9092"; + + /* Defines the number of copies of managed feature stream Kafka. */ + @Positive private short replicationFactor = 1; + + /* Number of Kafka partitions to to use for managed feature stream. */ + @Positive private int partitions = 1; + } } + @AssertTrue + public boolean isValidStreamTypeSelected() { + return Objects.equals(getStream().getType(), "kafka"); + } + + /** Feast population job metrics */ @Getter @Setter public static class MetricsProperties { + /* Population job metrics enabled */ private boolean enabled; - private String type; - private String host; - private int port; + + /* Metric type. Possible options: statsd */ + @NotBlank private String type; + + /* Host of metric sink */ + @URL private String host; + + /* Port of metric sink */ + @Positive private int port; + } + + /** + * Validates all FeastProperties. This method runs after properties have been initialized and + * individually and conditionally validates each class. + */ + @PostConstruct + public void validate() { + ValidatorFactory factory = Validation.buildDefaultValidatorFactory(); + Validator validator = factory.getValidator(); + + // Validate root fields in FeastProperties + Set> violations = validator.validate(this); + if (!violations.isEmpty()) { + throw new ConstraintViolationException(violations); + } + + // Validate Stream properties + Set> streamPropertyViolations = + validator.validate(getStream()); + if (!streamPropertyViolations.isEmpty()) { + throw new ConstraintViolationException(streamPropertyViolations); + } + + // Validate Stream Options + Set> featureStreamOptionsViolations = + validator.validate(getStream().getOptions()); + if (!featureStreamOptionsViolations.isEmpty()) { + throw new ConstraintViolationException(featureStreamOptionsViolations); + } + + // Validate JobProperties + Set> jobPropertiesViolations = validator.validate(getJobs()); + if (!jobPropertiesViolations.isEmpty()) { + throw new ConstraintViolationException(jobPropertiesViolations); + } + + // Validate RunnerOptions + Set> runnerOptionsViolations = + validator.validate(getJobs().getRunnerOptions()); + if (!runnerOptionsViolations.isEmpty()) { + throw new ConstraintViolationException(runnerOptionsViolations); + } + + // Validate MetricsProperties + if (getJobs().getMetrics().isEnabled()) { + Set> jobMetricViolations = + validator.validate(getJobs().getMetrics()); + if (!jobMetricViolations.isEmpty()) { + throw new ConstraintViolationException(jobMetricViolations); + } + } } } diff --git a/core/src/main/java/feast/core/config/FeatureStreamConfig.java b/core/src/main/java/feast/core/config/FeatureStreamConfig.java index 45de359ac76..44f0e0e0993 100644 --- a/core/src/main/java/feast/core/config/FeatureStreamConfig.java +++ b/core/src/main/java/feast/core/config/FeatureStreamConfig.java @@ -48,8 +48,8 @@ public Source getDefaultSource(FeastProperties feastProperties) { SourceType featureStreamType = SourceType.valueOf(streamProperties.getType().toUpperCase()); switch (featureStreamType) { case KAFKA: - String bootstrapServers = streamProperties.getOptions().get("bootstrapServers"); - String topicName = streamProperties.getOptions().get("topic"); + String bootstrapServers = streamProperties.getOptions().getBootstrapServers(); + String topicName = streamProperties.getOptions().getTopic(); Map map = new HashMap<>(); map.put(AdminClientConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers); map.put( @@ -59,9 +59,8 @@ public Source getDefaultSource(FeastProperties feastProperties) { NewTopic newTopic = new NewTopic( topicName, - Integer.valueOf(streamProperties.getOptions().getOrDefault("numPartitions", "1")), - Short.valueOf( - streamProperties.getOptions().getOrDefault("replicationFactor", "1"))); + streamProperties.getOptions().getPartitions(), + streamProperties.getOptions().getReplicationFactor()); CreateTopicsResult createTopicsResult = client.createTopics(Collections.singleton(newTopic)); try { diff --git a/core/src/main/java/feast/core/config/JobConfig.java b/core/src/main/java/feast/core/config/JobConfig.java index 728fc0545bf..85641681bff 100644 --- a/core/src/main/java/feast/core/config/JobConfig.java +++ b/core/src/main/java/feast/core/config/JobConfig.java @@ -23,7 +23,6 @@ import com.google.api.services.dataflow.DataflowScopes; import com.google.common.base.Strings; import feast.core.config.FeastProperties.JobProperties; -import feast.core.config.FeastProperties.JobUpdatesProperties; import feast.core.job.JobManager; import feast.core.job.Runner; import feast.core.job.dataflow.DataflowJobManager; @@ -31,7 +30,6 @@ import feast.core.job.direct.DirectRunnerJobManager; import java.io.IOException; import java.security.GeneralSecurityException; -import java.util.HashMap; import java.util.Map; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; @@ -55,10 +53,7 @@ public JobManager getJobManager( JobProperties jobProperties = feastProperties.getJobs(); Runner runner = Runner.fromString(jobProperties.getRunner()); - if (jobProperties.getOptions() == null) { - jobProperties.setOptions(new HashMap<>()); - } - Map jobOptions = jobProperties.getOptions(); + Map jobOptions = jobProperties.getRunnerOptionsMap(); switch (runner) { case DATAFLOW: if (Strings.isNullOrEmpty(jobOptions.getOrDefault("region", null)) @@ -77,7 +72,7 @@ public JobManager getJobManager( credential); return new DataflowJobManager( - dataflow, jobProperties.getOptions(), jobProperties.getMetrics()); + dataflow, jobProperties.getRunnerOptionsMap(), jobProperties.getMetrics()); } catch (IOException e) { throw new IllegalStateException( "Unable to find credential required for Dataflow monitoring API", e); @@ -88,7 +83,7 @@ public JobManager getJobManager( } case DIRECT: return new DirectRunnerJobManager( - jobProperties.getOptions(), directJobRegistry, jobProperties.getMetrics()); + jobProperties.getRunnerOptionsMap(), directJobRegistry, jobProperties.getMetrics()); default: throw new IllegalArgumentException("Unsupported runner: " + jobProperties.getRunner()); } @@ -99,10 +94,4 @@ public JobManager getJobManager( public DirectJobRegistry directJobRegistry() { return new DirectJobRegistry(); } - - /** Extracts job update options from feast core options. */ - @Bean - public JobUpdatesProperties jobUpdatesProperties(FeastProperties feastProperties) { - return feastProperties.getJobs().getUpdates(); - } } diff --git a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java index c2313d75ecc..9dc3dc0b57c 100644 --- a/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java +++ b/core/src/main/java/feast/core/job/dataflow/DataflowJobManager.java @@ -60,12 +60,12 @@ public class DataflowJobManager implements JobManager { private final MetricsProperties metrics; public DataflowJobManager( - Dataflow dataflow, Map defaultOptions, MetricsProperties metricsProperties) { - this.defaultOptions = defaultOptions; + Dataflow dataflow, Map runnerOptions, MetricsProperties metricsProperties) { + this.defaultOptions = runnerOptions; this.dataflow = dataflow; this.metrics = metricsProperties; - this.projectId = defaultOptions.get("project"); - this.location = defaultOptions.get("region"); + this.projectId = runnerOptions.get("project"); + this.location = runnerOptions.get("region"); } @Override diff --git a/core/src/main/java/feast/core/service/JobCoordinatorService.java b/core/src/main/java/feast/core/service/JobCoordinatorService.java index b66d181022e..24115883ed2 100644 --- a/core/src/main/java/feast/core/service/JobCoordinatorService.java +++ b/core/src/main/java/feast/core/service/JobCoordinatorService.java @@ -24,7 +24,8 @@ import feast.core.FeatureSetProto.FeatureSetStatus; import feast.core.StoreProto; import feast.core.StoreProto.Store.Subscription; -import feast.core.config.FeastProperties.JobUpdatesProperties; +import feast.core.config.FeastProperties; +import feast.core.config.FeastProperties.JobProperties; import feast.core.dao.FeatureSetRepository; import feast.core.dao.JobRepository; import feast.core.job.JobManager; @@ -58,7 +59,7 @@ public class JobCoordinatorService { private FeatureSetRepository featureSetRepository; private SpecService specService; private JobManager jobManager; - private JobUpdatesProperties jobUpdatesProperties; + private JobProperties jobProperties; @Autowired public JobCoordinatorService( @@ -66,12 +67,12 @@ public JobCoordinatorService( FeatureSetRepository featureSetRepository, SpecService specService, JobManager jobManager, - JobUpdatesProperties jobUpdatesProperties) { + FeastProperties feastProperties) { this.jobRepository = jobRepository; this.featureSetRepository = featureSetRepository; this.specService = specService; this.jobManager = jobManager; - this.jobUpdatesProperties = jobUpdatesProperties; + this.jobProperties = feastProperties.getJobs(); } /** @@ -121,7 +122,7 @@ public void Poll() throws InvalidProtocolBufferException { store, originalJob, jobManager, - jobUpdatesProperties.getTimeoutSeconds())); + jobProperties.getJobUpdateTimeout())); }); } } diff --git a/core/src/main/resources/application.yml b/core/src/main/resources/application.yml index ee060fffc95..84aa79a6fc4 100644 --- a/core/src/main/resources/application.yml +++ b/core/src/main/resources/application.yml @@ -23,7 +23,6 @@ grpc: enable-reflection: true feast: -# version: @project.version@ jobs: # Runner type for feature population jobs. Currently supported runner types are # DirectRunner and DataflowRunner. diff --git a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java index aa71f201dde..aed889af86e 100644 --- a/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java +++ b/core/src/test/java/feast/core/service/JobCoordinatorServiceTest.java @@ -39,7 +39,8 @@ import feast.core.StoreProto.Store.RedisConfig; import feast.core.StoreProto.Store.StoreType; import feast.core.StoreProto.Store.Subscription; -import feast.core.config.FeastProperties.JobUpdatesProperties; +import feast.core.config.FeastProperties; +import feast.core.config.FeastProperties.JobProperties; import feast.core.dao.FeatureSetRepository; import feast.core.dao.JobRepository; import feast.core.job.JobManager; @@ -65,13 +66,15 @@ public class JobCoordinatorServiceTest { @Mock SpecService specService; @Mock FeatureSetRepository featureSetRepository; - private JobUpdatesProperties jobUpdatesProperties; + private FeastProperties feastProperties; @Before public void setUp() { initMocks(this); - jobUpdatesProperties = new JobUpdatesProperties(); - jobUpdatesProperties.setTimeoutSeconds(5); + feastProperties = new FeastProperties(); + JobProperties jobProperties = new JobProperties(); + jobProperties.setJobUpdateTimeout(5); + feastProperties.setJobs(jobProperties); } @Test @@ -79,7 +82,7 @@ public void shouldDoNothingIfNoStoresFound() throws InvalidProtocolBufferExcepti when(specService.listStores(any())).thenReturn(ListStoresResponse.newBuilder().build()); JobCoordinatorService jcs = new JobCoordinatorService( - jobRepository, featureSetRepository, specService, jobManager, jobUpdatesProperties); + jobRepository, featureSetRepository, specService, jobManager, feastProperties); jcs.Poll(); verify(jobRepository, times(0)).saveAndFlush(any()); } @@ -105,7 +108,7 @@ public void shouldDoNothingIfNoMatchingFeatureSetsFound() throws InvalidProtocol .thenReturn(ListFeatureSetsResponse.newBuilder().build()); JobCoordinatorService jcs = new JobCoordinatorService( - jobRepository, featureSetRepository, specService, jobManager, jobUpdatesProperties); + jobRepository, featureSetRepository, specService, jobManager, feastProperties); jcs.Poll(); verify(jobRepository, times(0)).saveAndFlush(any()); } @@ -196,7 +199,7 @@ public void shouldGenerateAndSubmitJobsIfAny() throws InvalidProtocolBufferExcep JobCoordinatorService jcs = new JobCoordinatorService( - jobRepository, featureSetRepository, specService, jobManager, jobUpdatesProperties); + jobRepository, featureSetRepository, specService, jobManager, feastProperties); jcs.Poll(); verify(jobRepository, times(1)).saveAndFlush(jobArgCaptor.capture()); Job actual = jobArgCaptor.getValue(); @@ -318,7 +321,7 @@ public void shouldGroupJobsBySource() throws InvalidProtocolBufferException { JobCoordinatorService jcs = new JobCoordinatorService( - jobRepository, featureSetRepository, specService, jobManager, jobUpdatesProperties); + jobRepository, featureSetRepository, specService, jobManager, feastProperties); jcs.Poll(); verify(jobRepository, times(2)).saveAndFlush(jobArgCaptor.capture()); diff --git a/pom.xml b/pom.xml index 649ef01865b..8fe300e1ff7 100644 --- a/pom.xml +++ b/pom.xml @@ -486,6 +486,14 @@ true + + + build-info + + build-info + + + diff --git a/protos/feast/core/Store.proto b/protos/feast/core/Store.proto index 931a9d46b69..de9af0a99fe 100644 --- a/protos/feast/core/Store.proto +++ b/protos/feast/core/Store.proto @@ -120,6 +120,9 @@ message Store { message BigQueryConfig { string project_id = 1; string dataset_id = 2; + string staging_location = 3; + int32 initial_retry_delay_seconds = 4; + int32 total_timeout_seconds = 5; } message CassandraConfig { diff --git a/serving/pom.xml b/serving/pom.xml index 1390bfdc80c..1036f437de3 100644 --- a/serving/pom.xml +++ b/serving/pom.xml @@ -34,7 +34,7 @@ spring-plugins Spring Plugins - http://repo.spring.io/plugins-release + https://repo.spring.io/plugins-release @@ -46,6 +46,14 @@ false + + + build-info + + build-info + + + org.apache.maven.plugins @@ -76,6 +84,12 @@ ${project.version} + + dev.feast + feast-core + ${project.version} + + dev.feast feast-storage-api @@ -259,6 +273,11 @@ embedded-redis test + + org.projectlombok + lombok + compile + diff --git a/serving/src/main/java/feast/serving/FeastProperties.java b/serving/src/main/java/feast/serving/FeastProperties.java deleted file mode 100644 index 505d7d03301..00000000000 --- a/serving/src/main/java/feast/serving/FeastProperties.java +++ /dev/null @@ -1,191 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2019 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.serving; - -// Feast configuration properties that maps Feast configuration from default application.yml file to -// a Java object. -// https://www.baeldung.com/configuration-properties-in-spring-boot -// https://docs.spring.io/spring-boot/docs/current/reference/html/boot-features-external-config.html#boot-features-external-config-typesafe-configuration-properties - -import java.util.Map; -import org.springframework.boot.context.properties.ConfigurationProperties; - -@ConfigurationProperties(prefix = "feast") -public class FeastProperties { - private String version; - private String coreHost; - private int coreGrpcPort; - private StoreProperties store; - private JobProperties jobs; - private TracingProperties tracing; - - public String getVersion() { - return this.version; - } - - public String getCoreHost() { - return this.coreHost; - } - - public int getCoreGrpcPort() { - return this.coreGrpcPort; - } - - public StoreProperties getStore() { - return this.store; - } - - public JobProperties getJobs() { - return this.jobs; - } - - public TracingProperties getTracing() { - return this.tracing; - } - - public void setVersion(String version) { - this.version = version; - } - - public void setCoreHost(String coreHost) { - this.coreHost = coreHost; - } - - public void setCoreGrpcPort(int coreGrpcPort) { - this.coreGrpcPort = coreGrpcPort; - } - - public void setStore(StoreProperties store) { - this.store = store; - } - - public void setJobs(JobProperties jobs) { - this.jobs = jobs; - } - - public void setTracing(TracingProperties tracing) { - this.tracing = tracing; - } - - public static class StoreProperties { - private String configPath; - private int redisPoolMaxSize; - private int redisPoolMaxIdle; - - public String getConfigPath() { - return this.configPath; - } - - public int getRedisPoolMaxSize() { - return this.redisPoolMaxSize; - } - - public int getRedisPoolMaxIdle() { - return this.redisPoolMaxIdle; - } - - public void setConfigPath(String configPath) { - this.configPath = configPath; - } - - public void setRedisPoolMaxSize(int redisPoolMaxSize) { - this.redisPoolMaxSize = redisPoolMaxSize; - } - - public void setRedisPoolMaxIdle(int redisPoolMaxIdle) { - this.redisPoolMaxIdle = redisPoolMaxIdle; - } - } - - public static class JobProperties { - private String stagingLocation; - private int bigqueryInitialRetryDelaySecs; - private int bigqueryTotalTimeoutSecs; - private String storeType; - private Map storeOptions; - - public String getStagingLocation() { - return this.stagingLocation; - } - - public int getBigqueryInitialRetryDelaySecs() { - return bigqueryInitialRetryDelaySecs; - } - - public int getBigqueryTotalTimeoutSecs() { - return bigqueryTotalTimeoutSecs; - } - - public String getStoreType() { - return this.storeType; - } - - public Map getStoreOptions() { - return this.storeOptions; - } - - public void setStagingLocation(String stagingLocation) { - this.stagingLocation = stagingLocation; - } - - public void setBigqueryInitialRetryDelaySecs(int bigqueryInitialRetryDelaySecs) { - this.bigqueryInitialRetryDelaySecs = bigqueryInitialRetryDelaySecs; - } - - public void setBigqueryTotalTimeoutSecs(int bigqueryTotalTimeoutSecs) { - this.bigqueryTotalTimeoutSecs = bigqueryTotalTimeoutSecs; - } - - public void setStoreType(String storeType) { - this.storeType = storeType; - } - - public void setStoreOptions(Map storeOptions) { - this.storeOptions = storeOptions; - } - } - - public static class TracingProperties { - private boolean enabled; - private String tracerName; - private String serviceName; - - public boolean isEnabled() { - return this.enabled; - } - - public String getTracerName() { - return this.tracerName; - } - - public String getServiceName() { - return this.serviceName; - } - - public void setEnabled(boolean enabled) { - this.enabled = enabled; - } - - public void setTracerName(String tracerName) { - this.tracerName = tracerName; - } - - public void setServiceName(String serviceName) { - this.serviceName = serviceName; - } - } -} diff --git a/serving/src/main/java/feast/serving/ServingApplication.java b/serving/src/main/java/feast/serving/ServingApplication.java index ae9bb87a0b5..064f7b3e8d8 100644 --- a/serving/src/main/java/feast/serving/ServingApplication.java +++ b/serving/src/main/java/feast/serving/ServingApplication.java @@ -16,6 +16,7 @@ */ package feast.serving; +import feast.serving.config.FeastProperties; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.context.properties.EnableConfigurationProperties; diff --git a/serving/src/main/java/feast/serving/configuration/ContextClosedHandler.java b/serving/src/main/java/feast/serving/config/ContextClosedHandler.java similarity index 96% rename from serving/src/main/java/feast/serving/configuration/ContextClosedHandler.java rename to serving/src/main/java/feast/serving/config/ContextClosedHandler.java index a4f6d64d84f..2bc97439f38 100644 --- a/serving/src/main/java/feast/serving/configuration/ContextClosedHandler.java +++ b/serving/src/main/java/feast/serving/config/ContextClosedHandler.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; import java.util.concurrent.ScheduledExecutorService; import org.springframework.beans.factory.annotation.Autowired; diff --git a/serving/src/main/java/feast/serving/config/FeastProperties.java b/serving/src/main/java/feast/serving/config/FeastProperties.java new file mode 100644 index 00000000000..088eef958d5 --- /dev/null +++ b/serving/src/main/java/feast/serving/config/FeastProperties.java @@ -0,0 +1,283 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * Copyright 2018-2019 The Feast Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package feast.serving.config; + +// Feast configuration properties that maps Feast configuration from default application.yml file to +// a Java object. +// https://www.baeldung.com/configuration-properties-in-spring-boot +// https://docs.spring.io/spring-boot/docs/current/reference/html/boot-features-external-config.html#boot-features-external-config-typesafe-configuration-properties + +import javax.validation.constraints.NotBlank; +import javax.validation.constraints.Positive; +import org.apache.logging.log4j.core.config.plugins.validation.constraints.ValidHost; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.boot.info.BuildProperties; + +/** Feast Serving properties. */ +@ConfigurationProperties(prefix = "feast", ignoreInvalidFields = true) +public class FeastProperties { + + /** + * Instantiates a new Feast Serving properties. + * + * @param buildProperties the build properties + */ + @Autowired + public FeastProperties(BuildProperties buildProperties) { + setVersion(buildProperties.getVersion()); + } + + public FeastProperties() {} + + /* Feast Serving build version */ + @NotBlank private String version = "unknown"; + + /* Feast Core host to connect to. */ + @ValidHost @NotBlank private String coreHost; + + /* Feast Core port to connect to. */ + @Positive private int coreGrpcPort; + + /** + * The "store" string should contain a YAML representation of the store configuration. Store + * configurations can be seen in protos/feast/core/Store.proto + */ + private feast.core.model.Store store; + + /* Job Store properties to retain state of async jobs. */ + private JobStoreProperties jobStore; + + /* Metric tracing properties. */ + private TracingProperties tracing; + + /** + * Gets Serving store configuration deserialiazed as a {@link feast.core.model.Store}. + * + * @return the store + */ + public feast.core.model.Store getStore() { + return store; + } + + /** + * Gets Feast Serving build version. + * + * @return the build version + */ + public String getVersion() { + return version; + } + + /** + * Sets build version + * + * @param version the build version + */ + public void setVersion(String version) { + this.version = version; + } + + /** + * Gets Feast Core host. + * + * @return Feast Core host + */ + public String getCoreHost() { + return coreHost; + } + + /** + * Sets Feast Core host to connect to. + * + * @param coreHost Feast Core host + */ + public void setCoreHost(String coreHost) { + this.coreHost = coreHost; + } + + /** + * Gets Feast Core gRPC port. + * + * @return Port + */ + public int getCoreGrpcPort() { + return coreGrpcPort; + } + + /** + * Sets Feast Core gRPC port. + * + * @param coreGrpcPort gRPC port of Feast Core + */ + public void setCoreGrpcPort(int coreGrpcPort) { + this.coreGrpcPort = coreGrpcPort; + } + + /** + * Sets store properties. + * + * @param store properties comes from a YAML string + */ + public void setStore(feast.core.model.Store store) { + this.store = store; + } + + /** + * Gets job store properties + * + * @return the job store properties + */ + public JobStoreProperties getJobStore() { + return jobStore; + } + + /** + * Set job store properties + * + * @param jobStore Job store properties to set + */ + public void setJobStore(JobStoreProperties jobStore) { + this.jobStore = jobStore; + } + + /** + * Gets tracing properties + * + * @return tracing properties + */ + public TracingProperties getTracing() { + return tracing; + } + + public void setTracing(TracingProperties tracing) { + this.tracing = tracing; + } + + /** The type Job store properties. */ + public static class JobStoreProperties { + + /** Job Store Redis Host */ + private String redisHost; + + /** Job Store Redis Host */ + private int redisPort; + + /** + * Gets redis host. + * + * @return the redis host + */ + public String getRedisHost() { + return redisHost; + } + + /** + * Sets redis host. + * + * @param redisHost the redis host + */ + public void setRedisHost(String redisHost) { + this.redisHost = redisHost; + } + + /** + * Gets redis port. + * + * @return the redis port + */ + public int getRedisPort() { + return redisPort; + } + + /** + * Sets redis port. + * + * @param redisPort the redis port + */ + public void setRedisPort(int redisPort) { + this.redisPort = redisPort; + } + } + + /** Trace metric collection properties */ + public static class TracingProperties { + + /** Tracing enabled/disabled */ + private boolean enabled; + + /** Name of tracer to use (only "jaeger") */ + private String tracerName; + + /** Service name uniquely identifies this Feast Serving deployment */ + private String serviceName; + + /** + * Is tracing enabled + * + * @return boolean flag + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Sets tracing enabled or disabled. + * + * @param enabled flag + */ + public void setEnabled(boolean enabled) { + this.enabled = enabled; + } + + /** + * Gets tracer name ('jaeger') + * + * @return the tracer name + */ + public String getTracerName() { + return tracerName; + } + + /** + * Sets tracer name. + * + * @param tracerName the tracer name + */ + public void setTracerName(String tracerName) { + this.tracerName = tracerName; + } + + /** + * Gets the service name. The service name uniquely identifies this Feast serving instance. + * + * @return the service name + */ + public String getServiceName() { + return serviceName; + } + + /** + * Sets service name. + * + * @param serviceName the service name + */ + public void setServiceName(String serviceName) { + this.serviceName = serviceName; + } + } +} diff --git a/serving/src/main/java/feast/serving/configuration/InstrumentationConfig.java b/serving/src/main/java/feast/serving/config/InstrumentationConfig.java similarity index 96% rename from serving/src/main/java/feast/serving/configuration/InstrumentationConfig.java rename to serving/src/main/java/feast/serving/config/InstrumentationConfig.java index 2cd284829c4..30269c5d0ec 100644 --- a/serving/src/main/java/feast/serving/configuration/InstrumentationConfig.java +++ b/serving/src/main/java/feast/serving/config/InstrumentationConfig.java @@ -14,9 +14,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; -import feast.serving.FeastProperties; import io.opentracing.Tracer; import io.opentracing.noop.NoopTracerFactory; import io.prometheus.client.exporter.MetricsServlet; diff --git a/serving/src/main/java/feast/serving/configuration/JobServiceConfig.java b/serving/src/main/java/feast/serving/config/JobServiceConfig.java similarity index 62% rename from serving/src/main/java/feast/serving/configuration/JobServiceConfig.java rename to serving/src/main/java/feast/serving/config/JobServiceConfig.java index fa94dab8329..f94a9c28c6e 100644 --- a/serving/src/main/java/feast/serving/configuration/JobServiceConfig.java +++ b/serving/src/main/java/feast/serving/config/JobServiceConfig.java @@ -14,10 +14,9 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; import feast.core.StoreProto.Store.StoreType; -import feast.serving.FeastProperties; import feast.serving.service.JobService; import feast.serving.service.NoopJobService; import feast.serving.service.RedisBackedJobService; @@ -29,24 +28,10 @@ public class JobServiceConfig { @Bean - public JobService jobService( - FeastProperties feastProperties, - CachedSpecService specService, - StoreConfiguration storeConfiguration) { + public JobService jobService(CachedSpecService specService, JobStoreConfig jobStoreConfig) { if (!specService.getStore().getType().equals(StoreType.BIGQUERY)) { return new NoopJobService(); } - StoreType storeType = StoreType.valueOf(feastProperties.getJobs().getStoreType()); - switch (storeType) { - case REDIS: - return new RedisBackedJobService(storeConfiguration.getJobStoreRedisConnection()); - case INVALID: - case BIGQUERY: - case CASSANDRA: - case UNRECOGNIZED: - default: - throw new IllegalArgumentException( - String.format("Unsupported store type '%s' for job store", storeType)); - } + return new RedisBackedJobService(jobStoreConfig); } } diff --git a/serving/src/main/java/feast/serving/configuration/StoreConfiguration.java b/serving/src/main/java/feast/serving/config/JobStoreConfig.java similarity index 57% rename from serving/src/main/java/feast/serving/configuration/StoreConfiguration.java rename to serving/src/main/java/feast/serving/config/JobStoreConfig.java index 84dc7b7f8d4..02bef55ddae 100644 --- a/serving/src/main/java/feast/serving/configuration/StoreConfiguration.java +++ b/serving/src/main/java/feast/serving/config/JobStoreConfig.java @@ -14,31 +14,30 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; -import org.springframework.beans.factory.ObjectProvider; +import io.lettuce.core.codec.ByteArrayCodec; +import io.lettuce.core.resource.DefaultClientResources; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; @Configuration -public class StoreConfiguration { +public class JobStoreConfig { - // We can define other store specific beans here - // These beans can be autowired or can be created in this class. - private final StatefulRedisConnection servingRedisConnection; private final StatefulRedisConnection jobStoreRedisConnection; @Autowired - public StoreConfiguration( - ObjectProvider> servingRedisConnection, - ObjectProvider> jobStoreRedisConnection) { - this.servingRedisConnection = servingRedisConnection.getIfAvailable(); - this.jobStoreRedisConnection = jobStoreRedisConnection.getIfAvailable(); - } + public JobStoreConfig(FeastProperties feastProperties) { + RedisURI uri = + RedisURI.create( + feastProperties.getJobStore().getRedisHost(), + feastProperties.getJobStore().getRedisPort()); - public StatefulRedisConnection getServingRedisConnection() { - return servingRedisConnection; + jobStoreRedisConnection = + RedisClient.create(DefaultClientResources.create(), uri).connect(new ByteArrayCodec()); } public StatefulRedisConnection getJobStoreRedisConnection() { diff --git a/serving/src/main/java/feast/serving/configuration/ServingApiConfiguration.java b/serving/src/main/java/feast/serving/config/ServingApiConfiguration.java similarity index 97% rename from serving/src/main/java/feast/serving/configuration/ServingApiConfiguration.java rename to serving/src/main/java/feast/serving/config/ServingApiConfiguration.java index 539b25a0fcd..ce4fe134373 100644 --- a/serving/src/main/java/feast/serving/configuration/ServingApiConfiguration.java +++ b/serving/src/main/java/feast/serving/config/ServingApiConfiguration.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; import java.util.List; import org.springframework.beans.factory.annotation.Autowired; diff --git a/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java b/serving/src/main/java/feast/serving/config/ServingServiceConfig.java similarity index 64% rename from serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java rename to serving/src/main/java/feast/serving/config/ServingServiceConfig.java index 28df853e224..376f26f81af 100644 --- a/serving/src/main/java/feast/serving/configuration/ServingServiceConfig.java +++ b/serving/src/main/java/feast/serving/config/ServingServiceConfig.java @@ -14,25 +14,26 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; import com.google.cloud.bigquery.BigQuery; import com.google.cloud.bigquery.BigQueryOptions; import com.google.cloud.storage.Storage; import com.google.cloud.storage.StorageOptions; -import feast.core.StoreProto.Store; +import com.google.protobuf.InvalidProtocolBufferException; +import feast.core.StoreProto; import feast.core.StoreProto.Store.BigQueryConfig; -import feast.core.StoreProto.Store.RedisConfig; -import feast.core.StoreProto.Store.Subscription; -import feast.serving.FeastProperties; -import feast.serving.service.*; +import feast.serving.service.HistoricalServingService; +import feast.serving.service.JobService; +import feast.serving.service.NoopJobService; +import feast.serving.service.OnlineServingService; +import feast.serving.service.ServingService; import feast.serving.specs.CachedSpecService; import feast.storage.api.retriever.HistoricalRetriever; import feast.storage.api.retriever.OnlineRetriever; import feast.storage.connectors.bigquery.retriever.BigQueryHistoricalRetriever; import feast.storage.connectors.redis.retriever.RedisOnlineRetriever; import io.opentracing.Tracer; -import java.util.Map; import org.slf4j.Logger; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -42,52 +43,27 @@ public class ServingServiceConfig { private static final Logger log = org.slf4j.LoggerFactory.getLogger(ServingServiceConfig.class); - private Store setStoreConfig(Store.Builder builder, Map options) { - switch (builder.getType()) { - case REDIS: - RedisConfig redisConfig = - RedisConfig.newBuilder() - .setHost(options.get("host")) - .setPort(Integer.parseInt(options.get("port"))) - .build(); - return builder.setRedisConfig(redisConfig).build(); - case BIGQUERY: - BigQueryConfig bqConfig = - BigQueryConfig.newBuilder() - .setProjectId(options.get("projectId")) - .setDatasetId(options.get("datasetId")) - .build(); - return builder.setBigqueryConfig(bqConfig).build(); - case CASSANDRA: - default: - throw new IllegalArgumentException( - String.format( - "Unsupported store %s provided, only REDIS or BIGQUERY are currently supported.", - builder.getType())); - } - } - @Bean public ServingService servingService( FeastProperties feastProperties, CachedSpecService specService, JobService jobService, - Tracer tracer, - StoreConfiguration storeConfiguration) { + Tracer tracer) + throws InvalidProtocolBufferException { ServingService servingService = null; - Store store = specService.getStore(); + StoreProto.Store store = feastProperties.getStore().toProto(); switch (store.getType()) { case REDIS: - OnlineRetriever redisRetriever = - new RedisOnlineRetriever(storeConfiguration.getServingRedisConnection()); + OnlineRetriever redisRetriever = new RedisOnlineRetriever(store.getRedisConfig()); servingService = new OnlineServingService(redisRetriever, specService, tracer); break; case BIGQUERY: BigQueryConfig bqConfig = store.getBigqueryConfig(); + String jobStagingLocation = bqConfig.getStagingLocation(); BigQuery bigquery = BigQueryOptions.getDefaultInstance().getService(); Storage storage = StorageOptions.getDefaultInstance().getService(); - String jobStagingLocation = feastProperties.getJobs().getStagingLocation(); + if (!jobStagingLocation.contains("://")) { throw new IllegalArgumentException( String.format("jobStagingLocation is not a valid URI: %s", jobStagingLocation)); @@ -110,10 +86,9 @@ public ServingService servingService( .setBigquery(bigquery) .setDatasetId(bqConfig.getDatasetId()) .setProjectId(bqConfig.getProjectId()) - .setJobStagingLocation(jobStagingLocation) - .setInitialRetryDelaySecs( - feastProperties.getJobs().getBigqueryInitialRetryDelaySecs()) - .setTotalTimeoutSecs(feastProperties.getJobs().getBigqueryTotalTimeoutSecs()) + .setJobStagingLocation(bqConfig.getStagingLocation()) + .setInitialRetryDelaySecs(bqConfig.getInitialRetryDelaySeconds()) + .setTotalTimeoutSecs(bqConfig.getTotalTimeoutSeconds()) .setStorage(storage) .build(); @@ -130,9 +105,4 @@ public ServingService servingService( return servingService; } - - private Subscription parseSubscription(String subscription) { - String[] split = subscription.split(":"); - return Subscription.newBuilder().setName(split[0]).setVersion(split[1]).build(); - } } diff --git a/serving/src/main/java/feast/serving/configuration/SpecServiceConfig.java b/serving/src/main/java/feast/serving/config/SpecServiceConfig.java similarity index 90% rename from serving/src/main/java/feast/serving/configuration/SpecServiceConfig.java rename to serving/src/main/java/feast/serving/config/SpecServiceConfig.java index 26ebfa956ca..2682e176d35 100644 --- a/serving/src/main/java/feast/serving/configuration/SpecServiceConfig.java +++ b/serving/src/main/java/feast/serving/config/SpecServiceConfig.java @@ -14,13 +14,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package feast.serving.configuration; +package feast.serving.config; -import feast.serving.FeastProperties; +import com.google.protobuf.InvalidProtocolBufferException; +import feast.core.StoreProto; import feast.serving.specs.CachedSpecService; import feast.serving.specs.CoreSpecService; -import java.nio.file.Path; -import java.nio.file.Paths; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; @@ -58,10 +57,11 @@ public ScheduledExecutorService cachedSpecServiceScheduledExecutorService( } @Bean - public CachedSpecService specService(FeastProperties feastProperties) { + public CachedSpecService specService(FeastProperties feastProperties) + throws InvalidProtocolBufferException { CoreSpecService coreService = new CoreSpecService(feastCoreHost, feastCorePort); - Path path = Paths.get(feastProperties.getStore().getConfigPath()); - CachedSpecService cachedSpecStorage = new CachedSpecService(coreService, path); + StoreProto.Store storeProto = feastProperties.getStore().toProto(); + CachedSpecService cachedSpecStorage = new CachedSpecService(coreService, storeProto); try { cachedSpecStorage.populateCache(); } catch (Exception e) { diff --git a/serving/src/main/java/feast/serving/configuration/redis/JobStoreRedisConfig.java b/serving/src/main/java/feast/serving/configuration/redis/JobStoreRedisConfig.java deleted file mode 100644 index 77d9262bcb3..00000000000 --- a/serving/src/main/java/feast/serving/configuration/redis/JobStoreRedisConfig.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.serving.configuration.redis; - -import com.google.common.base.Enums; -import feast.core.StoreProto; -import feast.serving.FeastProperties; -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.codec.ByteArrayCodec; -import io.lettuce.core.resource.ClientResources; -import io.lettuce.core.resource.DefaultClientResources; -import java.util.Map; -import org.springframework.beans.factory.ObjectProvider; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -@Configuration -public class JobStoreRedisConfig { - - @Bean(destroyMethod = "shutdown") - ClientResources jobStoreClientResources() { - return DefaultClientResources.create(); - } - - @Bean(destroyMethod = "shutdown") - RedisClient jobStoreRedisClient( - ClientResources jobStoreClientResources, FeastProperties feastProperties) { - StoreProto.Store.StoreType storeType = - Enums.getIfPresent( - StoreProto.Store.StoreType.class, feastProperties.getJobs().getStoreType()) - .orNull(); - if (storeType != StoreProto.Store.StoreType.REDIS) return null; - Map jobStoreConf = feastProperties.getJobs().getStoreOptions(); - // If job conf is empty throw StoreException - if (jobStoreConf == null - || jobStoreConf.get("host") == null - || jobStoreConf.get("host").isEmpty() - || jobStoreConf.get("port") == null - || jobStoreConf.get("port").isEmpty()) - throw new IllegalArgumentException("Store Configuration is not set"); - RedisURI uri = - RedisURI.create(jobStoreConf.get("host"), Integer.parseInt(jobStoreConf.get("port"))); - return RedisClient.create(jobStoreClientResources, uri); - } - - @Bean(destroyMethod = "close") - StatefulRedisConnection jobStoreRedisConnection( - ObjectProvider jobStoreRedisClient) { - if (jobStoreRedisClient.getIfAvailable() == null) return null; - return jobStoreRedisClient.getIfAvailable().connect(new ByteArrayCodec()); - } -} diff --git a/serving/src/main/java/feast/serving/configuration/redis/ServingStoreRedisConfig.java b/serving/src/main/java/feast/serving/configuration/redis/ServingStoreRedisConfig.java deleted file mode 100644 index 17a50eef6d6..00000000000 --- a/serving/src/main/java/feast/serving/configuration/redis/ServingStoreRedisConfig.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * Copyright 2018-2020 The Feast Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package feast.serving.configuration.redis; - -import feast.core.StoreProto; -import feast.serving.specs.CachedSpecService; -import io.lettuce.core.RedisClient; -import io.lettuce.core.RedisURI; -import io.lettuce.core.api.StatefulRedisConnection; -import io.lettuce.core.codec.ByteArrayCodec; -import io.lettuce.core.resource.ClientResources; -import io.lettuce.core.resource.DefaultClientResources; -import org.springframework.beans.factory.ObjectProvider; -import org.springframework.context.annotation.*; - -@Configuration -public class ServingStoreRedisConfig { - - @Bean - StoreProto.Store.RedisConfig servingStoreRedisConf(CachedSpecService specService) { - if (specService.getStore().getType() != StoreProto.Store.StoreType.REDIS) return null; - return specService.getStore().getRedisConfig(); - } - - @Bean(destroyMethod = "shutdown") - ClientResources servingClientResources() { - return DefaultClientResources.create(); - } - - @Bean(destroyMethod = "shutdown") - RedisClient servingRedisClient( - ClientResources servingClientResources, - ObjectProvider servingStoreRedisConf) { - if (servingStoreRedisConf.getIfAvailable() == null) return null; - RedisURI redisURI = - RedisURI.create( - servingStoreRedisConf.getIfAvailable().getHost(), - servingStoreRedisConf.getIfAvailable().getPort()); - return RedisClient.create(servingClientResources, redisURI); - } - - @Bean(destroyMethod = "close") - StatefulRedisConnection servingRedisConnection( - ObjectProvider servingRedisClient) { - if (servingRedisClient.getIfAvailable() == null) return null; - return servingRedisClient.getIfAvailable().connect(new ByteArrayCodec()); - } -} diff --git a/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java b/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java index cc1f856d728..d4f220bdaa6 100644 --- a/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java +++ b/serving/src/main/java/feast/serving/controller/ServingServiceGRpcController.java @@ -16,7 +16,6 @@ */ package feast.serving.controller; -import feast.serving.FeastProperties; import feast.serving.ServingAPIProto.GetBatchFeaturesRequest; import feast.serving.ServingAPIProto.GetBatchFeaturesResponse; import feast.serving.ServingAPIProto.GetFeastServingInfoRequest; @@ -26,6 +25,7 @@ import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest; import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; import feast.serving.ServingServiceGrpc.ServingServiceImplBase; +import feast.serving.config.FeastProperties; import feast.serving.interceptors.GrpcMonitoringInterceptor; import feast.serving.service.ServingService; import feast.serving.util.RequestHelper; diff --git a/serving/src/main/java/feast/serving/controller/ServingServiceRestController.java b/serving/src/main/java/feast/serving/controller/ServingServiceRestController.java index b0e349fd6b0..344ab7cf3ae 100644 --- a/serving/src/main/java/feast/serving/controller/ServingServiceRestController.java +++ b/serving/src/main/java/feast/serving/controller/ServingServiceRestController.java @@ -18,11 +18,11 @@ import static feast.serving.util.mappers.ResponseJSONMapper.mapGetOnlineFeaturesResponse; -import feast.serving.FeastProperties; import feast.serving.ServingAPIProto.GetFeastServingInfoRequest; import feast.serving.ServingAPIProto.GetFeastServingInfoResponse; import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest; import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.serving.config.FeastProperties; import feast.serving.service.ServingService; import feast.serving.util.RequestHelper; import io.opentracing.Tracer; diff --git a/serving/src/main/java/feast/serving/service/RedisBackedJobService.java b/serving/src/main/java/feast/serving/service/RedisBackedJobService.java index 0bf53630379..99933585b27 100644 --- a/serving/src/main/java/feast/serving/service/RedisBackedJobService.java +++ b/serving/src/main/java/feast/serving/service/RedisBackedJobService.java @@ -19,6 +19,7 @@ import com.google.protobuf.util.JsonFormat; import feast.serving.ServingAPIProto.Job; import feast.serving.ServingAPIProto.Job.Builder; +import feast.serving.config.JobStoreConfig; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisCommands; import java.util.Optional; @@ -37,6 +38,10 @@ public class RedisBackedJobService implements JobService { // and since users normally don't require info about relatively old jobs. private final int defaultExpirySeconds = (int) Duration.standardDays(1).getStandardSeconds(); + public RedisBackedJobService(JobStoreConfig jobStoreConfig) { + this.syncCommand = jobStoreConfig.getJobStoreRedisConnection().sync(); + } + public RedisBackedJobService(StatefulRedisConnection connection) { this.syncCommand = connection.sync(); } diff --git a/serving/src/main/java/feast/serving/specs/CachedSpecService.java b/serving/src/main/java/feast/serving/specs/CachedSpecService.java index 47f4934d52c..cf5eeabc7e3 100644 --- a/serving/src/main/java/feast/serving/specs/CachedSpecService.java +++ b/serving/src/main/java/feast/serving/specs/CachedSpecService.java @@ -18,7 +18,6 @@ import static feast.serving.util.RefUtil.generateFeatureSetStringRef; import static feast.serving.util.RefUtil.generateFeatureStringRef; -import static feast.serving.util.mappers.YamlToProtoMapper.yamlToStoreProto; import static java.util.Comparator.comparingInt; import static java.util.stream.Collectors.groupingBy; @@ -27,11 +26,10 @@ import com.google.common.cache.LoadingCache; import feast.core.CoreServiceProto.ListFeatureSetsRequest; import feast.core.CoreServiceProto.ListFeatureSetsResponse; -import feast.core.CoreServiceProto.UpdateStoreRequest; -import feast.core.CoreServiceProto.UpdateStoreResponse; import feast.core.FeatureSetProto.FeatureSet; import feast.core.FeatureSetProto.FeatureSetSpec; import feast.core.FeatureSetProto.FeatureSpec; +import feast.core.StoreProto; import feast.core.StoreProto.Store; import feast.core.StoreProto.Store.Subscription; import feast.serving.ServingAPIProto.FeatureReference; @@ -39,9 +37,6 @@ import feast.storage.api.retriever.FeatureSetRequest; import io.grpc.StatusRuntimeException; import io.prometheus.client.Gauge; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Path; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -59,7 +54,6 @@ public class CachedSpecService { private static final Logger log = org.slf4j.LoggerFactory.getLogger(CachedSpecService.class); private final CoreSpecService coreService; - private final Path configPath; private final Map featureToFeatureSetMapping; @@ -80,10 +74,9 @@ public class CachedSpecService { .help("epoch time of the last time the cache was updated") .register(); - public CachedSpecService(CoreSpecService coreService, Path configPath) { - this.configPath = configPath; + public CachedSpecService(CoreSpecService coreService, StoreProto.Store store) { this.coreService = coreService; - this.store = updateStore(readConfig(configPath)); + this.store = store; Map featureSets = getFeatureSetMap(); featureToFeatureSetMapping = @@ -152,7 +145,6 @@ public List getFeatureSets(List featureRefe * from core to preload the cache. */ public void populateCache() { - this.store = updateStore(readConfig(configPath)); Map featureSetMap = getFeatureSetMap(); featureSetCache.putAll(featureSetMap); featureToFeatureSetMapping.putAll(getFeatureToFeatureSetMapping(featureSetMap)); @@ -235,29 +227,4 @@ private Map getFeatureToFeatureSetMapping( }); return mapping; } - - private Store readConfig(Path path) { - try { - List fileContents = Files.readAllLines(path); - String yaml = fileContents.stream().reduce("", (l1, l2) -> l1 + "\n" + l2); - log.info("loaded store config at {}: \n{}", path.toString(), yaml); - return yamlToStoreProto(yaml); - } catch (IOException e) { - throw new RuntimeException( - String.format("Unable to read store config at %s", path.toAbsolutePath()), e); - } - } - - private Store updateStore(Store store) { - UpdateStoreRequest request = UpdateStoreRequest.newBuilder().setStore(store).build(); - try { - UpdateStoreResponse updateStoreResponse = coreService.updateStore(request); - if (!updateStoreResponse.getStore().equals(store)) { - throw new RuntimeException("Core store config not matching current store config"); - } - return updateStoreResponse.getStore(); - } catch (Exception e) { - throw new RuntimeException("Unable to update store configuration", e); - } - } } diff --git a/serving/src/main/java/feast/serving/util/mappers/YamlToProtoMapper.java b/serving/src/main/java/feast/serving/util/mappers/YamlToProtoMapper.java index 00ad1fabb1c..784a21552bd 100644 --- a/serving/src/main/java/feast/serving/util/mappers/YamlToProtoMapper.java +++ b/serving/src/main/java/feast/serving/util/mappers/YamlToProtoMapper.java @@ -19,19 +19,29 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; import com.google.protobuf.util.JsonFormat; +import feast.core.StoreProto; import feast.core.StoreProto.Store; import feast.core.StoreProto.Store.Builder; import java.io.IOException; +import org.slf4j.Logger; public class YamlToProtoMapper { + + private static final Logger log = org.slf4j.LoggerFactory.getLogger(YamlToProtoMapper.class); + private static final ObjectMapper yamlReader = new ObjectMapper(new YAMLFactory()); private static final ObjectMapper jsonWriter = new ObjectMapper(); - public static Store yamlToStoreProto(String yaml) throws IOException { - Object obj = yamlReader.readValue(yaml, Object.class); - String jsonString = jsonWriter.writeValueAsString(obj); - Builder builder = Store.newBuilder(); - JsonFormat.parser().merge(jsonString, builder); - return builder.build(); + public static Store yamlToStoreProto(String yaml) { + try { + Object obj = yamlReader.readValue(yaml, Object.class); + String jsonString = jsonWriter.writeValueAsString(obj); + Builder builder = Store.newBuilder(); + JsonFormat.parser().merge(jsonString, builder); + return builder.build(); + } catch (IOException e) { + log.error("Could not parse store configuration YAML", e); + return StoreProto.Store.getDefaultInstance(); + } } } diff --git a/serving/src/main/resources/application.yml b/serving/src/main/resources/application.yml index 96713c80287..264a98320f9 100644 --- a/serving/src/main/resources/application.yml +++ b/serving/src/main/resources/application.yml @@ -1,7 +1,4 @@ feast: - # This value is retrieved from project.version properties in pom.xml - # https://docs.spring.io/spring-boot/docs/current/reference/html/ - version: @project.version@ # GRPC service address for Feast Core # Feast Serving requires connection to Feast Core to retrieve and reload Feast metadata (e.g. FeatureSpecs, Store information) core-host: ${FEAST_CORE_HOST:localhost} @@ -18,40 +15,38 @@ feast: service-name: feast_serving store: - # Path containing the store configuration for this serving store. - config-path: ${FEAST_STORE_CONFIG_PATH:serving/sample_redis_config.yml} - # If serving redis, the redis pool max size - redis-pool-max-size: ${FEAST_REDIS_POOL_MAX_SIZE:128} - # If serving redis, the redis pool max idle conns - redis-pool-max-idle: ${FEAST_REDIS_POOL_MAX_IDLE:16} + name: serving + type: REDIS # Alternative, BIGQUERY + redis_config: + host: localhost + port: 6379 + bigquery_config: + # GCP Project + project_id: my_project - jobs: - # staging-location specifies the URI to store intermediate files for batch serving. - # Feast Serving client is expected to have read access to this staging location - # to download the batch features. - # - # For example: gs://mybucket/myprefix - # Please omit the trailing slash in the URI. - staging-location: ${FEAST_JOB_STAGING_LOCATION:} - # - # Retry options for BigQuery jobs: - bigquery-initial-retry-delay-secs: 1 - bigquery-total-timeout-secs: 21600 - # - # Type of store to store job metadata. This only needs to be set if the - # serving store type is Bigquery. - store-type: ${FEAST_JOB_STORE_TYPE:} - # - # Job store connection options. If the job store is redis, the following items are required: - # - # store-options: - # host: localhost - # port: 6379 - # Optionally, you can configure the connection pool with the following items: - # max-conn: 8 - # max-idle: 8 - # max-wait-millis: 50 - store-options: {} + # BigQuery Dataset Id + dataset_id: my_dataset + + # staging-location specifies the URI to store intermediate files for batch serving. + # Feast Serving client is expected to have read access to this staging location + # to download the batch features. + # For example: gs://mybucket/myprefix + # Please omit the trailing slash in the URI. + staging-location: ${FEAST_JOB_STAGING_LOCATION:} + + # Retry options for BigQuery retrieval jobs + bigquery-initial-retry-delay-secs: 1 + + # BigQuery timeout for retrieval jobs + bigquery-total-timeout-secs: 21600 + subscriptions: + - name: "*" + project: "*" + version: "*" + + job_store: + redis_host: localhost + redis_port: 6379 grpc: # The port number Feast Serving GRPC service should listen on diff --git a/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java b/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java index f2c51bc7dde..d23f9da1d25 100644 --- a/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java +++ b/serving/src/test/java/feast/serving/controller/ServingServiceGRpcControllerTest.java @@ -19,11 +19,11 @@ import static org.mockito.MockitoAnnotations.initMocks; import com.google.protobuf.Timestamp; -import feast.serving.FeastProperties; import feast.serving.ServingAPIProto.FeatureReference; import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest; import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; import feast.serving.ServingAPIProto.GetOnlineFeaturesResponse; +import feast.serving.config.FeastProperties; import feast.serving.service.ServingService; import feast.types.ValueProto.Value; import io.grpc.StatusRuntimeException; diff --git a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java index 01c9304bda0..580b45a224b 100644 --- a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java +++ b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java @@ -46,7 +46,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -55,7 +54,6 @@ public class CachedSpecServiceTest { - private File configFile; private Store store; @Rule public final ExpectedException expectedException = ExpectedException.none(); @@ -66,27 +64,9 @@ public class CachedSpecServiceTest { private CachedSpecService cachedSpecService; @Before - public void setUp() throws IOException { + public void setUp() { initMocks(this); - configFile = File.createTempFile("serving", ".yml"); - String yamlString = - "name: SERVING\n" - + "type: REDIS\n" - + "redis_config:\n" - + " host: localhost\n" - + " port: 6379\n" - + "subscriptions:\n" - + "- project: project\n" - + " name: fs1\n" - + " version: \"*\"\n" - + "- project: project\n" - + " name: fs2\n" - + " version: \"*\""; - BufferedWriter writer = new BufferedWriter(new FileWriter(configFile)); - writer.write(yamlString); - writer.close(); - store = Store.newBuilder() .setName("SERVING") @@ -164,12 +144,7 @@ public void setUp() throws IOException { .build())) .thenReturn(ListFeatureSetsResponse.newBuilder().addAllFeatureSets(fs2FeatureSets).build()); - cachedSpecService = new CachedSpecService(coreService, configFile.toPath()); - } - - @After - public void tearDown() { - configFile.delete(); + cachedSpecService = new CachedSpecService(coreService, store); } @Test diff --git a/serving/src/test/java/feast/serving/service/RedisBackedJobServiceTest.java b/serving/src/test/java/feast/serving/service/RedisBackedJobServiceTest.java index 34bc31d2c26..23626c2cb85 100644 --- a/serving/src/test/java/feast/serving/service/RedisBackedJobServiceTest.java +++ b/serving/src/test/java/feast/serving/service/RedisBackedJobServiceTest.java @@ -26,6 +26,7 @@ import redis.embedded.RedisServer; public class RedisBackedJobServiceTest { + private static Integer REDIS_PORT = 51235; private RedisServer redis; @@ -41,7 +42,7 @@ public void teardown() { } @Test - public void shouldRecoverIfRedisConnectionIsLost() throws IOException { + public void shouldRecoverIfRedisConnectionIsLost() { RedisClient client = RedisClient.create(RedisURI.create("localhost", REDIS_PORT)); RedisBackedJobService jobService = new RedisBackedJobService(client.connect(new ByteArrayCodec())); diff --git a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java index c8bb33de5fd..99de7f9112c 100644 --- a/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java +++ b/storage/connectors/redis/src/main/java/feast/storage/connectors/redis/retriever/RedisOnlineRetriever.java @@ -20,6 +20,7 @@ import com.google.protobuf.InvalidProtocolBufferException; import feast.core.FeatureSetProto.EntitySpec; import feast.core.FeatureSetProto.FeatureSetSpec; +import feast.core.StoreProto.Store.RedisConfig; import feast.serving.ServingAPIProto.FeatureReference; import feast.serving.ServingAPIProto.GetOnlineFeaturesRequest.EntityRow; import feast.storage.RedisProto.RedisKey; @@ -29,8 +30,11 @@ import feast.types.FieldProto.Field; import feast.types.ValueProto.Value; import io.grpc.Status; +import io.lettuce.core.RedisClient; +import io.lettuce.core.RedisURI; import io.lettuce.core.api.StatefulRedisConnection; import io.lettuce.core.api.sync.RedisCommands; +import io.lettuce.core.codec.ByteArrayCodec; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -45,6 +49,13 @@ public RedisOnlineRetriever(StatefulRedisConnection connection) this.syncCommands = connection.sync(); } + public RedisOnlineRetriever(RedisConfig config) { + StatefulRedisConnection connection = + RedisClient.create(RedisURI.create(config.getHost(), config.getPort())) + .connect(new ByteArrayCodec()); + this.syncCommands = connection.sync(); + } + /** * Gets online features from redis. This method returns a list of {@link FeatureRow}s * corresponding to each feature set spec. Each feature row in the list then corresponds to an From 03e567356a2b9f88f48be30709848a09b78fb8de Mon Sep 17 00:00:00 2001 From: Willem Pienaar Date: Sat, 11 Apr 2020 11:09:34 +0800 Subject: [PATCH 4/4] Set default build version in Feast Core "version" field in Feast Properties --- core/src/main/java/feast/core/config/FeastProperties.java | 6 ++---- .../java/feast/serving/service/CachedSpecServiceTest.java | 4 ---- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/core/src/main/java/feast/core/config/FeastProperties.java b/core/src/main/java/feast/core/config/FeastProperties.java index 59324d9567e..b21ed254614 100644 --- a/core/src/main/java/feast/core/config/FeastProperties.java +++ b/core/src/main/java/feast/core/config/FeastProperties.java @@ -52,12 +52,10 @@ public FeastProperties(BuildProperties buildProperties) { setVersion(buildProperties.getVersion()); } - public FeastProperties() { - setVersion("unknown"); - } + public FeastProperties() {} /* Feast Core Build Version */ - @NotBlank private String version; + @NotBlank private String version = "unknown"; /* Population job properties */ @NotNull private JobProperties jobs; diff --git a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java index 580b45a224b..f4f795ed32f 100644 --- a/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java +++ b/serving/src/test/java/feast/serving/service/CachedSpecServiceTest.java @@ -38,10 +38,6 @@ import feast.serving.specs.CachedSpecService; import feast.serving.specs.CoreSpecService; import feast.storage.api.retriever.FeatureSetRequest; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileWriter; -import java.io.IOException; import java.util.Collections; import java.util.LinkedHashMap; import java.util.List;