Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Add new Hadoop API reading support.

commit 66f05f383e66d24326199101bafa5adb4504519c 1 parent 02d43e6
@woggle woggle authored
View
88 core/src/main/scala/spark/NewHadoopRDD.scala
@@ -0,0 +1,88 @@
+package spark
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce.InputFormat
+import org.apache.hadoop.mapreduce.InputSplit
+import org.apache.hadoop.mapreduce.JobContext
+import org.apache.hadoop.mapreduce.JobID
+import org.apache.hadoop.mapreduce.RecordReader
+import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.hadoop.mapreduce.TaskAttemptID
+
+import java.util.Date
+import java.text.SimpleDateFormat
+
+class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
+extends Split {
+ val serializableHadoopSplit = new SerializableWritable(rawSplit)
+
+ override def hashCode(): Int = (41 * (41 + rddId) + index).toInt
+}
+
+class NewHadoopRDD[K, V](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: InputFormat[K, V]],
+ keyClass: Class[K], valueClass: Class[V],
+ @transient conf: Configuration)
+extends RDD[(K, V)](sc) {
+ private val serializableConf = new SerializableWritable(conf)
+
+ private val jobtrackerId: String = {
+ val formatter = new SimpleDateFormat("yyyyMMddHHmm")
+ formatter.format(new Date())
+ }
+
+ @transient private val jobId = new JobID(jobtrackerId, id)
+
+ @transient private val splits_ : Array[Split] = {
+ val inputFormat = inputFormatClass.newInstance
+ val jobContext = new JobContext(serializableConf.value, jobId)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Split](rawSplits.size)
+ for (i <- 0 until rawSplits.size)
+ result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ result
+ }
+
+ override def splits = splits_
+
+ override def compute(theSplit: Split) = new Iterator[(K, V)] {
+ val split = theSplit.asInstanceOf[NewHadoopSplit]
+ val conf = serializableConf.value
+ val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
+ val context = new TaskAttemptContext(serializableConf.value, attemptId)
+ val format = inputFormatClass.newInstance
+ val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
+ reader.initialize(split.serializableHadoopSplit.value, context)
+
+ var havePair = false
+ var finished = false
+
+ override def hasNext: Boolean = {
+ if (!finished && !havePair) {
+ finished = !reader.nextKeyValue
+ havePair = !finished
+ if (finished) {
+ reader.close
+ }
+ }
+ !finished
+ }
+
+ override def next: (K, V) = {
+ if (!hasNext) {
+ throw new java.util.NoSuchElementException("End of stream")
+ }
+ havePair = false
+ return (reader.getCurrentKey, reader.getCurrentValue)
+ }
+ }
+
+ override def preferredLocations(split: Split) = {
+ val theSplit = split.asInstanceOf[NewHadoopSplit]
+ theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
+ }
+
+ override val dependencies: List[Dependency[_]] = Nil
+}
View
29 core/src/main/scala/spark/SparkContext.scala
@@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.io.Writable
@@ -22,6 +24,10 @@ import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
+
import spark.broadcast._
class SparkContext(
@@ -123,6 +129,29 @@ extends Logging {
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
+ /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
+ (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
+ val job = new NewHadoopJob
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val conf = job.getConfiguration
+ newAPIHadoopFile(path,
+ fm.erasure.asInstanceOf[Class[F]],
+ km.erasure.asInstanceOf[Class[K]],
+ vm.erasure.asInstanceOf[Class[V]],
+ conf)
+ }
+
+ /** Get an RDD for a given Hadoop file with an arbitrary new API InputFormat and extra
+ * configuration options to pass to the input format.
+ */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String,
+ fClass: Class[F],
+ kClass: Class[K],
+ vClass: Class[V],
+ conf: Configuration): RDD[(K, V)] =
+ new NewHadoopRDD(this, fClass, kClass, vClass, conf)
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
View
13 core/src/test/scala/spark/FileSuite.scala
@@ -128,4 +128,17 @@ class FileSuite extends FunSuite {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
sc.stop()
}
+
+ test("read SequenceFile using new Hadoop API") {
+ import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
+ val sc = new SparkContext("local", "test")
+ val tempDir = Files.createTempDir()
+ val outputDir = new File(tempDir, "output").getAbsolutePath
+ val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
+ nums.saveAsSequenceFile(outputDir)
+ val output =
+ sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
+ assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
+ sc.stop()
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.