Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions ml_metadata/metadata_store/metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,12 @@ def _get_channel(self, config: metadata_store_pb2.MetadataStoreClientConfig):
"""
target = ':'.join([config.host, str(config.port)])

channel_arguments = None
if config.channel_arguments.HasField('max_receive_message_length'):
channel_arguments = [('grpc.max_receive_message_length', config.channel_arguments.max_receive_message_length)]

if not config.HasField('ssl_config'):
return grpc.insecure_channel(target)
return grpc.insecure_channel(target, options=channel_arguments)

root_certificates = None
private_key = None
Expand All @@ -116,7 +120,7 @@ def _get_channel(self, config: metadata_store_pb2.MetadataStoreClientConfig):
str(config.ssl_config.server_cert).encode('ascii'))
credentials = grpc.ssl_channel_credentials(root_certificates, private_key,
certificate_chain)
return grpc.secure_channel(target, credentials)
return grpc.secure_channel(target, credentials, options=channel_arguments)

def __del__(self):
if self._using_db_connection and hasattr(self, '_metadata_store'):
Expand Down
14 changes: 13 additions & 1 deletion ml_metadata/metadata_store/metadata_store_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@
"The gRPC port number to use when use_grpc_backed is set to 'True'")


def _get_metadata_store():
def _get_metadata_store(max_receive_message_length=4*1024*1024):
if FLAGS.use_grpc_backend:
grpc_connection_config = metadata_store_pb2.MetadataStoreClientConfig()
grpc_connection_config.channel_arguments.max_receive_message_length = max_receive_message_length
if FLAGS.grpc_host is None:
raise ValueError("grpc_host argument not set.")
grpc_connection_config.host = FLAGS.grpc_host
Expand Down Expand Up @@ -133,6 +134,17 @@ def test_connection_config_with_retry_options(self):
store = metadata_store.MetadataStore(connection_config)
self.assertEqual(store._max_num_retries, want_num_retries)

def test_connection_config_with_grpc_channel_arguments(self):
if FLAGS.use_grpc_backend:
# set max_receive_message_length to 0, and client should raise ResourceExhaustedError
store = _get_metadata_store(max_receive_message_length=0)
artifact_type_name = self._get_test_type_name()
artifact_type = _create_example_artifact_type(artifact_type_name)
with self.assertRaises(errors.ResourceExhaustedError):
store.put_artifact_type(artifact_type)
else:
self.skipTest("Skip test due to missing GRPC backend")

def test_put_artifact_type_get_artifact_type(self):
store = _get_metadata_store()
artifact_type_name = self._get_test_type_name()
Expand Down
8 changes: 8 additions & 0 deletions ml_metadata/proto/metadata_store.proto
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,11 @@ message ConnectionConfig {
optional RetryOptions retry_options = 4;
}

message GrpcChannelArguments {
// Maximum message length that the channel can receive.
optional int64 max_receive_message_length = 1;
}

// Configuration for the gRPC metadata store client.
message MetadataStoreClientConfig {
// The hostname or IP address of the gRPC server. Must be specified.
Expand All @@ -590,6 +595,9 @@ message MetadataStoreClientConfig {
// Configuration for a secure gRPC channel.
// If not given, insecure connection is used.
optional SSLConfig ssl_config = 3;

// GRPC channel arguments
optional GrpcChannelArguments channel_arguments = 4;
}

// Configuration for the gRPC metadata store server.
Expand Down