diff --git a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java index ddba5b8d5b1..211ef5366e9 100644 --- a/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java +++ b/protobuf-lite/src/main/java/io/grpc/protobuf/lite/ProtoLiteUtils.java @@ -81,7 +81,20 @@ public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) { */ public static Marshaller marshaller(T defaultInstance) { // TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe) - return new MessageMarshaller<>(defaultInstance); + return new MessageMarshaller<>(defaultInstance, -1); + } + + /** + * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a + * custom limit for the recursion depth. Any negative number will leave the limit to its default + * value as defined by the protobuf library. + * + * @since 1.56.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") + public static Marshaller marshallerWithRecursionLimit( + T defaultInstance, int recursionLimit) { + return new MessageMarshaller<>(defaultInstance, recursionLimit); } /** @@ -117,18 +130,20 @@ private ProtoLiteUtils() { private static final class MessageMarshaller implements PrototypeMarshaller { + private static final ThreadLocal> bufs = new ThreadLocal<>(); private final Parser parser; private final T defaultInstance; + private final int recursionLimit; @SuppressWarnings("unchecked") - MessageMarshaller(T defaultInstance) { - this.defaultInstance = defaultInstance; - parser = (Parser) defaultInstance.getParserForType(); + MessageMarshaller(T defaultInstance, int recursionLimit) { + this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null"); + this.parser = (Parser) defaultInstance.getParserForType(); + this.recursionLimit = recursionLimit; } - @SuppressWarnings("unchecked") @Override public Class getMessageClass() { @@ -211,6 +226,10 @@ public T parse(InputStream stream) { // when parsing. cis.setSizeLimit(Integer.MAX_VALUE); + if (recursionLimit >= 0) { + cis.setRecursionLimit(recursionLimit); + } + try { return parseFrom(cis); } catch (InvalidProtocolBufferException ipbe) { diff --git a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java index 6ea836f96a7..5c25cb3b309 100644 --- a/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java +++ b/protobuf-lite/src/test/java/io/grpc/protobuf/lite/ProtoLiteUtilsTest.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.io.ByteStreams; @@ -36,6 +37,7 @@ import io.grpc.Status; import io.grpc.StatusRuntimeException; import io.grpc.internal.GrpcUtil; +import io.grpc.testing.protobuf.SimpleRecursiveMessage; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -54,7 +56,7 @@ public class ProtoLiteUtilsTest { @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 @Rule public final ExpectedException thrown = ExpectedException.none(); - private Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); + private final Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); private Type proto = Type.newBuilder().setName("name").build(); @Test @@ -85,7 +87,7 @@ public void testInvalidatedMessage() throws Exception { } @Test - public void parseInvalid() throws Exception { + public void parseInvalid() { InputStream is = new ByteArrayInputStream(new byte[] {-127}); try { marshaller.parse(is); @@ -97,7 +99,7 @@ public void parseInvalid() throws Exception { } @Test - public void testMismatch() throws Exception { + public void testMismatch() { Marshaller enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance()); // Enum's name and Type's name are both strings with tag 1. Enum altProto = Enum.newBuilder().setName(proto.getName()).build(); @@ -105,7 +107,7 @@ public void testMismatch() throws Exception { } @Test - public void introspection() throws Exception { + public void introspection() { Marshaller enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance()); PrototypeMarshaller prototypeMarshaller = (PrototypeMarshaller) enumMarshaller; assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype()); @@ -219,7 +221,7 @@ public void extensionRegistry_notNull() { } @Test - public void parseFromKnowLengthInputStream() throws Exception { + public void parseFromKnowLengthInputStream() { Marshaller marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance()); Type expect = Type.newBuilder().setName("expected name").build(); @@ -232,21 +234,106 @@ public void defaultMaxMessageSize() { assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE); } + @Test + public void testNullDefaultInstance() { + String expectedMessage = "defaultInstance cannot be null"; + assertThrows(expectedMessage, NullPointerException.class, + () -> ProtoLiteUtils.marshaller(null)); + + assertThrows(expectedMessage, NullPointerException.class, + () -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10) + ); + } + + @Test + public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException { + Marshaller marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( + SimpleRecursiveMessage.getDefaultInstance(), 10); + SimpleRecursiveMessage message = buildRecursiveMessage(12); + + assertRecursionLimitExceeded(marshaller, message); + } + + @Test + public void givenZeroLimit_testRecursionLimitExceeded() throws IOException { + Marshaller marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( + SimpleRecursiveMessage.getDefaultInstance(), 0); + SimpleRecursiveMessage message = buildRecursiveMessage(1); + + assertRecursionLimitExceeded(marshaller, message); + } + + @Test + public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException { + Marshaller marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( + SimpleRecursiveMessage.getDefaultInstance(), 15); + SimpleRecursiveMessage message = buildRecursiveMessage(12); + + assertRecursionLimitNotExceeded(marshaller, message); + } + + @Test + public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException { + Marshaller marshaller = ProtoLiteUtils.marshallerWithRecursionLimit( + SimpleRecursiveMessage.getDefaultInstance(), 0); + SimpleRecursiveMessage message = buildRecursiveMessage(0); + + assertRecursionLimitNotExceeded(marshaller, message); + } + + @Test + public void testDefaultRecursionLimit() throws IOException { + Marshaller marshaller = ProtoLiteUtils.marshaller( + SimpleRecursiveMessage.getDefaultInstance()); + SimpleRecursiveMessage message = buildRecursiveMessage(100); + + assertRecursionLimitNotExceeded(marshaller, message); + } + + private static void assertRecursionLimitExceeded(Marshaller marshaller, + SimpleRecursiveMessage message) throws IOException { + InputStream is = marshaller.stream(message); + ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is)); + + assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais)); + } + + private static void assertRecursionLimitNotExceeded(Marshaller marshaller, + SimpleRecursiveMessage message) throws IOException { + InputStream is = marshaller.stream(message); + ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is)); + + assertEquals(message, marshaller.parse(bais)); + } + + private static SimpleRecursiveMessage buildRecursiveMessage(int depth) { + SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder() + .setValue("depth-" + depth); + for (int i = depth; i > 0; i--) { + builder = SimpleRecursiveMessage.newBuilder() + .setValue("depth-" + i) + .setMessage(builder.build()); + } + + return builder.build(); + } + private static class CustomKnownLengthInputStream extends InputStream implements KnownLength { + private int position = 0; - private byte[] source; + private final byte[] source; private CustomKnownLengthInputStream(byte[] source) { this.source = source; } @Override - public int available() throws IOException { + public int available() { return source.length - position; } @Override - public int read() throws IOException { + public int read() { if (position == source.length) { return -1; } diff --git a/protobuf-lite/src/test/proto/simple_recursive.proto b/protobuf-lite/src/test/proto/simple_recursive.proto new file mode 100644 index 00000000000..e3ff0f55634 --- /dev/null +++ b/protobuf-lite/src/test/proto/simple_recursive.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package grpc.testing; + +option java_package = "io.grpc.testing.protobuf"; +option java_outer_classname = "SimpleRecursiveProto"; +option java_multiple_files = true; + +// A simple recursive message for testing purposes +message SimpleRecursiveMessage { + string value = 1; + SimpleRecursiveMessage message = 2; +} diff --git a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java index c936b3c1b48..ebc708f522f 100644 --- a/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java +++ b/protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java @@ -57,6 +57,19 @@ public static Marshaller marshaller(final T defaultInstan return ProtoLiteUtils.marshaller(defaultInstance); } + /** + * Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a + * custom limit for the recursion depth. Any negative number will leave the limit to its default + * value as defined by the protobuf library. + * + * @since 1.56.0 + */ + @ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108") + public static Marshaller marshallerWithRecursionLimit(T defaultInstance, + int recursionLimit) { + return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit); + } + /** * Produce a metadata key for a generated protobuf type. * @@ -70,7 +83,7 @@ public static Metadata.Key keyForProto(T instance) { /** * Produce a metadata marshaller for a protobuf type. - * + * * @since 1.13.0 */ @ExperimentalApi("https://github.com/grpc/grpc-java/issues/4477")