From d10c65de9e400fd235dd11110e13949d0d2b0d42 Mon Sep 17 00:00:00 2001 From: Keshi Dai Date: Mon, 7 Sep 2020 21:44:43 -0400 Subject: [PATCH] Add channel arguments to set max_receive_message_length in client config --- ml_metadata/metadata_store/metadata_store.py | 8 ++++++-- ml_metadata/metadata_store/metadata_store_test.py | 14 +++++++++++++- ml_metadata/proto/metadata_store.proto | 8 ++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/ml_metadata/metadata_store/metadata_store.py b/ml_metadata/metadata_store/metadata_store.py index d0f35f0ad..4c67336c8 100644 --- a/ml_metadata/metadata_store/metadata_store.py +++ b/ml_metadata/metadata_store/metadata_store.py @@ -100,8 +100,12 @@ def _get_channel(self, config: metadata_store_pb2.MetadataStoreClientConfig): """ target = ':'.join([config.host, str(config.port)]) + channel_arguments = None + if config.channel_arguments.HasField('max_receive_message_length'): + channel_arguments = [('grpc.max_receive_message_length', config.channel_arguments.max_receive_message_length)] + if not config.HasField('ssl_config'): - return grpc.insecure_channel(target) + return grpc.insecure_channel(target, options=channel_arguments) root_certificates = None private_key = None @@ -116,7 +120,7 @@ def _get_channel(self, config: metadata_store_pb2.MetadataStoreClientConfig): str(config.ssl_config.server_cert).encode('ascii')) credentials = grpc.ssl_channel_credentials(root_certificates, private_key, certificate_chain) - return grpc.secure_channel(target, credentials) + return grpc.secure_channel(target, credentials, options=channel_arguments) def __del__(self): if self._using_db_connection and hasattr(self, '_metadata_store'): diff --git a/ml_metadata/metadata_store/metadata_store_test.py b/ml_metadata/metadata_store/metadata_store_test.py index 748dcd1dc..8e1dccee1 100644 --- a/ml_metadata/metadata_store/metadata_store_test.py +++ b/ml_metadata/metadata_store/metadata_store_test.py @@ -41,9 +41,10 @@ "The gRPC port number to use when use_grpc_backed is set to 'True'") -def _get_metadata_store(): +def _get_metadata_store(max_receive_message_length=4*1024*1024): if FLAGS.use_grpc_backend: grpc_connection_config = metadata_store_pb2.MetadataStoreClientConfig() + grpc_connection_config.channel_arguments.max_receive_message_length = max_receive_message_length if FLAGS.grpc_host is None: raise ValueError("grpc_host argument not set.") grpc_connection_config.host = FLAGS.grpc_host @@ -133,6 +134,17 @@ def test_connection_config_with_retry_options(self): store = metadata_store.MetadataStore(connection_config) self.assertEqual(store._max_num_retries, want_num_retries) + def test_connection_config_with_grpc_channel_arguments(self): + if FLAGS.use_grpc_backend: + # set max_receive_message_length to 0, and client should raise ResourceExhaustedError + store = _get_metadata_store(max_receive_message_length=0) + artifact_type_name = self._get_test_type_name() + artifact_type = _create_example_artifact_type(artifact_type_name) + with self.assertRaises(errors.ResourceExhaustedError): + store.put_artifact_type(artifact_type) + else: + self.skipTest("Skip test due to missing GRPC backend") + def test_put_artifact_type_get_artifact_type(self): store = _get_metadata_store() artifact_type_name = self._get_test_type_name() diff --git a/ml_metadata/proto/metadata_store.proto b/ml_metadata/proto/metadata_store.proto index 9d96cc35b..056f40588 100644 --- a/ml_metadata/proto/metadata_store.proto +++ b/ml_metadata/proto/metadata_store.proto @@ -567,6 +567,11 @@ message ConnectionConfig { optional RetryOptions retry_options = 4; } +message GrpcChannelArguments { + // Maximum message length that the channel can receive. + optional int64 max_receive_message_length = 1; +} + // Configuration for the gRPC metadata store client. message MetadataStoreClientConfig { // The hostname or IP address of the gRPC server. Must be specified. @@ -590,6 +595,9 @@ message MetadataStoreClientConfig { // Configuration for a secure gRPC channel. // If not given, insecure connection is used. optional SSLConfig ssl_config = 3; + + // GRPC channel arguments + optional GrpcChannelArguments channel_arguments = 4; } // Configuration for the gRPC metadata store server.