Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

prevent Spark from overwriting directory silently and leaving dirty d…

…irectory
  • Loading branch information...
commit ec490e8a61898ed36478e62abf2f04ae470aa55a 1 parent aace2c0
@CodingCat CodingCat authored
View
19 core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -34,14 +34,11 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
-import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
-import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
-import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
+import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, RecordWriter => NewRecordWriter, JobContext, SparkHadoopMapReduceUtil}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark.
import org.apache.hadoop.mapred.SparkHadoopWriter
-import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
@@ -604,8 +601,12 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)
+
val wrappedConf = new SerializableWritable(job.getConfiguration)
- NewFileOutputFormat.setOutputPath(job, new Path(path))
+ val outpath = new Path(path)
+ NewFileOutputFormat.setOutputPath(job, outpath)
+ val jobFormat = outputFormatClass.newInstance
+ jobFormat.checkOutputSpecs(new JobContext(wrappedConf.value, job.getJobID))
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = self.id
@@ -633,7 +634,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
committer.commitTask(hadoopContext)
return 1
}
- val jobFormat = outputFormatClass.newInstance
+
/* apparently we need a TaskAttemptID to construct an OutputCommitter;
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
@@ -642,7 +643,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
- val count = self.context.runJob(self, writeShard _).sum
+ self.context.runJob(self, writeShard _).sum
jobCommitter.commitJob(jobTaskContext)
}
@@ -712,6 +713,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
+ val path = new Path(conf.get("mapred.output.dir"))
+ val fs = path.getFileSystem(conf)
+ conf.getOutputFormat.checkOutputSpecs(fs, conf)
+
val writer = new SparkHadoopWriter(conf)
writer.preSetup()
View
22 core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -24,6 +24,7 @@ import scala.io.Source
import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
+import org.apache.hadoop.mapred.FileAlreadyExistsException
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
@@ -208,4 +209,25 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(rdd.count() === 3)
assert(rdd.count() === 3)
}
+
+ test ("prevent user from overwriting the empty directory") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ var randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsTextFile(tempdir.getPath)
+ }
+ }
+
+ test ("prevent user from overwriting the non-empty directory") {
+ sc = new SparkContext("local", "test")
+ val tempdir = Files.createTempDir()
+ var randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
+ randomRDD.saveAsTextFile(tempdir.getPath + "/output")
+ assert(new File(tempdir.getPath + "/output/part-00000").exists() === true)
+ randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
+ intercept[FileAlreadyExistsException] {
+ randomRDD.saveAsTextFile(tempdir.getPath + "/output")
+ }
+ }
}
Please sign in to comment.
Something went wrong with that request. Please try again.