Skip to content

Commit

Permalink
Speed up Protobuf generation through object reuse. Thanks to John Sir…
Browse files Browse the repository at this point in the history
…ois for the tip.
  • Loading branch information
Dmitriy Ryaboy committed Apr 30, 2010
1 parent 129f3e4 commit b8dae6e
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 15 deletions.
Expand Up @@ -3,6 +3,7 @@
import java.io.IOException;
import java.io.InputStream;

import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.elephantbird.mapreduce.io.ProtobufWritable;
import com.twitter.elephantbird.util.Protobufs;
Expand Down Expand Up @@ -30,9 +31,11 @@ public class LzoProtobufB64LineRecordReader<M extends Message, W extends Protob
private final W value_;
private final TypeRef<M> typeRef_;
private final Base64 base64_ = new Base64();
private final Function<byte[], M> protoConverter_;

public LzoProtobufB64LineRecordReader(TypeRef<M> typeRef, W protobufWritable) {
typeRef_ = typeRef;
protoConverter_ = Protobufs.getProtoConverter(typeRef_.getRawClass());
LOG.info("LzoProtobufBlockRecordReader, type args are " + typeRef_.getRawClass());
value_ = protobufWritable;
}
Expand Down Expand Up @@ -79,7 +82,7 @@ public boolean nextKeyValue() throws IOException, InterruptedException {
}
pos_ = getLzoFilePos();
byte[] lineBytes = line_.toString().getBytes("UTF-8");
M protoValue = Protobufs.<M>parseFrom(typeRef_.getRawClass(), base64_.decode(lineBytes));
M protoValue = protoConverter_.apply(base64_.decode(lineBytes));
if (protoValue == null) {
continue;
}
Expand Down
Expand Up @@ -3,6 +3,7 @@
import java.io.IOException;
import java.io.InputStream;

import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.data.proto.BlockStorage.SerializedBlock;
import com.twitter.elephantbird.util.Protobufs;
Expand Down Expand Up @@ -34,17 +35,17 @@
public class ProtobufBlockReader<M extends Message> {
private static final Logger LOG = LoggerFactory.getLogger(ProtobufBlockReader.class);

private InputStream in_;
private StreamSearcher searcher_;
private TypeRef<M> typeRef_;
private final InputStream in_;
private final StreamSearcher searcher_;
private final Function<byte[], M> protoConverter_;
private SerializedBlock curBlock_;
private int numLeftToReadThisBlock_ = 0;
private boolean readNewBlocks_ = true;

public ProtobufBlockReader(InputStream in, TypeRef<M> typeRef) {
LOG.info("ProtobufReader, my typeClass is " + typeRef.getRawClass());
in_ = in;
typeRef_ = typeRef;
protoConverter_ = Protobufs.getProtoConverter(typeRef.getRawClass());
searcher_ = new StreamSearcher(Protobufs.KNOWN_GOOD_POSITION_MARKER);
}

Expand All @@ -61,7 +62,7 @@ public boolean readProtobuf(ProtobufWritable<M> message) throws IOException {

int blobIndex = curBlock_.getProtoBlobsCount() - numLeftToReadThisBlock_;
byte[] blob = curBlock_.getProtoBlobs(blobIndex).toByteArray();
message.set(Protobufs.<M>parseFrom(typeRef_.getRawClass(), blob));
message.set(protoConverter_.apply(blob));
numLeftToReadThisBlock_--;
return true;
}
Expand Down
Expand Up @@ -4,6 +4,7 @@
import java.io.DataOutput;
import java.io.IOException;

import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.elephantbird.util.Protobufs;
import com.twitter.elephantbird.util.TypeRef;
Expand All @@ -23,15 +24,15 @@ public class ProtobufWritable<M extends Message> implements Writable {
private static final Logger LOG = LoggerFactory.getLogger(ProtobufWritable.class);

private M message_;
private TypeRef<M> typeRef_;

private final Function<byte[], M> protoConverter_;
public ProtobufWritable(TypeRef<M> typeRef) {
this(null, typeRef);
}

public ProtobufWritable(M message, TypeRef<M> typeRef) {
message_ = message;
typeRef_ = typeRef;
protoConverter_ = Protobufs.getProtoConverter(typeRef.getRawClass());
LOG.debug("ProtobufWritable, typeClass is " + typeRef.getRawClass() + " and message is " + message_);
}

Expand Down Expand Up @@ -64,7 +65,7 @@ public void readFields(DataInput in) throws IOException {
if (size > 0) {
byte[] messageBytes = new byte[size];
in.readFully(messageBytes, 0, size);
message_ = Protobufs.<M>parseFrom(typeRef_.getRawClass(), messageBytes);
message_ = protoConverter_.apply(messageBytes);
}
}
}
Expand Up @@ -3,6 +3,7 @@
import java.io.IOException;
import java.nio.charset.Charset;

import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.elephantbird.pig.util.ProtobufToPig;
import com.twitter.elephantbird.util.Protobufs;
Expand All @@ -26,6 +27,7 @@ public abstract class LzoProtobufB64LinePigLoader<M extends Message> extends Lzo
private static final Logger LOG = LoggerFactory.getLogger(LzoProtobufB64LinePigLoader.class);

private TypeRef<M> typeRef_ = null;
private Function<byte[], M> protoConverter_ = null;
private final Base64 base64_ = new Base64();
private final ProtobufToPig protoToPig_ = new ProtobufToPig();

Expand All @@ -45,6 +47,7 @@ public LzoProtobufB64LinePigLoader() {
*/
public void setTypeRef(TypeRef<M> typeRef) {
typeRef_ = typeRef;
protoConverter_ = Protobufs.getProtoConverter(typeRef.getRawClass());
}

public void skipToNextSyncPoint(boolean atFirstRecord) throws IOException {
Expand Down Expand Up @@ -72,7 +75,7 @@ public Tuple getNext() throws IOException {
Tuple t = null;
while ((line = is_.readLine(UTF8, RECORD_DELIMITER)) != null) {
incrCounter(LzoProtobufB64LinePigLoaderCounts.LinesRead, 1L);
M protoValue = Protobufs.parseFrom(typeRef_.getRawClass(), base64_.decode(line.getBytes("UTF-8")));
M protoValue = protoConverter_.apply(base64_.decode(line.getBytes("UTF-8")));
if (protoValue != null) {
t = protoToPig_.toTuple(protoValue);
incrCounter(LzoProtobufB64LinePigLoaderCounts.ProtobufsRead, 1L);
Expand Down
Expand Up @@ -7,6 +7,7 @@
import org.apache.pig.data.Tuple;
import org.apache.pig.impl.logicalLayer.schema.Schema;

import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.elephantbird.pig.util.ProtobufToPig;
import com.twitter.elephantbird.util.Protobufs;
Expand All @@ -20,6 +21,7 @@
*/
public abstract class ProtobufBytesToTuple<M extends Message> extends EvalFunc<Tuple> {
private TypeRef<M> typeRef_ = null;
private Function<byte[], M> protoConverter_ = null;
private final ProtobufToPig protoToPig_ = new ProtobufToPig();

/**
Expand All @@ -29,14 +31,15 @@ public abstract class ProtobufBytesToTuple<M extends Message> extends EvalFunc<T
*/
public void setTypeRef(TypeRef<M> typeRef) {
typeRef_ = typeRef;
protoConverter_ = Protobufs.getProtoConverter(typeRef.getRawClass());
}

@Override
public Tuple exec(Tuple input) throws IOException {
if (input == null || input.size() < 1) return null;
try {
DataByteArray bytes = (DataByteArray) input.get(0);
M value_ = Protobufs.parseFrom(typeRef_.getRawClass(), bytes.get());
M value_ = protoConverter_.apply(bytes.get());
return protoToPig_.toTuple(value_);
} catch (IOException e) {
return null;
Expand Down
23 changes: 23 additions & 0 deletions src/java/com/twitter/elephantbird/util/Protobufs.java
Expand Up @@ -115,6 +115,7 @@ public String apply(FieldDescriptor f) {
}


@SuppressWarnings("unchecked")
public static <M extends Message> M parseFrom(Class<M> protoClass, byte[] messageBytes) {
try {
Method parseFrom = protoClass.getMethod("parseFrom", new Class[] { byte[].class });
Expand All @@ -141,6 +142,28 @@ public static DynamicMessage parseDynamicFrom(Class<? extends Message> protoClas

return null;
}

/**
* Creates a Function to repeatedly convert byte arrays into Messages. Using such a function
* is more efficient than the static <code>parseFrom</code> method, since it avoids some of the
* reflection overhead of the static function.
*/
public static <M extends Message> Function<byte[], M> getProtoConverter(final Class<M> protoClass) {
return new Function<byte[], M>() {
private Message.Builder protoBuilder = Protobufs.getMessageBuilder(protoClass);

@SuppressWarnings("unchecked")
@Override
public M apply(byte[] bytes) {
try {
return (M) protoBuilder.clone().mergeFrom(bytes).build();
} catch (InvalidProtocolBufferException e) {
LOG.error("Invalid Protocol Buffer exception building " + protoClass.getName(), e);
return null;
}
}
};
}

public static Message instantiateFromClassName(String canonicalClassName) {
Class<? extends Message> protoClass = getInnerProtobufClass(canonicalClassName);
Expand Down
29 changes: 26 additions & 3 deletions src/test/com/twitter/elephantbird/util/TestProtobufs.java
@@ -1,17 +1,40 @@
package com.twitter.elephantbird.util;

import com.google.protobuf.Message;
import com.twitter.data.proto.tutorial.AddressBookProtos.Person;
import static org.junit.Assert.assertEquals;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import com.google.common.base.Function;
import com.google.protobuf.Message;
import com.twitter.data.proto.tutorial.AddressBookProtos.AddressBook;
import com.twitter.data.proto.tutorial.AddressBookProtos.Person;
import com.twitter.elephantbird.pig.piggybank.Fixtures;

public class TestProtobufs {

private static final AddressBook ab_ = Fixtures.buildAddressBookProto();
private static final byte[] abBytes_ = ab_.toByteArray();

@Test
public void testGetInnerProtobufClass() {
String canonicalClassName = "com.twitter.data.proto.tutorial.AddressBookProtos.Person";
Class<? extends Message> klass = Protobufs.getInnerProtobufClass(canonicalClassName);
assertEquals(klass, Person.class);
}

@Test
public void testDynamicParsing() {
assertEquals(ab_, Protobufs.parseDynamicFrom(AddressBook.class, abBytes_));
}

@Test
public void testStaticParsing() {
assertEquals(ab_, Protobufs.parseFrom(AddressBook.class, abBytes_));
}

@Test
public void testConverterParsing() {
Function<byte[], AddressBook> protoConverter = Protobufs.getProtoConverter(AddressBook.class);
assertEquals(ab_, protoConverter.apply(abBytes_));
}
}

0 comments on commit b8dae6e

Please sign in to comment.