Skip to content

Commit

Permalink
Add new Hadoop API reading support.
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesreiss committed Dec 1, 2011
1 parent 02d43e6 commit 66f05f3
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
88 changes: 88 additions & 0 deletions core/src/main/scala/spark/NewHadoopRDD.scala
@@ -0,0 +1,88 @@
package spark

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.InputFormat
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.JobID
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.TaskAttemptID

import java.util.Date
import java.text.SimpleDateFormat

class NewHadoopSplit(rddId: Int, val index: Int, @transient rawSplit: InputSplit with Writable)
extends Split {
val serializableHadoopSplit = new SerializableWritable(rawSplit)

override def hashCode(): Int = (41 * (41 + rddId) + index).toInt
}

class NewHadoopRDD[K, V](
sc: SparkContext,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K], valueClass: Class[V],
@transient conf: Configuration)
extends RDD[(K, V)](sc) {
private val serializableConf = new SerializableWritable(conf)

private val jobtrackerId: String = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
formatter.format(new Date())
}

@transient private val jobId = new JobID(jobtrackerId, id)

@transient private val splits_ : Array[Split] = {
val inputFormat = inputFormatClass.newInstance
val jobContext = new JobContext(serializableConf.value, jobId)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Split](rawSplits.size)
for (i <- 0 until rawSplits.size)
result(i) = new NewHadoopSplit(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
result
}

override def splits = splits_

override def compute(theSplit: Split) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopSplit]
val conf = serializableConf.value
val attemptId = new TaskAttemptID(jobtrackerId, id, true, split.index, 0)
val context = new TaskAttemptContext(serializableConf.value, attemptId)
val format = inputFormatClass.newInstance
val reader = format.createRecordReader(split.serializableHadoopSplit.value, context)
reader.initialize(split.serializableHadoopSplit.value, context)

var havePair = false
var finished = false

override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
if (finished) {
reader.close
}
}
!finished
}

override def next: (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
return (reader.getCurrentKey, reader.getCurrentValue)
}
}

override def preferredLocations(split: Split) = {
val theSplit = split.asInstanceOf[NewHadoopSplit]
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
}

override val dependencies: List[Dependency[_]] = Nil
}
29 changes: 29 additions & 0 deletions core/src/main/scala/spark/SparkContext.scala
Expand Up @@ -6,6 +6,8 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.actors.remote.RemoteActor
import scala.collection.mutable.ArrayBuffer

import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.io.Writable
Expand All @@ -22,6 +24,10 @@ import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat

import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}

import spark.broadcast._

class SparkContext(
Expand Down Expand Up @@ -123,6 +129,29 @@ extends Logging {
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)

/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
val job = new NewHadoopJob
NewFileInputFormat.addInputPath(job, new Path(path))
val conf = job.getConfiguration
newAPIHadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
conf)
}

/** Get an RDD for a given Hadoop file with an arbitrary new API InputFormat and extra
* configuration options to pass to the input format.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
conf: Configuration): RDD[(K, V)] =
new NewHadoopRDD(this, fClass, kClass, vClass, conf)

/** Get an RDD for a Hadoop SequenceFile with given key and value types */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
Expand Down
13 changes: 13 additions & 0 deletions core/src/test/scala/spark/FileSuite.scala
Expand Up @@ -128,4 +128,17 @@ class FileSuite extends FunSuite {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
sc.stop()
}

test("read SequenceFile using new Hadoop API") {
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat
val sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 3).map(x => (new IntWritable(x), new Text("a" * x)))
nums.saveAsSequenceFile(outputDir)
val output =
sc.newAPIHadoopFile[IntWritable, Text, SequenceFileInputFormat[IntWritable, Text]](outputDir)
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
sc.stop()
}
}

0 comments on commit 66f05f3

Please sign in to comment.