Skip to content

Commit

Permalink
Add option to respect the marshaller specified in gRPC MethodDescript…
Browse files Browse the repository at this point in the history
…or (#5630)

Motivation:

- This PR adds option to use marshaller that specified in gRPC `MethodDescriptor`
- Related Issue #5103
  - Partial solved: Option to use marshaller that specified in gRPC `MethodDescriptor`
  - Unsolved part: Provide way to add custom marshaller

Modifications:

- New option `useMethodMarshaller`
  - default value is `false`
- Add validate logic for `GrpcServiceBuilder` and `GrpcClientBuilder` to check that `unsafeWrapDeserializedBuffer` and `useMethodMarshaller` are mutually exclusive

Result:
- Have new option `useMethodMarshaller`
- Throw `IllegalStateException` when both `unsafeWrapDeserializedBuffer` and `useMethodMarshaller` are enabled
<!--
Visit this URL to learn more about how to write a pull request description:
https://armeria.dev/community/developer-guide#how-to-write-pull-request-description
-->
  • Loading branch information
jaeseung-bae committed May 16, 2024
1 parent 1737482 commit 8ab4284
Show file tree
Hide file tree
Showing 17 changed files with 430 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import static com.linecorp.armeria.client.grpc.GrpcClientOptions.MAX_INBOUND_MESSAGE_SIZE_BYTES;
import static com.linecorp.armeria.client.grpc.GrpcClientOptions.MAX_OUTBOUND_MESSAGE_SIZE_BYTES;
import static com.linecorp.armeria.client.grpc.GrpcClientOptions.UNSAFE_WRAP_RESPONSE_BUFFERS;
import static com.linecorp.armeria.client.grpc.GrpcClientOptions.USE_METHOD_MARSHALLER;
import static java.util.Objects.requireNonNull;

import java.net.URI;
Expand Down Expand Up @@ -81,6 +82,7 @@
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.DecompressorRegistry;
import io.grpc.MethodDescriptor;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;

Expand Down Expand Up @@ -284,9 +286,31 @@ public GrpcClientBuilder callCredentials(CallCredentials callCredentials) {
*/
@UnstableApi
public GrpcClientBuilder enableUnsafeWrapResponseBuffers(boolean enableUnsafeWrapResponseBuffers) {
final ClientOptions options = buildOptions();
if (options.get(USE_METHOD_MARSHALLER)) {
throw new IllegalStateException(
"'unsafeWrapRequestBuffers' and 'useMethodMarshaller' are mutually exclusive."
);
}
return option(UNSAFE_WRAP_RESPONSE_BUFFERS.newValue(enableUnsafeWrapResponseBuffers));
}

/**
* Sets whether to respect the marshaller specified in gRPC {@link MethodDescriptor}.
* If disabled, the default marshaller will be used, which is more efficient.
* This property is disabled by default.
*/
@UnstableApi
public GrpcClientBuilder useMethodMarshaller(boolean useMethodMarshaller) {
final ClientOptions options = buildOptions();
if (options.get(GrpcClientOptions.UNSAFE_WRAP_RESPONSE_BUFFERS)) {
throw new IllegalStateException(
"'unsafeWrapRequestBuffers' and 'useMethodMarshaller' are mutually exclusive."
);
}
return option(USE_METHOD_MARSHALLER.newValue(useMethodMarshaller));
}

/**
* Sets the factory that creates a {@link GrpcJsonMarshaller} that serializes and deserializes request or
* response messages to and from JSON depending on the {@link SerializationFormat}. The returned
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import io.grpc.Codec;
import io.grpc.Compressor;
import io.grpc.DecompressorRegistry;
import io.grpc.MethodDescriptor;
import io.grpc.ServiceDescriptor;
import io.grpc.Status;

Expand Down Expand Up @@ -176,5 +177,13 @@ public final class GrpcClientOptions {
ClientOption.define("EXCEPTION_HANDLER",
(ctx, cause, metadata) -> GrpcStatus.fromThrowable(cause));

/**
* Sets whether to respect the marshaller specified in gRPC {@link MethodDescriptor}.
* If disabled, the default marshaller will be used, which is more efficient.
* This option is disabled by default.
*/
public static final ClientOption<Boolean> USE_METHOD_MARSHALLER =
ClientOption.define("USE_METHOD_MARSHALLER", false);

private GrpcClientOptions() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ final class ArmeriaChannel extends Channel implements ClientBuilderParams, Unwra
private final DecompressorRegistry decompressorRegistry;
private final CallCredentials credentials0;
private final GrpcExceptionHandlerFunction exceptionHandler;
private final boolean useMethodMarshaller;

ArmeriaChannel(ClientBuilderParams params,
HttpClient httpClient,
Expand All @@ -119,6 +120,7 @@ final class ArmeriaChannel extends Channel implements ClientBuilderParams, Unwra
maxOutboundMessageSizeBytes = options.get(GrpcClientOptions.MAX_OUTBOUND_MESSAGE_SIZE_BYTES);
maxInboundMessageSizeBytes = maxInboundMessageSizeBytes(options);
unsafeWrapResponseBuffers = options.get(GrpcClientOptions.UNSAFE_WRAP_RESPONSE_BUFFERS);
useMethodMarshaller = options.get(GrpcClientOptions.USE_METHOD_MARSHALLER);
compressor = options.get(GrpcClientOptions.COMPRESSOR);
decompressorRegistry = options.get(GrpcClientOptions.DECOMPRESSOR_REGISTRY);
credentials0 = options.get(GrpcClientOptions.CALL_CREDENTIALS);
Expand Down Expand Up @@ -180,7 +182,8 @@ public <I, O> ClientCall<I, O> newCall(MethodDescriptor<I, O> method, CallOption
serializationFormat,
jsonMarshaller,
unsafeWrapResponseBuffers,
exceptionHandler);
exceptionHandler,
useMethodMarshaller);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ final class ArmeriaClientCall<I, O> extends ClientCall<I, O>
SerializationFormat serializationFormat,
@Nullable GrpcJsonMarshaller jsonMarshaller,
boolean unsafeWrapResponseBuffers,
GrpcExceptionHandlerFunction exceptionHandler) {
GrpcExceptionHandlerFunction exceptionHandler,
boolean useMethodMarshaller) {
this.ctx = ctx;
this.endpointGroup = endpointGroup;
this.httpClient = httpClient;
Expand All @@ -184,7 +185,7 @@ final class ArmeriaClientCall<I, O> extends ClientCall<I, O>

requestFramer = new ArmeriaMessageFramer(ctx.alloc(), maxOutboundMessageSizeBytes, grpcWebText);
marshaller = new GrpcMessageMarshaller<>(ctx.alloc(), serializationFormat, method, jsonMarshaller,
unsafeWrapResponseBuffers);
unsafeWrapResponseBuffers, useMethodMarshaller);

if (callOptions.getExecutor() == null) {
executor = MoreExecutors.directExecutor();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ public Object newClient(ClientBuilderParams params) {
} else {
jsonMarshaller = null;
}

final ArmeriaChannel armeriaChannel =
new ArmeriaChannel(newParams, httpClient, meterRegistry(), scheme.sessionProtocol(),
serializationFormat, jsonMarshaller, simpleMethodNames);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,14 @@ private enum MessageType {
private final MessageType responseType;
private final boolean unsafeWrapDeserializedBuffer;
private final boolean isProto;
private final boolean useMethodMarshaller;

public GrpcMessageMarshaller(ByteBufAllocator alloc,
SerializationFormat serializationFormat,
MethodDescriptor<I, O> method,
@Nullable GrpcJsonMarshaller jsonMarshaller,
boolean unsafeWrapDeserializedBuffer) {
boolean unsafeWrapDeserializedBuffer,
boolean useMethodMarshaller) {
this.alloc = requireNonNull(alloc, "alloc");
this.method = requireNonNull(method, "method");
this.unsafeWrapDeserializedBuffer = unsafeWrapDeserializedBuffer;
Expand All @@ -84,6 +86,7 @@ public GrpcMessageMarshaller(ByteBufAllocator alloc,
responseMarshaller = method.getResponseMarshaller();
requestType = marshallerType(requestMarshaller);
responseType = marshallerType(responseMarshaller);
this.useMethodMarshaller = useMethodMarshaller;
}

public ByteBuf serializeRequest(I message) throws IOException {
Expand Down Expand Up @@ -203,8 +206,17 @@ private <T> ByteBuf serializeProto(PrototypeMarshaller<T> marshaller, Message me
final ByteBuf buf = alloc.buffer(serializedSize);
boolean success = false;
try {
message.writeTo(CodedOutputStream.newInstance(buf.nioBuffer(0, serializedSize)));
buf.writerIndex(serializedSize);
if (useMethodMarshaller) {
final InputStream is = marshaller.stream((T) message);
try (ByteBufOutputStream os = new ByteBufOutputStream(buf)) {
ByteStreams.copy(is, os);
} finally {
is.close();
}
} else {
message.writeTo(CodedOutputStream.newInstance(buf.nioBuffer(0, serializedSize)));
buf.writerIndex(serializedSize);
}
success = true;
} finally {
if (!success) {
Expand Down Expand Up @@ -236,20 +248,25 @@ private <T> Message deserializeProto(PrototypeMarshaller<T> marshaller, ByteBuf
if (!buf.isReadable()) {
return prototype.getDefaultInstanceForType();
}
final CodedInputStream stream;
if (unsafeWrapDeserializedBuffer) {
stream = UnsafeByteOperations.unsafeWrap(buf.nioBuffer()).newCodedInput();
stream.enableAliasing(true);
} else {
stream = CodedInputStream.newInstance(buf.nioBuffer());
}
try {
final Message msg = prototype.getParserForType().parseFrom(stream);
try {
stream.checkLastTagWas(0);
} catch (InvalidProtocolBufferException e) {
e.setUnfinishedMessage(msg);
throw e;
final Message msg;
if (useMethodMarshaller) {
msg = (Message) marshaller.parse(new ByteBufInputStream(buf));
} else {
final CodedInputStream stream;
if (unsafeWrapDeserializedBuffer) {
stream = UnsafeByteOperations.unsafeWrap(buf.nioBuffer()).newCodedInput();
stream.enableAliasing(true);
} else {
stream = CodedInputStream.newInstance(buf.nioBuffer());
}
msg = prototype.getParserForType().parseFrom(stream);
try {
stream.checkLastTagWas(0);
} catch (InvalidProtocolBufferException e) {
e.setUnfinishedMessage(msg);
throw e;
}
}
return msg;
} catch (InvalidProtocolBufferException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ protected AbstractServerCall(HttpRequest req,
ResponseHeaders defaultHeaders,
@Nullable GrpcExceptionHandlerFunction exceptionHandler,
@Nullable Executor blockingExecutor,
boolean autoCompression) {
boolean autoCompression,
boolean useMethodMarshaller) {
requireNonNull(req, "req");
this.method = requireNonNull(method, "method");
this.simpleMethodName = requireNonNull(simpleMethodName, "simpleMethodName");
Expand All @@ -170,7 +171,7 @@ protected AbstractServerCall(HttpRequest req,
clientAcceptEncoding = req.headers().get(GrpcHeaderNames.GRPC_ACCEPT_ENCODING, "");
this.autoCompression = autoCompression;
marshaller = new GrpcMessageMarshaller<>(alloc, serializationFormat, method, jsonMarshaller,
unsafeWrapRequestBuffers);
unsafeWrapRequestBuffers, useMethodMarshaller);
this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers;
this.blockingExecutor = blockingExecutor;
defaultResponseHeaders = defaultHeaders;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ private static Map<String, GrpcJsonMarshaller> getJsonMarshallers(
private final boolean useBlockingTaskExecutor;
private final boolean unsafeWrapRequestBuffers;
private final boolean useClientTimeoutHeader;
private final boolean useMethodMarshaller;
private final String advertisedEncodingsHeader;
private final Map<SerializationFormat, ResponseHeaders> defaultHeaders;
@Nullable
Expand All @@ -154,7 +155,7 @@ private static Map<String, GrpcJsonMarshaller> getJsonMarshallers(
boolean useClientTimeoutHeader,
boolean lookupMethodFromAttribute,
@Nullable GrpcHealthCheckService grpcHealthCheckService,
boolean autoCompression) {
boolean autoCompression, boolean useMethodMarshaller) {
this.registry = requireNonNull(registry, "registry");
routes = ImmutableSet.copyOf(registry.methodsByRoute().keySet());
exchangeTypes = registry.methods().entrySet().stream()
Expand All @@ -173,6 +174,7 @@ private static Map<String, GrpcJsonMarshaller> getJsonMarshallers(
this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers;
this.lookupMethodFromAttribute = lookupMethodFromAttribute;
this.autoCompression = autoCompression;
this.useMethodMarshaller = useMethodMarshaller;

advertisedEncodingsHeader = String.join(",", decompressorRegistry.getAdvertisedMessageEncodings());

Expand Down Expand Up @@ -356,7 +358,8 @@ private <I, O> AbstractServerCall<I, O> newServerCall(
defaultHeaders.get(serializationFormat),
exceptionHandler,
blockingExecutor,
autoCompression);
autoCompression,
useMethodMarshaller);
} else {
return new StreamingServerCall<>(
req,
Expand All @@ -374,7 +377,8 @@ private <I, O> AbstractServerCall<I, O> newServerCall(
defaultHeaders.get(serializationFormat),
exceptionHandler,
blockingExecutor,
autoCompression);
autoCompression,
useMethodMarshaller);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ public final class GrpcServiceBuilder {

private boolean useClientTimeoutHeader = true;

private boolean useMethodMarshaller;

private boolean enableHealthCheckService;

private boolean autoCompression;
Expand Down Expand Up @@ -749,6 +751,11 @@ public GrpcServiceBuilder useBlockingTaskExecutor(boolean useBlockingTaskExecuto
* {@link GrpcSerializationFormats#PROTO_WEB_TEXT}.
*/
public GrpcServiceBuilder unsafeWrapRequestBuffers(boolean unsafeWrapRequestBuffers) {
if (unsafeWrapRequestBuffers && useMethodMarshaller) {
throw new IllegalStateException(
"'unsafeWrapRequestBuffers' and 'useMethodMarshaller' are mutually exclusive."
);
}
this.unsafeWrapRequestBuffers = unsafeWrapRequestBuffers;
return this;
}
Expand Down Expand Up @@ -824,6 +831,21 @@ public GrpcServiceBuilder autoCompression(boolean autoCompression) {
return this;
}

/**
* Sets whether to respect the marshaller specified in gRPC {@link MethodDescriptor}
* If not set, will use the default(false), which use more efficient way that reduce copy operation.
*/
@UnstableApi
public GrpcServiceBuilder useMethodMarshaller(boolean useMethodMarshaller) {
if (unsafeWrapRequestBuffers && useMethodMarshaller) {
throw new IllegalStateException(
"'unsafeWrapRequestBuffers' and 'useMethodMarshaller' are mutually exclusive."
);
}
this.useMethodMarshaller = useMethodMarshaller;
return this;
}

/**
* Sets the specified {@link GrpcExceptionHandlerFunction} that maps a {@link Throwable}
* to a gRPC {@link Status}.
Expand Down Expand Up @@ -1016,7 +1038,8 @@ public GrpcService build() {
useClientTimeoutHeader,
enableHttpJsonTranscoding, // The method definition might be set when transcoding is enabled.
grpcHealthCheckService,
autoCompression);
autoCompression,
useMethodMarshaller);
if (enableUnframedRequests) {
grpcService = new UnframedGrpcService(
grpcService, handlerRegistry,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,11 @@ final class StreamingServerCall<I, O> extends AbstractServerCall<I, O>
@Nullable GrpcJsonMarshaller jsonMarshaller, boolean unsafeWrapRequestBuffers,
ResponseHeaders defaultHeaders,
@Nullable GrpcExceptionHandlerFunction exceptionHandler,
@Nullable Executor blockingExecutor, boolean autoCompress) {
@Nullable Executor blockingExecutor, boolean autoCompress,
boolean useMethodMarshaller) {
super(req, method, simpleMethodName, compressorRegistry, decompressorRegistry, res,
maxResponseMessageLength, ctx, serializationFormat, jsonMarshaller, unsafeWrapRequestBuffers,
defaultHeaders, exceptionHandler, blockingExecutor, autoCompress);
defaultHeaders, exceptionHandler, blockingExecutor, autoCompress, useMethodMarshaller);
requireNonNull(req, "req");
this.method = requireNonNull(method, "method");
this.ctx = requireNonNull(ctx, "ctx");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,11 @@ final class UnaryServerCall<I, O> extends AbstractServerCall<I, O> {
ResponseHeaders defaultHeaders,
@Nullable GrpcExceptionHandlerFunction exceptionHandler,
@Nullable Executor blockingExecutor,
boolean autoCompress) {
boolean autoCompress,
boolean useMethodMarshaller) {
super(req, method, simpleMethodName, compressorRegistry, decompressorRegistry, res,
maxResponseMessageLength, ctx, serializationFormat, jsonMarshaller, unsafeWrapRequestBuffers,
defaultHeaders, exceptionHandler, blockingExecutor, autoCompress);
defaultHeaders, exceptionHandler, blockingExecutor, autoCompress, useMethodMarshaller);
requireNonNull(req, "req");
this.ctx = requireNonNull(ctx, "ctx");
final boolean grpcWebText = GrpcSerializationFormats.isGrpcWebText(serializationFormat);
Expand Down
Loading

0 comments on commit 8ab4284

Please sign in to comment.