diff --git a/core/src/main/scala/spark/NewHadoopRDD.scala b/core/src/main/scala/spark/NewHadoopRDD.scala new file mode 100644 index 0000000000..c40a39cbe0 --- /dev/null +++ b/core/src/main/scala/spark/NewHadoopRDD.scala @@ -0,0 +1,88 @@ +package spark + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapreduce.InputFormat +import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.JobContext +import org.apache.hadoop.mapreduce.JobID +import org.apache.hadoop.mapreduce.RecordReader +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.mapreduce.TaskAttemptID + +import java.util.Date +import java.text.SimpleDateFormat + +class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable) +extends Split { + val serializableHadoopSplit = new SerializableWritable(rawSplit) + + override def hashCode(): Int = (41 * (41 + rddId) + index).toInt +} + +class NewHadoopRDD[K, V]( + sc: SparkContext, + inputFormatClass: Class[_ <: InputFormat[K, V]], + keyClass: Class[K], valueClass: Class[V], + @transient conf: Configuration) +extends RDD[(K, V)](sc) { + private val serializableConf = new SerializableWritable(conf) + + private val jobtrackerId: String = { + val formatter = new SimpleDateFormat("yyyyMMddHHmm") + formatter.format(new Date()) + } + + @transient private val jobId = new JobID(jobtrackerId, id) + + @transient private val splits_ : Array[Split] = { + val inputFormat = inputFormatClass.newInstance + val jobContext = new JobContext(serializableConf.value, jobId) + val rawSplits = inputFormat.getSplits(jobContext).toArray + val result = new Array[Split](rawSplits.size) + for (i <- 0 until rawSplits.size) + result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) + result + } + + override def splits = splits_ + + override def compute(theSplit: Split) = new Iterator[(K, V)] { + val split = theSplit.asInstanceOf[NewHadoopSplit] + val conf = serializableConf.value + val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0) + val context = new TaskAttemptContext(serializableConf.value, attemptId) + val format = inputFormatClass.newInstance + val reader = format.createRecordReader(split.serializableHadoopSplit.value, context) + reader.initialize(split.serializableHadoopSplit.value, context) + + var havePair = false + var finished = false + + override def hasNext: Boolean = { + if (!finished && !havePair) { + finished = !reader.nextKeyValue + havePair = !finished + if (finished) { + reader.close + } + } + !finished + } + + override def next: (K, V) = { + if (!hasNext) { + throw new java.util.NoSuchElementException("End of stream") + } + havePair = false + return (reader.getCurrentKey, reader.getCurrentValue) + } + } + + override def preferredLocations(split: Split) = { + val theSplit = split.asInstanceOf[NewHadoopSplit] + theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") + } + + override val dependencies: List[Dependency[_]] = Nil +} diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index f044da1e21..25b879ba96 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.actors.remote.RemoteActor import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.SequenceFileInputFormat import org.apache.hadoop.io.Writable @@ -22,6 +24,10 @@ import org.apache.hadoop.mapred.FileInputFormat import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.TextInputFormat +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} +import org.apache.hadoop.mapreduce.{Job => NewHadoopJob} + import spark.broadcast._ class SparkContext( @@ -123,6 +129,29 @@ extends Logging { (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = hadoopFile[K, V, F](path, defaultMinSplits) + /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) + (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { + val job = new NewHadoopJob + NewFileInputFormat.addInputPath(job, new Path(path)) + val conf = job.getConfiguration + newAPIHadoopFile(path, + fm.erasure.asInstanceOf[Class[F]], + km.erasure.asInstanceOf[Class[K]], + vm.erasure.asInstanceOf[Class[V]], + conf) + } + + /** Get an RDD for a given Hadoop file with an arbitrary new API InputFormat and extra + * configuration options to pass to the input format. + */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String, + fClass: Class[F], + kClass: Class[K], + vClass: Class[V], + conf: Configuration): RDD[(K, V)] = + new NewHadoopRDD(this, fClass, kClass, vClass, conf) + /** Get an RDD for a Hadoop SequenceFile with given key and value types */ def sequenceFile[K, V](path: String, keyClass: Class[K], diff --git a/core/src/test/scala/spark/FileSuite.scala b/core/src/test/scala/spark/FileSuite.scala index bb2d0c658b..b12014e6be 100644 --- a/core/src/test/scala/spark/FileSuite.scala +++ b/core/src/test/scala/spark/FileSuite.scala @@ -128,4 +128,17 @@ class FileSuite extends FunSuite { assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) sc.stop() } + + test("read SequenceFile using new Hadoop API") { + import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat + val sc = new SparkContext("local", "test") + val tempDir = Files.createTempDir() + val outputDir = new File(tempDir, "output").getAbsolutePath + val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x))) + nums.saveAsSequenceFile(outputDir) + val output = + sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir) + assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)")) + sc.stop() + } }