Skip to content

Commit

Permalink
fixed several scala-style issues, changed structure of binaryFiles, r…
Browse files Browse the repository at this point in the history
…emoved excessive classes added new tests. The caching tests still have a serialization issue, but that should be easily fixed as well.
  • Loading branch information
kmader committed Oct 1, 2014
1 parent 932a206 commit 238c83c
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 78 deletions.
36 changes: 11 additions & 25 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, ByteInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
Expand Down Expand Up @@ -510,27 +510,6 @@ class SparkContext(config: SparkConf) extends Logging {
minPartitions).setName(path)
}

/**
* Get an RDD for a Hadoop-readable dataset as byte-streams for each file
* (useful for binary data)
*
* @param minPartitions A suggestion value of the minimal splitting number for input data.
*
* @note Small files are preferred, large file is also allowable, but may cause bad performance.
*/
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, Array[Byte])] = {
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
new BinaryFileRDD(
this,
classOf[ByteInputFormat],
classOf[String],
classOf[Array[Byte]],
updateConf,
minPartitions).setName(path)
}

/**
* Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
Expand All @@ -543,7 +522,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @note Small files are preferred, large file is also allowable, but may cause bad performance.
*/
@DeveloperApi
def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions):
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
Expand All @@ -563,10 +542,17 @@ class SparkContext(config: SparkConf) extends Logging {
* bytes per record is constant (see FixedLengthBinaryInputFormat)
*
* @param path Directory to the input data files
* @param recordLength The length at which to split the records
* @return An RDD of data with values, RDD[(Array[Byte])]
*/
def binaryRecords(path: String): RDD[Array[Byte]] = {
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path)
def binaryRecords(path: String, recordLength: Int,
conf: Configuration = hadoopConfiguration): RDD[Array[Byte]] = {
conf.setInt("recordLength",recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
classOf[LongWritable],
classOf[BytesWritable],
conf=conf)
val data = br.map{ case (k, v) => v.getBytes}
data
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}

import org.apache.spark._
import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
import org.apache.spark.SparkContext._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, RDD}
Expand Down Expand Up @@ -256,8 +256,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
*
* @param minPartitions A suggestion value of the minimal splitting number for input data.
*/
def dataStreamFiles(path: String, minPartitions: Int = defaultMinPartitions):
JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.dataStreamFiles(path,minPartitions))
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
JavaPairRDD[String,PortableDataStream] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))

/**
* Read a directory of files as DataInputStream from HDFS,
Expand Down Expand Up @@ -288,8 +288,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
*
* @param minPartitions A suggestion value of the minimal splitting number for input data.
*/
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions))
def binaryArrays(path: String, minPartitions: Int = defaultMinPartitions):
JavaPairRDD[String, Array[Byte]] = new JavaPairRDD(sc.binaryFiles(path,minPartitions).mapValues(_.toArray()))

/**
* Load data from a flat binary file, assuming each record is a set of numbers
Expand All @@ -299,8 +299,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* @param path Directory to the input data files
* @return An RDD of data with values, JavaRDD[(Array[Byte])]
*/
def binaryRecords(path: String): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.binaryRecords(path))
def binaryRecords(path: String,recordLength: Int): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.binaryRecords(path,recordLength))
}

/** Get an RDD for a Hadoop SequenceFile with given key and value types.
Expand Down
59 changes: 24 additions & 35 deletions core/src/main/scala/org/apache/spark/input/RawFileInput.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ abstract class StreamFileInputFormat[T]
if (file.isDir) 0L else file.getLen
}.sum

val maxSplitSize = Math.ceil(totalLen*1.0/files.length).toLong
val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
super.setMaxSplitSize(maxSplitSize)
}

Expand All @@ -61,8 +61,10 @@ abstract class StreamFileInputFormat[T]
*/
class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, index: Integer)
extends Serializable {

// transient forces file to be reopened after being moved (serialization)
@transient
private var fileIn: FSDataInputStream = null.asInstanceOf[FSDataInputStream]
@transient
private var isOpen = false
/**
* Calculate the path name independently of opening the file
Expand All @@ -76,13 +78,25 @@ class PortableDataStream(split: CombineFileSplit, context: TaskAttemptContext, i
* create a new DataInputStream from the split and context
*/
def open(): FSDataInputStream = {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(context.getConfiguration)
fileIn = fs.open(pathp)
isOpen=true
if (!isOpen) {
val pathp = split.getPath(index)
val fs = pathp.getFileSystem(context.getConfiguration)
fileIn = fs.open(pathp)
isOpen=true
}
fileIn
}

/**
* Read the file as a byte array
*/
def toArray(): Array[Byte] = {
open()
val innerBuffer = ByteStreams.toByteArray(fileIn)
close()
innerBuffer
}

/**
* close the file (if it is already open)
*/
Expand Down Expand Up @@ -131,7 +145,7 @@ abstract class StreamBasedRecordReader[T](

override def nextKeyValue = {
if (!processed) {
val fileIn = new PortableDataStream(split,context,index)
val fileIn = new PortableDataStream(split, context, index)
value = parseStream(fileIn)
fileIn.close() // if it has not been open yet, close does nothing
key = fileIn.getPath
Expand All @@ -157,7 +171,7 @@ private[spark] class StreamRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
extends StreamBasedRecordReader[PortableDataStream](split,context,index) {
extends StreamBasedRecordReader[PortableDataStream](split, context, index) {

def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
}
Expand All @@ -170,7 +184,7 @@ private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDat
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
{
new CombineFileRecordReader[String,PortableDataStream](
split.asInstanceOf[CombineFileSplit],taContext,classOf[StreamRecordReader]
split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader]
)
}
}
Expand All @@ -193,29 +207,4 @@ abstract class BinaryRecordReader[T](
parseByteArray(innerBuffer)
}
def parseByteArray(inArray: Array[Byte]): T
}



private[spark] class ByteRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
extends BinaryRecordReader[Array[Byte]](split,context,index) {

override def parseByteArray(inArray: Array[Byte]) = inArray
}

/**
* A class for reading the file using the BinaryRecordReader (as Byte array)
*/
private[spark] class ByteInputFormat extends StreamFileInputFormat[Array[Byte]] {
override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext)=
{
new CombineFileRecordReader[String,Array[Byte]](
split.asInstanceOf[CombineFileSplit],taContext,classOf[ByteRecordReader]
)
}
}


}
65 changes: 61 additions & 4 deletions core/src/test/java/org/apache/spark/JavaAPISuite.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.net.URI;
import java.util.*;

import org.apache.spark.input.PortableDataStream;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
Expand Down Expand Up @@ -852,12 +853,68 @@ public void binaryFiles() throws Exception {
FileChannel channel1 = fos1.getChannel();
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
channel1.write(bbuf);
channel1.close();
JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName,3);
List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
for (Tuple2<String, PortableDataStream> res : result) {
Assert.assertArrayEquals(content1, res._2().toArray());
}
}

@Test
public void binaryFilesCaching() throws Exception {
// Reusing the wholeText files example
byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");


String tempDirName = tempDir.getAbsolutePath();
File file1 = new File(tempDirName + "/part-00000");

FileOutputStream fos1 = new FileOutputStream(file1);

FileChannel channel1 = fos1.getChannel();
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
channel1.write(bbuf);
channel1.close();

JavaPairRDD<String, PortableDataStream> readRDD = sc.binaryFiles(tempDirName,3).cache();
readRDD.foreach(new VoidFunction<Tuple2<String,PortableDataStream>>() {
@Override
public void call(Tuple2<String, PortableDataStream> stringPortableDataStreamTuple2) throws Exception {
stringPortableDataStreamTuple2._2().getPath();
stringPortableDataStreamTuple2._2().toArray(); // force the file to read
}
});

List<Tuple2<String, PortableDataStream>> result = readRDD.collect();
for (Tuple2<String, PortableDataStream> res : result) {
Assert.assertArrayEquals(content1, res._2().toArray());
}
}

@Test
public void binaryRecords() throws Exception {
// Reusing the wholeText files example
byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8");
int numOfCopies = 10;
String tempDirName = tempDir.getAbsolutePath();
File file1 = new File(tempDirName + "/part-00000");

FileOutputStream fos1 = new FileOutputStream(file1);

FileChannel channel1 = fos1.getChannel();

for (int i=0;i<numOfCopies;i++) {
ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
channel1.write(bbuf);
}
channel1.close();

JavaPairRDD<String, byte[]> readRDD = sc.binaryFiles(tempDirName,3);
List<Tuple2<String, byte[]>> result = readRDD.collect();
for (Tuple2<String, byte[]> res : result) {
Assert.assertArrayEquals(content1, res._2());
JavaRDD<byte[]> readRDD = sc.binaryRecords(tempDirName,content1.length);
Assert.assertEquals(numOfCopies,readRDD.count());
List<byte[]> result = readRDD.collect();
for (byte[] res : result) {
Assert.assertArrayEquals(content1, res);
}
}

Expand Down
50 changes: 43 additions & 7 deletions core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark

import java.io.{File, FileWriter}

import org.apache.spark.input.PortableDataStream

import scala.io.Source

import com.google.common.io.Files
Expand Down Expand Up @@ -240,35 +242,69 @@ class FileSuite extends FunSuite with LocalSparkContext {
file.close()

val inRdd = sc.binaryFiles(outFileName)
val (infile: String, indata: Array[Byte]) = inRdd.first
val (infile: String, indata: PortableDataStream) = inRdd.first

// Try reading the output back as an object file
assert(infile === outFileName)
assert(indata === testOutput)
assert(indata.toArray === testOutput)
}

test("portabledatastream caching tests") {
sc = new SparkContext("local", "test")
val outFile = new File(tempDir, "record-bytestream-00000.bin")
val outFileName = outFile.getAbsolutePath()

// create file
val testOutput = Array[Byte](1,2,3,4,5,6)
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
// write data to file
val file = new java.io.FileOutputStream(outFile)
val channel = file.getChannel
channel.write(bbuf)
channel.close()
file.close()

val inRdd = sc.binaryFiles(outFileName).cache()
inRdd.foreach{
curData: (String, PortableDataStream) =>
curData._2.toArray() // force the file to read
}
val mappedRdd = inRdd.map{
curData: (String, PortableDataStream) =>
(curData._2.getPath(),curData._2)
}
val (infile: String, indata: PortableDataStream) = mappedRdd.first

// Try reading the output back as an object file

assert(indata.toArray === testOutput)
}

test("fixed record length binary file as byte array") {
// a fixed length of 6 bytes

sc = new SparkContext("local", "test")

val outFile = new File(tempDir, "record-bytestream-00000.bin")
val outFileName = outFile.getAbsolutePath()

// create file
val testOutput = Array[Byte](1,2,3,4,5,6)
val testOutputCopies = 10
val bbuf = java.nio.ByteBuffer.wrap(testOutput)

// write data to file
val file = new java.io.FileOutputStream(outFile)
val channel = file.getChannel
for(i <- 1 to testOutputCopies) channel.write(bbuf)
for(i <- 1 to testOutputCopies) {
val bbuf = java.nio.ByteBuffer.wrap(testOutput)
channel.write(bbuf)
}
channel.close()
file.close()
sc.hadoopConfiguration.setInt("recordLength",testOutput.length)

val inRdd = sc.binaryRecords(outFileName)
val inRdd = sc.binaryRecords(outFileName, testOutput.length)
// make sure there are enough elements
assert(inRdd.count== testOutputCopies)
assert(inRdd.count == testOutputCopies)

// now just compare the first one
val indata: Array[Byte] = inRdd.first
Expand Down

0 comments on commit 238c83c

Please sign in to comment.