Skip to content

Commit

Permalink
Better handle rows that break across splits, and other small related …
Browse files Browse the repository at this point in the history
…fixes (#400)

This attempt to address #398
See also #399

The change is I believe explained in comments below.
  • Loading branch information
srowen authored and HyukjinKwon committed Aug 5, 2019
1 parent 41d0d17 commit 8bc9621
Show file tree
Hide file tree
Showing 9 changed files with 3,751 additions and 7 deletions.
5 changes: 5 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Apache Commons IO
Copyright 2002-2019 The Apache Software Foundation

This product includes software developed at
The Apache Software Foundation (http://www.apache.org/).
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ sparkComponents := Seq("core", "sql")
autoScalaLibrary := false

libraryDependencies ++= Seq(
"commons-io" % "commons-io" % "2.6",
"org.slf4j" % "slf4j-api" % "1.7.25" % Provided,
"org.scalatest" %% "scalatest" % "3.0.3" % Test,
"com.novocode" % "junit-interface" % "0.11" % Test,
Expand Down
60 changes: 53 additions & 7 deletions src/main/scala/com/databricks/spark/xml/XmlInputFormat.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
package com.databricks.spark.xml

import java.io.{IOException, InputStream, InputStreamReader, Reader}
import java.nio.ByteBuffer
import java.nio.charset.Charset

import org.apache.commons.io.input.CountingInputStream
import org.apache.hadoop.fs.Seekable
import org.apache.hadoop.io.compress._
import org.apache.hadoop.io.{LongWritable, Text}
Expand Down Expand Up @@ -47,7 +49,9 @@ object XmlInputFormat {

/**
* XMLRecordReader class to read through a given xml document to output xml blocks as records
* as specified by the start tag and end tag
* as specified by the start tag and end tag.
*
* This implementation is ultimately loosely based on LineRecordReader in Hadoop.
*/
private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {

Expand All @@ -60,6 +64,8 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
private var end: Long = _
private var reader: Reader = _
private var filePosition: Seekable = _
private var countingIn: CountingInputStream = _
private var readerByteBuffer: ByteBuffer = _
private var decompressor: Decompressor = _
private var buffer = new StringBuilder()

Expand Down Expand Up @@ -106,11 +112,45 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
filePosition = fsin
}
} else {
in = fsin
filePosition = fsin
filePosition.seek(start)
fsin.seek(start)
countingIn = new CountingInputStream(fsin)
in = countingIn
// don't use filePosition in this case. We have to count bytes read manually
}

reader = new InputStreamReader(in, charset)

if (codec == null) {
// Hack: in the uncompressed case (see more below), we must know how much the
// InputStreamReader has buffered but not processed
// to accurately assess how many bytes have been processed
val sdField = reader.getClass.getDeclaredField("sd")
sdField.setAccessible(true)
val sd = sdField.get(reader)
val bbField = sd.getClass.getDeclaredField("bb")
bbField.setAccessible(true)
readerByteBuffer = bbField.get(sd).asInstanceOf[ByteBuffer]
}
}

/**
* Tries to determine how many bytes of the underlying split have been read. There are two
* distinct cases.
*
* For compressed input, it attempts to read the current position read in the compressed input
* stream. This logic is copied from LineRecordReader, essentially.
*
* For uncompressed input, it counts the number of bytes read directly from the split. It
* further compensates for the fact that the intervening InputStreamReader buffers input and
* accounts for data it has read but not yet returned.
*/
private def getFilePosition(): Long = {
// filePosition != null when input is compressed
if (filePosition != null) {
filePosition.getPos
} else {
start + countingIn.getByteCount - readerByteBuffer.remaining()
}
}

override def nextKeyValue: Boolean = {
Expand All @@ -132,7 +172,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
try {
buffer.append(currentStartTag)
if (readUntilEndElement(currentStartTag.endsWith(">"))) {
key.set(filePosition.getPos)
key.set(getFilePosition())
value.set(buffer.toString())
return true
}
Expand All @@ -148,7 +188,7 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
var i = 0
while (true) {
val cOrEOF = reader.read()
if (cOrEOF == -1 || (i == 0 && filePosition.getPos > end)) {
if (cOrEOF == -1 || (i == 0 && getFilePosition() > end)) {
// End of file or end of split.
return false
}
Expand Down Expand Up @@ -265,7 +305,13 @@ private[xml] class XmlRecordReader extends RecordReader[LongWritable, Text] {
false
}

override def getProgress: Float = (filePosition.getPos - start) / (end - start).toFloat
override def getProgress: Float = {
if (start == end) {
0.0f
} else {
math.min(1.0f, (getFilePosition() - start) / (end - start).toFloat)
}
}

override def getCurrentKey: LongWritable = currentKey

Expand Down
Loading

0 comments on commit 8bc9621

Please sign in to comment.