Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,20 @@ public static void setExtensionRegistry(ExtensionRegistryLite newRegistry) {
*/
public static <T extends MessageLite> Marshaller<T> marshaller(T defaultInstance) {
// TODO(ejona): consider changing return type to PrototypeMarshaller (assuming ABI safe)
return new MessageMarshaller<>(defaultInstance);
return new MessageMarshaller<>(defaultInstance, -1);
}

/**
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
* custom limit for the recursion depth. Any negative number will leave the limit to its default
* value as defined by the protobuf library.
*
* @since 1.56.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108")
public static <T extends MessageLite> Marshaller<T> marshallerWithRecursionLimit(
T defaultInstance, int recursionLimit) {
return new MessageMarshaller<>(defaultInstance, recursionLimit);
}

/**
Expand Down Expand Up @@ -117,18 +130,20 @@ private ProtoLiteUtils() {

private static final class MessageMarshaller<T extends MessageLite>
implements PrototypeMarshaller<T> {

private static final ThreadLocal<Reference<byte[]>> bufs = new ThreadLocal<>();

private final Parser<T> parser;
private final T defaultInstance;
private final int recursionLimit;

@SuppressWarnings("unchecked")
MessageMarshaller(T defaultInstance) {
this.defaultInstance = defaultInstance;
parser = (Parser<T>) defaultInstance.getParserForType();
MessageMarshaller(T defaultInstance, int recursionLimit) {
this.defaultInstance = checkNotNull(defaultInstance, "defaultInstance cannot be null");
this.parser = (Parser<T>) defaultInstance.getParserForType();
this.recursionLimit = recursionLimit;
}


@SuppressWarnings("unchecked")
@Override
public Class<T> getMessageClass() {
Expand Down Expand Up @@ -211,6 +226,10 @@ public T parse(InputStream stream) {
// when parsing.
cis.setSizeLimit(Integer.MAX_VALUE);

if (recursionLimit >= 0) {
cis.setRecursionLimit(recursionLimit);
}

try {
return parseFrom(cis);
} catch (InvalidProtocolBufferException ipbe) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.fail;

import com.google.common.io.ByteStreams;
Expand All @@ -36,6 +37,7 @@
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.internal.GrpcUtil;
import io.grpc.testing.protobuf.SimpleRecursiveMessage;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -54,7 +56,7 @@ public class ProtoLiteUtilsTest {
@SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467
@Rule public final ExpectedException thrown = ExpectedException.none();

private Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
private final Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
private Type proto = Type.newBuilder().setName("name").build();

@Test
Expand Down Expand Up @@ -85,7 +87,7 @@ public void testInvalidatedMessage() throws Exception {
}

@Test
public void parseInvalid() throws Exception {
public void parseInvalid() {
InputStream is = new ByteArrayInputStream(new byte[] {-127});
try {
marshaller.parse(is);
Expand All @@ -97,15 +99,15 @@ public void parseInvalid() throws Exception {
}

@Test
public void testMismatch() throws Exception {
public void testMismatch() {
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
// Enum's name and Type's name are both strings with tag 1.
Enum altProto = Enum.newBuilder().setName(proto.getName()).build();
assertEquals(proto, marshaller.parse(enumMarshaller.stream(altProto)));
}

@Test
public void introspection() throws Exception {
public void introspection() {
Marshaller<Enum> enumMarshaller = ProtoLiteUtils.marshaller(Enum.getDefaultInstance());
PrototypeMarshaller<Enum> prototypeMarshaller = (PrototypeMarshaller<Enum>) enumMarshaller;
assertSame(Enum.getDefaultInstance(), prototypeMarshaller.getMessagePrototype());
Expand Down Expand Up @@ -219,7 +221,7 @@ public void extensionRegistry_notNull() {
}

@Test
public void parseFromKnowLengthInputStream() throws Exception {
public void parseFromKnowLengthInputStream() {
Marshaller<Type> marshaller = ProtoLiteUtils.marshaller(Type.getDefaultInstance());
Type expect = Type.newBuilder().setName("expected name").build();

Expand All @@ -232,21 +234,106 @@ public void defaultMaxMessageSize() {
assertEquals(GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE, ProtoLiteUtils.DEFAULT_MAX_MESSAGE_SIZE);
}

@Test
public void testNullDefaultInstance() {
String expectedMessage = "defaultInstance cannot be null";
assertThrows(expectedMessage, NullPointerException.class,
() -> ProtoLiteUtils.marshaller(null));

assertThrows(expectedMessage, NullPointerException.class,
() -> ProtoLiteUtils.marshallerWithRecursionLimit(null, 10)
);
}

@Test
public void givenPositiveLimit_testRecursionLimitExceeded() throws IOException {
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
SimpleRecursiveMessage.getDefaultInstance(), 10);
SimpleRecursiveMessage message = buildRecursiveMessage(12);

assertRecursionLimitExceeded(marshaller, message);
}

@Test
public void givenZeroLimit_testRecursionLimitExceeded() throws IOException {
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
SimpleRecursiveMessage.getDefaultInstance(), 0);
SimpleRecursiveMessage message = buildRecursiveMessage(1);

assertRecursionLimitExceeded(marshaller, message);
}

@Test
public void givenPositiveLimit_testRecursionLimitNotExceeded() throws IOException {
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
SimpleRecursiveMessage.getDefaultInstance(), 15);
SimpleRecursiveMessage message = buildRecursiveMessage(12);

assertRecursionLimitNotExceeded(marshaller, message);
}

@Test
public void givenZeroLimit_testRecursionLimitNotExceeded() throws IOException {
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshallerWithRecursionLimit(
SimpleRecursiveMessage.getDefaultInstance(), 0);
SimpleRecursiveMessage message = buildRecursiveMessage(0);

assertRecursionLimitNotExceeded(marshaller, message);
}

@Test
public void testDefaultRecursionLimit() throws IOException {
Marshaller<SimpleRecursiveMessage> marshaller = ProtoLiteUtils.marshaller(
SimpleRecursiveMessage.getDefaultInstance());
SimpleRecursiveMessage message = buildRecursiveMessage(100);

assertRecursionLimitNotExceeded(marshaller, message);
}

private static void assertRecursionLimitExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
SimpleRecursiveMessage message) throws IOException {
InputStream is = marshaller.stream(message);
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));

assertThrows(StatusRuntimeException.class, () -> marshaller.parse(bais));
}

private static void assertRecursionLimitNotExceeded(Marshaller<SimpleRecursiveMessage> marshaller,
SimpleRecursiveMessage message) throws IOException {
InputStream is = marshaller.stream(message);
ByteArrayInputStream bais = new ByteArrayInputStream(ByteStreams.toByteArray(is));

assertEquals(message, marshaller.parse(bais));
}

private static SimpleRecursiveMessage buildRecursiveMessage(int depth) {
SimpleRecursiveMessage.Builder builder = SimpleRecursiveMessage.newBuilder()
.setValue("depth-" + depth);
for (int i = depth; i > 0; i--) {
builder = SimpleRecursiveMessage.newBuilder()
.setValue("depth-" + i)
.setMessage(builder.build());
}

return builder.build();
}

private static class CustomKnownLengthInputStream extends InputStream implements KnownLength {

private int position = 0;
private byte[] source;
private final byte[] source;

private CustomKnownLengthInputStream(byte[] source) {
this.source = source;
}

@Override
public int available() throws IOException {
public int available() {
return source.length - position;
}

@Override
public int read() throws IOException {
public int read() {
if (position == source.length) {
return -1;
}
Expand Down
13 changes: 13 additions & 0 deletions protobuf-lite/src/test/proto/simple_recursive.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

package grpc.testing;

option java_package = "io.grpc.testing.protobuf";
option java_outer_classname = "SimpleRecursiveProto";
option java_multiple_files = true;

// A simple recursive message for testing purposes
message SimpleRecursiveMessage {
string value = 1;
SimpleRecursiveMessage message = 2;
}
15 changes: 14 additions & 1 deletion protobuf/src/main/java/io/grpc/protobuf/ProtoUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,19 @@ public static <T extends Message> Marshaller<T> marshaller(final T defaultInstan
return ProtoLiteUtils.marshaller(defaultInstance);
}

/**
* Creates a {@link Marshaller} for protos of the same type as {@code defaultInstance} and a
* custom limit for the recursion depth. Any negative number will leave the limit to its default
* value as defined by the protobuf library.
*
* @since 1.56.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10108")
public static <T extends Message> Marshaller<T> marshallerWithRecursionLimit(T defaultInstance,
int recursionLimit) {
return ProtoLiteUtils.marshallerWithRecursionLimit(defaultInstance, recursionLimit);
}

/**
* Produce a metadata key for a generated protobuf type.
*
Expand All @@ -70,7 +83,7 @@ public static <T extends Message> Metadata.Key<T> keyForProto(T instance) {

/**
* Produce a metadata marshaller for a protobuf type.
*
*
* @since 1.13.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/4477")
Expand Down