diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 1b6ad249bc5c9..80490fecd978b 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -40,7 +40,7 @@ import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, ByteInputFormat, FixedLengthBinaryInputFormat} +import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.scheduler._ @@ -510,27 +510,6 @@ class SparkContext(config: SparkConf) extends Logging { minPartitions).setName(path) } - /** - * Get an RDD for a Hadoop-readable dataset as byte-streams for each file - * (useful for binary data) - * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * - * @note Small files are preferred, large file is also allowable, but may cause bad performance. - */ - def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): - RDD[(String, Array[Byte])] = { - val job = new NewHadoopJob(hadoopConfiguration) - NewFileInputFormat.addInputPath(job, new Path(path)) - val updateConf = job.getConfiguration - new BinaryFileRDD( - this, - classOf[ByteInputFormat], - classOf[String], - classOf[Array[Byte]], - updateConf, - minPartitions).setName(path) - } /** * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file @@ -543,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Small files are preferred, large file is also allowable, but may cause bad performance. */ @DeveloperApi - def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions): + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = { val job = new NewHadoopJob(hadoopConfiguration) NewFileInputFormat.addInputPath(job, new Path(path)) @@ -563,10 +542,17 @@ class SparkContext(config: SparkConf) extends Logging { * bytes per record is constant (see FixedLengthBinaryInputFormat) * * @param path Directory to the input data files + * @param recordLength The length at which to split the records * @return An RDD of data with values, RDD[(Array[Byte])] */ - def binaryRecords(path: String): RDD[Array[Byte]] = { - val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path) + def binaryRecords(path: String, recordLength: Int, + conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = { + conf.setInt("recordLength",recordLength) + val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path, + classOf[FixedLengthBinaryInputFormat], + classOf[LongWritable], + classOf[BytesWritable], + conf=conf) val data = br.map{ case (k, v) => v.getBytes} data } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index dcaaccb437168..8c869f7f70f65 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -36,7 +36,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ -import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam} +import org.apache.spark.SparkContext._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, RDD} @@ -256,8 +256,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions): - JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.dataStreamFiles(path,minPartitions)) + def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): + JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path,minPartitions)) /** * Read a directory of files as DataInputStream from HDFS, @@ -288,8 +288,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ - def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions): - JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions)) + def binaryArrays(path: String, minPartitions: Int = defaultMinPartitions): + JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions).mapValues(_.toArray())) /** * Load data from a flat binary file, assuming each record is a set of numbers @@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * @param path Directory to the input data files * @return An RDD of data with values, JavaRDD[(Array[Byte])] */ - def binaryRecords(path: String): JavaRDD[Array[Byte]] = { - new JavaRDD(sc.binaryRecords(path)) + def binaryRecords(path: String,recordLength: Int): JavaRDD[Array[Byte]] = { + new JavaRDD(sc.binaryRecords(path,recordLength)) } /** Get an RDD for a Hadoop SequenceFile with given key and value types. diff --git a/core/src/main/scala/org/apache/spark/input/RawFileInput.scala b/core/src/main/scala/org/apache/spark/input/RawFileInput.scala index d820e0efcea4a..0ff5127e5df0e 100644 --- a/core/src/main/scala/org/apache/spark/input/RawFileInput.scala +++ b/core/src/main/scala/org/apache/spark/input/RawFileInput.scala @@ -46,7 +46,7 @@ abstract class StreamFileInputFormat[T] if (file.isDir) 0L else file.getLen }.sum - val maxSplitSize = Math.ceil(totalLen*1.0/files.length).toLong + val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong super.setMaxSplitSize(maxSplitSize) } @@ -61,8 +61,10 @@ abstract class StreamFileInputFormat[T] */ class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, index: Integer) extends Serializable { - + // transient forces file to be reopened after being moved (serialization) + @transient private var fileIn: FSDataInputStream = null.asInstanceOf[FSDataInputStream] + @transient private var isOpen = false /** * Calculate the path name independently of opening the file @@ -76,13 +78,25 @@ class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, i * create a new DataInputStream from the split and context */ def open(): FSDataInputStream = { - val pathp = split.getPath(index) - val fs = pathp.getFileSystem(context.getConfiguration) - fileIn = fs.open(pathp) - isOpen=true + if (!isOpen) { + val pathp = split.getPath(index) + val fs = pathp.getFileSystem(context.getConfiguration) + fileIn = fs.open(pathp) + isOpen=true + } fileIn } + /** + * Read the file as a byte array + */ + def toArray(): Array[Byte] = { + open() + val innerBuffer = ByteStreams.toByteArray(fileIn) + close() + innerBuffer + } + /** * close the file (if it is already open) */ @@ -131,7 +145,7 @@ abstract class StreamBasedRecordReader[T]( override def nextKeyValue = { if (!processed) { - val fileIn = new PortableDataStream(split,context,index) + val fileIn = new PortableDataStream(split, context, index) value = parseStream(fileIn) fileIn.close() // if it has not been open yet, close does nothing key = fileIn.getPath @@ -157,7 +171,7 @@ private[spark] class StreamRecordReader( split: CombineFileSplit, context: TaskAttemptContext, index: Integer) - extends StreamBasedRecordReader[PortableDataStream](split,context,index) { + extends StreamBasedRecordReader[PortableDataStream](split, context, index) { def parseStream(inStream: PortableDataStream): PortableDataStream = inStream } @@ -170,7 +184,7 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)= { new CombineFileRecordReader[String,PortableDataStream]( - split.asInstanceOf[CombineFileSplit],taContext,classOf[StreamRecordReader] + split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader] ) } } @@ -193,29 +207,4 @@ abstract class BinaryRecordReader[T]( parseByteArray(innerBuffer) } def parseByteArray(inArray: Array[Byte]): T -} - - - -private[spark] class ByteRecordReader( - split: CombineFileSplit, - context: TaskAttemptContext, - index: Integer) - extends BinaryRecordReader[Array[Byte]](split,context,index) { - - override def parseByteArray(inArray: Array[Byte]) = inArray -} - -/** - * A class for reading the file using the BinaryRecordReader (as Byte array) - */ -private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] { - override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)= - { - new CombineFileRecordReader[String,Array[Byte]]( - split.asInstanceOf[CombineFileSplit],taContext,classOf[ByteRecordReader] - ) - } -} - - +} \ No newline at end of file diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 96865d7b736b1..b4dc8de323c1d 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -23,6 +23,7 @@ import java.net.URI; import java.util.*; +import org.apache.spark.input.PortableDataStream; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -852,12 +853,68 @@ public void binaryFiles() throws Exception { FileChannel channel1 = fos1.getChannel(); ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); channel1.write(bbuf); + channel1.close(); + JavaPairRDD readRDD = sc.binaryFiles(tempDirName,3); + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + + @Test + public void binaryFilesCaching() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark is easy to use.\n".getBytes("utf-8"); + + + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1); + channel1.write(bbuf); + channel1.close(); + + JavaPairRDD readRDD = sc.binaryFiles(tempDirName,3).cache(); + readRDD.foreach(new VoidFunction>() { + @Override + public void call(Tuple2 stringPortableDataStreamTuple2) throws Exception { + stringPortableDataStreamTuple2._2().getPath(); + stringPortableDataStreamTuple2._2().toArray(); // force the file to read + } + }); + + List> result = readRDD.collect(); + for (Tuple2 res : result) { + Assert.assertArrayEquals(content1, res._2().toArray()); + } + } + @Test + public void binaryRecords() throws Exception { + // Reusing the wholeText files example + byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8"); + int numOfCopies = 10; + String tempDirName = tempDir.getAbsolutePath(); + File file1 = new File(tempDirName + "/part-00000"); + + FileOutputStream fos1 = new FileOutputStream(file1); + + FileChannel channel1 = fos1.getChannel(); + + for (int i=0;i readRDD = sc.binaryFiles(tempDirName,3); - List> result = readRDD.collect(); - for (Tuple2 res : result) { - Assert.assertArrayEquals(content1, res._2()); + JavaRDD readRDD = sc.binaryRecords(tempDirName,content1.length); + Assert.assertEquals(numOfCopies,readRDD.count()); + List result = readRDD.collect(); + for (byte[] res : result) { + Assert.assertArrayEquals(content1, res); } } diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 2bb6d809cc174..d2913c7d2d479 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark import java.io.{File, FileWriter} +import org.apache.spark.input.PortableDataStream + import scala.io.Source import com.google.common.io.Files @@ -240,35 +242,69 @@ class FileSuite extends FunSuite with LocalSparkContext { file.close() val inRdd = sc.binaryFiles(outFileName) - val (infile: String, indata: Array[Byte]) = inRdd.first + val (infile: String, indata: PortableDataStream) = inRdd.first // Try reading the output back as an object file assert(infile === outFileName) - assert(indata === testOutput) + assert(indata.toArray === testOutput) + } + + test("portabledatastream caching tests") { + sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") + val outFileName = outFile.getAbsolutePath() + + // create file + val testOutput = Array[Byte](1,2,3,4,5,6) + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file + val file = new java.io.FileOutputStream(outFile) + val channel = file.getChannel + channel.write(bbuf) + channel.close() + file.close() + + val inRdd = sc.binaryFiles(outFileName).cache() + inRdd.foreach{ + curData: (String, PortableDataStream) => + curData._2.toArray() // force the file to read + } + val mappedRdd = inRdd.map{ + curData: (String, PortableDataStream) => + (curData._2.getPath(),curData._2) + } + val (infile: String, indata: PortableDataStream) = mappedRdd.first + + // Try reading the output back as an object file + + assert(indata.toArray === testOutput) } test("fixed record length binary file as byte array") { // a fixed length of 6 bytes sc = new SparkContext("local", "test") + val outFile = new File(tempDir, "record-bytestream-00000.bin") val outFileName = outFile.getAbsolutePath() // create file val testOutput = Array[Byte](1,2,3,4,5,6) val testOutputCopies = 10 - val bbuf = java.nio.ByteBuffer.wrap(testOutput) + // write data to file val file = new java.io.FileOutputStream(outFile) val channel = file.getChannel - for(i <- 1 to testOutputCopies) channel.write(bbuf) + for(i <- 1 to testOutputCopies) { + val bbuf = java.nio.ByteBuffer.wrap(testOutput) + channel.write(bbuf) + } channel.close() file.close() - sc.hadoopConfiguration.setInt("recordLength",testOutput.length) - val inRdd = sc.binaryRecords(outFileName) + val inRdd = sc.binaryRecords(outFileName, testOutput.length) // make sure there are enough elements - assert(inRdd.count== testOutputCopies) + assert(inRdd.count == testOutputCopies) // now just compare the first one val indata: Array[Byte] = inRdd.first