Permalink
Browse files

Mammoth change to get Avro reflection working for Scala collection types

  • Loading branch information...
1 parent 87995d1 commit c006a751ad8c0ee81574093bc6d31c028842287e @jwills jwills committed Mar 28, 2012
@@ -50,6 +50,16 @@
static final String CLASS_PROP = "java-class";
static final String ELEMENT_PROP = "java-element-class";
+ static Class getClassProp(Schema schema, String prop) {
+ String name = schema.getProp(prop);
+ if (name == null) return null;
+ try {
+ return Class.forName(name);
+ } catch (ClassNotFoundException e) {
+ throw new AvroRuntimeException(e);
+ }
+ }
+
/**
* This method is the whole reason for this class to exist, so that I can
* hack around a problem where calling getSimpleName on a class that is
@@ -264,4 +274,16 @@ private Schema getAnnotatedUnion(Union union, Map<String,Schema> names) {
branches.add(createSchema(branch, names));
return Schema.createUnion(branches);
}
+
+ @Override
+ protected boolean isArray(Object datum) {
+ if (datum == null) return false;
+ return (datum instanceof Collection) || datum.getClass().isArray() ||
+ (datum instanceof scala.collection.Iterable);
+ }
+
+ @Override
+ protected boolean isMap(Object datum) {
+ return (datum instanceof java.util.Map) || (datum instanceof scala.collection.Map);
+ }
}
@@ -14,14 +14,106 @@
*/
package com.cloudera.scrunch;
+import java.io.IOException;
+import java.lang.reflect.Array;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Map;
+
import org.apache.avro.Schema;
+import org.apache.avro.io.ResolvingDecoder;
import org.apache.avro.reflect.ReflectDatumReader;
+import scala.collection.JavaConversions;
+
/**
*
*/
public class ScalaSafeReflectDatumReader<T> extends ReflectDatumReader<T> {
+
public ScalaSafeReflectDatumReader(Schema schema) {
super(schema, schema, ScalaSafeReflectData.get());
}
+
+ @Override
+ protected Object readArray(Object old, Schema expected,
+ ResolvingDecoder in) throws IOException {
+ Schema expectedType = expected.getElementType();
+ long l = in.readArrayStart();
+ long base = 0;
+ if (l > 0) {
+ Object array = newArray(old, (int) l, expected);
+ do {
+ for (long i = 0; i < l; i++) {
+ addToArray(array, base + i, read(peekArray(array), expectedType, in));
+ }
+ base += l;
+ } while ((l = in.arrayNext()) > 0);
+ return scalaIterableCheck(array, expected);
+ } else {
+ return scalaIterableCheck(newArray(old, 0, expected), expected);
+ }
+ }
+
+ @Override
+ protected Object readMap(Object old, Schema expected,
+ ResolvingDecoder in) throws IOException {
+ return scalaMapCheck(super.readMap(old, expected, in), expected);
+ }
+
+ public static Object scalaMapCheck(Object map, Schema schema) {
+ Class mapClass = ScalaSafeReflectData.getClassProp(schema,
+ ScalaSafeReflectData.CLASS_PROP);
+ if (mapClass != null && mapClass.isAssignableFrom(scala.collection.Map.class)) {
+ return JavaConversions.mapAsScalaMap((Map) map);
+ }
+ return map;
+ }
+
+ public static Object scalaIterableCheck(Object array, Schema schema) {
+ Class collectionClass = ScalaSafeReflectData.getClassProp(schema,
+ ScalaSafeReflectData.CLASS_PROP);
+ if (collectionClass != null) {
+ if (scala.collection.Iterable.class.isAssignableFrom(collectionClass)) {
+ scala.collection.Iterable it = toIter(array);
+ if (scala.collection.immutable.List.class.isAssignableFrom(collectionClass)) {
+ return it.toList();
+ } else if (scala.collection.mutable.Buffer.class.isAssignableFrom(collectionClass)) {
+ return it.toBuffer();
+ } else if (scala.collection.immutable.Set.class.isAssignableFrom(collectionClass)) {
+ return it.toSet();
+ }
+ return it;
+ }
+ }
+ return array;
+ }
+
+ private static scala.collection.Iterable toIter(Object array) {
+ return JavaConversions.collectionAsScalaIterable((Collection) array);
+ }
+
+ @Override
+ @SuppressWarnings(value="unchecked")
+ protected Object newArray(Object old, int size, Schema schema) {
+ ScalaSafeReflectData data = ScalaSafeReflectData.get();
+ Class collectionClass = ScalaSafeReflectData.getClassProp(schema,
+ ScalaSafeReflectData.CLASS_PROP);
+ if (collectionClass != null) {
+ if (old instanceof Collection) {
+ ((Collection)old).clear();
+ return old;
+ }
+ if (scala.collection.Iterable.class.isAssignableFrom(collectionClass) ||
+ collectionClass.isAssignableFrom(ArrayList.class)) {
+ return new ArrayList();
+ }
+ return data.newInstance(collectionClass, schema);
+ }
+ Class elementClass = ScalaSafeReflectData.getClassProp(schema,
+ ScalaSafeReflectData.ELEMENT_PROP);
+ if (elementClass == null)
+ elementClass = data.getClass(schema.getElementType());
+ return Array.newInstance(elementClass, size);
+ }
}
@@ -14,13 +14,52 @@
*/
package com.cloudera.scrunch;
+import java.util.Iterator;
+import java.util.Map;
+
import org.apache.avro.reflect.ReflectDatumWriter;
+import scala.collection.JavaConversions;
+
/**
*
*/
public class ScalaSafeReflectDatumWriter<T> extends ReflectDatumWriter<T> {
public ScalaSafeReflectDatumWriter() {
super(ScalaSafeReflectData.get());
}
+
+ @Override
+ protected long getArraySize(Object array) {
+ if (array instanceof scala.collection.Iterable) {
+ return ((scala.collection.Iterable) array).size();
+ }
+ return super.getArraySize(array);
+ }
+
+ @Override
+ protected Iterator<Object> getArrayElements(Object array) {
+ if (array instanceof scala.collection.Iterable) {
+ return JavaConversions.asJavaIterable((scala.collection.Iterable) array).iterator();
+ }
+ return super.getArrayElements(array);
+ }
+
+ @Override
+ protected int getMapSize(Object map) {
+ if (map instanceof scala.collection.Map) {
+ return ((scala.collection.Map) map).size();
+ }
+ return super.getMapSize(map);
+ }
+
+ /** Called by the default implementation of {@link #writeMap} to enumerate
+ * map elements. The default implementation is for {@link Map}.*/
+ @SuppressWarnings("unchecked")
+ protected Iterable<Map.Entry<Object,Object>> getMapEntries(Object map) {
+ if (map instanceof scala.collection.Map) {
+ return JavaConversions.mapAsJavaMap((scala.collection.Map) map).entrySet();
+ }
+ return super.getMapEntries(map);
+ }
}
@@ -27,6 +27,8 @@
import com.cloudera.crunch.io.impl.FileSourceImpl;
import com.cloudera.crunch.type.avro.AvroInputFormat;
import com.cloudera.crunch.type.avro.AvroType;
+import com.cloudera.crunch.type.avro.Avros;
+import com.cloudera.crunch.type.avro.ReflectDataFactory;
public class AvroFileSource<T> extends FileSourceImpl<T> implements ReadableSource<T> {
@@ -44,10 +46,11 @@ public String toString() {
@Override
public void configureSource(Job job, int inputId) throws IOException {
- super.configureSource(job, inputId);
-
- job.getConfiguration().setBoolean(AvroJob.INPUT_IS_REFLECT, !this.avroType.isSpecific());
- job.getConfiguration().set(AvroJob.INPUT_SCHEMA, avroType.getSchema().toString());
+ super.configureSource(job, inputId);
+ Configuration conf = job.getConfiguration();
+ conf.setBoolean(AvroJob.INPUT_IS_REFLECT, !this.avroType.isSpecific());
+ conf.set(AvroJob.INPUT_SCHEMA, avroType.getSchema().toString());
+ Avros.configureReflectDataFactory(conf);
}
@Override
@@ -26,6 +26,8 @@
import com.cloudera.crunch.type.PType;
import com.cloudera.crunch.type.avro.AvroOutputFormat;
import com.cloudera.crunch.type.avro.AvroType;
+import com.cloudera.crunch.type.avro.Avros;
+import com.cloudera.crunch.type.avro.ReflectDataFactory;
public class AvroFileTarget extends FileTargetImpl {
public AvroFileTarget(String path) {
@@ -66,7 +68,8 @@ public void configureForMapReduce(Job job, PType<?> ptype, Path outputPath,
conf.set(schemaParam, atype.getSchema().toString());
} else if (!outputSchema.equals(atype.getSchema().toString())) {
throw new IllegalStateException("Avro targets must use the same output schema");
- }
+ }
+ Avros.configureReflectDataFactory(conf);
configureForMapReduce(job, AvroWrapper.class, NullWritable.class,
outputPath, name);
}
@@ -85,10 +85,12 @@ public void configureShuffle(Job job, GroupingOptions options) {
options.configure(job);
}
+ Avros.configureReflectDataFactory(conf);
+
Collection<String> serializations =
job.getConfiguration().getStringCollection("io.serializations");
- if (!serializations.contains(AvroSerialization.class.getName())) {
- serializations.add(AvroSerialization.class.getName());
+ if (!serializations.contains(SafeAvroSerialization.class.getName())) {
+ serializations.add(SafeAvroSerialization.class.getName());
job.getConfiguration().setStrings("io.serializations",
serializations.toArray(new String[0]));
}
@@ -38,23 +38,22 @@
@Override
protected List<FileStatus> listStatus(JobContext job) throws IOException {
- List<FileStatus> result = new ArrayList<FileStatus>();
- for (FileStatus file : super.listStatus(job)) {
- if (file.getPath().getName().endsWith(org.apache.avro.mapred.AvroOutputFormat.EXT)) {
- result.add(file);
- }
+ List<FileStatus> result = new ArrayList<FileStatus>();
+ for (FileStatus file : super.listStatus(job)) {
+ if (file.getPath().getName().endsWith(org.apache.avro.mapred.AvroOutputFormat.EXT)) {
+ result.add(file);
}
- return result;
+ }
+ return result;
}
@Override
public RecordReader<AvroWrapper<T>, NullWritable> createRecordReader(InputSplit split,
- TaskAttemptContext context) throws IOException, InterruptedException {
- context.setStatus(split.toString());
-
- String jsonSchema = context.getConfiguration().get(AvroJob.INPUT_SCHEMA);
- Schema schema = new Schema.Parser().parse(jsonSchema);
- return new AvroRecordReader<T>(schema);
+ TaskAttemptContext context) throws IOException, InterruptedException {
+ context.setStatus(split.toString());
+ String jsonSchema = context.getConfiguration().get(AvroJob.INPUT_SCHEMA);
+ Schema schema = new Schema.Parser().parse(jsonSchema);
+ return new AvroRecordReader<T>(schema);
}
}
@@ -47,8 +47,8 @@
schema = AvroJob.getOutputSchema(context.getConfiguration());
}
- final DataFileWriter<T> WRITER =
- new DataFileWriter<T>(Avros.REFLECT_DATA_FACTORY.<T>getWriter());
+ ReflectDataFactory factory = Avros.getReflectDataFactory(conf);
+ final DataFileWriter<T> WRITER = new DataFileWriter<T>(factory.<T>getWriter());
Path path = getDefaultWorkFile(context,
org.apache.avro.mapred.AvroOutputFormat.EXT);
@@ -28,7 +28,6 @@
import org.apache.avro.mapred.AvroJob;
import org.apache.avro.mapred.AvroWrapper;
import org.apache.avro.mapred.FsInput;
-import org.apache.avro.reflect.ReflectDatumReader;
import org.apache.avro.specific.SpecificDatumReader;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.NullWritable;
@@ -59,7 +58,8 @@ public void initialize(InputSplit genericSplit, TaskAttemptContext context) thro
SeekableInput in = new FsInput(split.getPath(), conf);
DatumReader<T> datumReader = null;
if (context.getConfiguration().getBoolean(AvroJob.INPUT_IS_REFLECT, true)) {
- datumReader = Avros.REFLECT_DATA_FACTORY.getReader(schema);
+ ReflectDataFactory factory = Avros.getReflectDataFactory(conf);
+ datumReader = factory.getReader(schema);
} else {
datumReader = new SpecificDatumReader<T>(schema);
}
@@ -24,8 +24,9 @@
import org.apache.avro.Schema.Type;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericRecord;
-import org.apache.avro.reflect.ReflectData;
import org.apache.avro.util.Utf8;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.util.ReflectionUtils;
import com.cloudera.crunch.MapFn;
import com.cloudera.crunch.Pair;
@@ -54,6 +55,21 @@
* The instance we use for generating reflected schemas. May be modified by clients (e.g., Scrunch.)
*/
public static ReflectDataFactory REFLECT_DATA_FACTORY = new ReflectDataFactory();
+
+ /**
+ * The name of the configuration parameter that tracks which reflection factory to use.
+ */
+ private static final String REFLECT_DATA_FACTORY_CLASS = "crunch.reflectdatafactory";
+
+ public static void configureReflectDataFactory(Configuration conf) {
+ conf.setClass(REFLECT_DATA_FACTORY_CLASS, REFLECT_DATA_FACTORY.getClass(),
+ ReflectDataFactory.class);
+ }
+
+ public static ReflectDataFactory getReflectDataFactory(Configuration conf) {
+ return (ReflectDataFactory) ReflectionUtils.newInstance(
+ conf.getClass(REFLECT_DATA_FACTORY_CLASS, ReflectDataFactory.class), conf);
+ }
public static MapFn<CharSequence, String> UTF8_TO_STRING = new MapFn<CharSequence, String>() {
@Override
Oops, something went wrong.

0 comments on commit c006a75

Please sign in to comment.