Permalink
Browse files

set map_input_file environment variable in PipedRDD

  • Loading branch information...
1 parent 3eb009f commit 2ba805ee7af681dce75d203816222779ccff45f2 @tgravescs tgravescs committed Mar 6, 2014
View
16 core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -26,8 +26,10 @@ import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
+import org.apache.hadoop.mapred.FileSplit
import org.apache.spark.{Partition, SparkEnv, TaskContext}
+
/**
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
@@ -59,6 +61,20 @@ class PipedRDD[T: ClassTag](
val currentEnvVars = pb.environment()
envVars.foreach { case (variable, value) => currentEnvVars.put(variable, value) }
+ // for compatibility with Hadoop which sets these env variables
+ // so the user code can access the input filename
+ if (split.isInstanceOf[HadoopPartition]) {
+ val hadoopSplit = split.asInstanceOf[HadoopPartition]
+
+ if (hadoopSplit.inputSplit.value.isInstanceOf[FileSplit]) {
+ val is: FileSplit = hadoopSplit.inputSplit.value.asInstanceOf[FileSplit]
+ // map.input.file is deprecated in favor of mapreduce.map.input.file but set both
+ // since its not removed yet
+ currentEnvVars.put("map_input_file", is.getPath().toString())
+ currentEnvVars.put("mapreduce_map_input_file", is.getPath().toString())
+ }
+ }
+
val proc = pb.start()
val env = SparkEnv.get
View
41 core/src/test/scala/org/apache/spark/PipedRDDSuite.scala
@@ -19,6 +19,14 @@ package org.apache.spark
import org.scalatest.FunSuite
+
+import org.apache.spark.rdd.{HadoopRDD, PipedRDD, HadoopPartition}
+import org.apache.hadoop.mapred.{JobConf, TextInputFormat, FileSplit}
+import org.apache.hadoop.fs.Path
+
+import scala.collection.Map
+import org.apache.hadoop.io.{Text, LongWritable}
+
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
@@ -89,4 +97,37 @@ class PipedRDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("test pipe exports map_input_file") {
+ testExportInputFile("map_input_file")
+ }
+
+ test("test pipe exports mapreduce_map_input_file") {
+ testExportInputFile("mapreduce_map_input_file")
+ }
+
+ def testExportInputFile(varName:String) {
+ val nums = new HadoopRDD(sc, new JobConf(), classOf[TextInputFormat], classOf[LongWritable],
+ classOf[Text], 2) {
+ override def getPartitions: Array[Partition] = Array(generateFakeHadoopPartition())
+ override val getDependencies = List[Dependency[_]]()
+ override def compute(theSplit: Partition, context: TaskContext) = {
+ new InterruptibleIterator[(LongWritable, Text)](context, Iterator((new LongWritable(1),
+ new Text("b"))))
+ }
+ }
+ val hadoopPart1 = generateFakeHadoopPartition()
+ val pipedRdd = new PipedRDD(nums, "printenv " + varName)
+ val tContext = new TaskContext(0, 0, 0, interrupted = false, runningLocally = false,
+ taskMetrics = null)
+ val rddIter = pipedRdd.compute(hadoopPart1, tContext)
+ val arr = rddIter.toArray
+ assert(arr(0) == "/some/path")
+ }
+
+ def generateFakeHadoopPartition(): HadoopPartition = {
+ val split = new FileSplit(new Path("/some/path"), 0, 1,
+ Array[String]("loc1", "loc2", "loc3", "loc4", "loc5"))
+ new HadoopPartition(sc.newRddId(), 1, split)
+ }
+
}

0 comments on commit 2ba805e

Please sign in to comment.