Skip to content

Commit

Permalink
[SPARK-23092][SQL] Migrate MemoryStream to DataSourceV2 APIs
Browse files Browse the repository at this point in the history
This PR migrates the MemoryStream to DataSourceV2 APIs.

One additional change is in the reported keys in StreamingQueryProgress.durationMs. "getOffset" and "getBatch" replaced with "setOffsetRange" and "getEndOffset" as tracking these make more sense. Unit tests changed accordingly.

Existing unit tests, few updated unit tests.

Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Burak Yavuz <brkyvz@gmail.com>

Closes apache#20445 from tdas/SPARK-23092.

Ref: LIHADOOP-48531

RB=1832973
G=superfriends-reviewers
R=latang,yezhou,zolin,mshen,fli
A=
  • Loading branch information
tdas authored and otterc committed Oct 11, 2019
1 parent 575dac8 commit 0a07ea0
Show file tree
Hide file tree
Showing 9 changed files with 171 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql.execution.streaming

import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2}

/**
* A simple offset for sources that produce a single linear stream of data.
*/
case class LongOffset(offset: Long) extends Offset {
case class LongOffset(offset: Long) extends OffsetV2 {

override val json = offset.toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -269,16 +269,17 @@ class MicroBatchExecution(
}
case s: MicroBatchReader =>
updateStatusMessage(s"Getting offsets from $s")
reportTimeTaken("getOffset") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())

(s, Some(s.getEndOffset))
reportTimeTaken("setOffsetRange") {
// Once v1 streaming source execution is gone, we can refactor this away.
// For now, we set the range here to get the source to infer the available end offset,
// get that offset, and then set the range again when we later execute.
s.setOffsetRange(
toJava(availableOffsets.get(s).map(off => s.deserializeOffset(off.json))),
Optional.empty())
}

val currentOffset = reportTimeTaken("getEndOffset") { s.getEndOffset() }
(s, Option(currentOffset))
}.toMap
availableOffsets ++= latestOffsets.filter { case (_, o) => o.nonEmpty }.mapValues(_.get)

Expand Down Expand Up @@ -400,10 +401,14 @@ class MicroBatchExecution(
case (reader: MicroBatchReader, available)
if committedOffsets.get(reader).map(_ != available).getOrElse(true) =>
val current = committedOffsets.get(reader).map(off => reader.deserializeOffset(off.json))
val availableV2: OffsetV2 = available match {
case v1: SerializedOffset => reader.deserializeOffset(v1.json)
case v2: OffsetV2 => v2
}
reader.setOffsetRange(
toJava(current),
Optional.of(available.asInstanceOf[OffsetV2]))
logDebug(s"Retrieving data from $reader: $current -> $available")
Optional.of(availableV2))
logDebug(s"Retrieving data from $reader: $current -> $availableV2")
Some(reader ->
new StreamingDataSourceV2Relation(reader.readSchema().toAttributes, reader))
case _ => None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@

package org.apache.spark.sql.execution.streaming

import java.{util => ju}
import java.util.Optional
import java.util.concurrent.atomic.AtomicInteger
import javax.annotation.concurrent.GuardedBy

import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, ListBuffer}
import scala.util.control.NonFatal

import org.apache.spark.internal.Logging
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, Statistics}
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.sources.v2.reader.{DataReader, DataReaderFactory, SupportsScanUnsafeRow}
import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchReader, Offset => OffsetV2}
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
Expand All @@ -51,30 +53,35 @@ object MemoryStream {
* available.
*/
case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
extends Source with Logging {
extends MicroBatchReader with SupportsScanUnsafeRow with Logging {
protected val encoder = encoderFor[A]
protected val logicalPlan = StreamingExecutionRelation(this, sqlContext.sparkSession)
private val attributes = encoder.schema.toAttributes
protected val logicalPlan = StreamingExecutionRelation(this, attributes)(sqlContext.sparkSession)
protected val output = logicalPlan.output

/**
* All batches from `lastCommittedOffset + 1` to `currentOffset`, inclusive.
* Stored in a ListBuffer to facilitate removing committed batches.
*/
@GuardedBy("this")
protected val batches = new ListBuffer[Dataset[A]]
protected val batches = new ListBuffer[Array[UnsafeRow]]

@GuardedBy("this")
protected var currentOffset: LongOffset = new LongOffset(-1)

@GuardedBy("this")
private var startOffset = new LongOffset(-1)

@GuardedBy("this")
private var endOffset = new LongOffset(-1)

/**
* Last offset that was discarded, or -1 if no commits have occurred. Note that the value
* -1 is used in calculations below and isn't just an arbitrary constant.
*/
@GuardedBy("this")
protected var lastOffsetCommitted : LongOffset = new LongOffset(-1)

def schema: StructType = encoder.schema

def toDS(): Dataset[A] = {
Dataset(sqlContext.sparkSession, logicalPlan)
}
Expand All @@ -88,72 +95,69 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)
}

def addData(data: TraversableOnce[A]): Offset = {
val encoded = data.toVector.map(d => encoder.toRow(d).copy())
val plan = new LocalRelation(schema.toAttributes, encoded, isStreaming = true)
val ds = Dataset[A](sqlContext.sparkSession, plan)
logDebug(s"Adding ds: $ds")
val objects = data.toSeq
val rows = objects.iterator.map(d => encoder.toRow(d).copy().asInstanceOf[UnsafeRow]).toArray
logDebug(s"Adding: $objects")
this.synchronized {
currentOffset = currentOffset + 1
batches += ds
batches += rows
currentOffset
}
}

override def toString: String = s"MemoryStream[${Utils.truncatedString(output, ",")}]"

override def getOffset: Option[Offset] = synchronized {
if (currentOffset.offset == -1) {
None
} else {
Some(currentOffset)
override def setOffsetRange(start: Optional[OffsetV2], end: Optional[OffsetV2]): Unit = {
synchronized {
startOffset = start.orElse(LongOffset(-1)).asInstanceOf[LongOffset]
endOffset = end.orElse(currentOffset).asInstanceOf[LongOffset]
}
}

override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal =
start.flatMap(LongOffset.convert).getOrElse(LongOffset(-1)).offset.toInt + 1
val endOrdinal = LongOffset.convert(end).getOrElse(LongOffset(-1)).offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}
override def readSchema(): StructType = encoder.schema

if (newBlocks.isEmpty) {
return sqlContext.internalCreateDataFrame(
sqlContext.sparkContext.emptyRDD, schema, isStreaming = true)
}
override def deserializeOffset(json: String): OffsetV2 = LongOffset(json.toLong)

override def getStartOffset: OffsetV2 = synchronized {
if (startOffset.offset == -1) null else startOffset
}

logDebug(generateDebugString(newBlocks, startOrdinal, endOrdinal))
override def getEndOffset: OffsetV2 = synchronized {
if (endOffset.offset == -1) null else endOffset
}

newBlocks
.map(_.toDF())
.reduceOption(_ union _)
.getOrElse {
sys.error("No data selected!")
override def createUnsafeRowReaderFactories(): ju.List[DataReaderFactory[UnsafeRow]] = {
synchronized {
// Compute the internal batch numbers to fetch: [startOrdinal, endOrdinal)
val startOrdinal = startOffset.offset.toInt + 1
val endOrdinal = endOffset.offset.toInt + 1

// Internal buffer only holds the batches after lastCommittedOffset.
val newBlocks = synchronized {
val sliceStart = startOrdinal - lastOffsetCommitted.offset.toInt - 1
val sliceEnd = endOrdinal - lastOffsetCommitted.offset.toInt - 1
assert(sliceStart <= sliceEnd, s"sliceStart: $sliceStart sliceEnd: $sliceEnd")
batches.slice(sliceStart, sliceEnd)
}

logDebug(generateDebugString(newBlocks.flatten, startOrdinal, endOrdinal))

newBlocks.map { block =>
new MemoryStreamDataReaderFactory(block).asInstanceOf[DataReaderFactory[UnsafeRow]]
}.asJava
}
}

private def generateDebugString(
blocks: TraversableOnce[Dataset[A]],
rows: Seq[UnsafeRow],
startOrdinal: Int,
endOrdinal: Int): String = {
val originalUnsupportedCheck =
sqlContext.getConf("spark.sql.streaming.unsupportedOperationCheck")
try {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", "false")
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${blocks.flatMap(_.collect()).mkString(", ")}"
} finally {
sqlContext.setConf("spark.sql.streaming.unsupportedOperationCheck", originalUnsupportedCheck)
}
val fromRow = encoder.resolveAndBind().fromRow _
s"MemoryBatch [$startOrdinal, $endOrdinal]: " +
s"${rows.map(row => fromRow(row)).mkString(", ")}"
}

override def commit(end: Offset): Unit = synchronized {
override def commit(end: OffsetV2): Unit = synchronized {
def check(newOffset: LongOffset): Unit = {
val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt

Expand All @@ -176,11 +180,33 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext)

def reset(): Unit = synchronized {
batches.clear()
startOffset = LongOffset(-1)
endOffset = LongOffset(-1)
currentOffset = new LongOffset(-1)
lastOffsetCommitted = new LongOffset(-1)
}
}


class MemoryStreamDataReaderFactory(records: Array[UnsafeRow])
extends DataReaderFactory[UnsafeRow] {
override def createDataReader(): DataReader[UnsafeRow] = {
new DataReader[UnsafeRow] {
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the array.
currentIndex += 1
currentIndex < records.length
}

override def get(): UnsafeRow = records(currentIndex)

override def close(): Unit = {}
}
}
}

/**
* A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit
* tests and does not provide durability.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ case class RateStreamBatchTask(vals: Seq[(Long, Long)]) extends DataReaderFactor
}

class RateStreamBatchReader(vals: Seq[(Long, Long)]) extends DataReader[Row] {
var currentIndex = -1
private var currentIndex = -1

override def next(): Boolean = {
// Return true as long as the new index is in the seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,49 +46,34 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
.foreach(new TestForeachWriter())
.start()

// -- batch 0 ---------------------------------------
input.addData(1, 2, 3, 4)
query.processAllAvailable()
def verifyOutput(expectedVersion: Int, expectedData: Seq[Int]): Unit = {
import ForeachSinkSuite._

var expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 2),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
var expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 0),
ForeachSinkSuite.Process(value = 1),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)
val events = ForeachSinkSuite.allEvents()
assert(events.size === 2) // one seq of events for each of the 2 partitions

var allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
// Verify both seq of events have an Open event as the first event
assert(events.map(_.head).toSet === Set(0, 1).map(p => Open(p, expectedVersion)))

// Verify all the Process event correspond to the expected data
val allProcessEvents = events.flatMap(_.filter(_.isInstanceOf[Process[_]]))
assert(allProcessEvents.toSet === expectedData.map { data => Process(data) }.toSet)

// Verify both seq of events have a Close event as the last event
assert(events.map(_.last).toSet === Set(Close(None), Close(None)))
}

// -- batch 0 ---------------------------------------
ForeachSinkSuite.clear()
input.addData(1, 2, 3, 4)
query.processAllAvailable()
verifyOutput(expectedVersion = 0, expectedData = 1 to 4)

// -- batch 1 ---------------------------------------
ForeachSinkSuite.clear()
input.addData(5, 6, 7, 8)
query.processAllAvailable()

expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 1),
ForeachSinkSuite.Process(value = 5),
ForeachSinkSuite.Process(value = 7),
ForeachSinkSuite.Close(None)
)
expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 1),
ForeachSinkSuite.Process(value = 6),
ForeachSinkSuite.Process(value = 8),
ForeachSinkSuite.Close(None)
)

allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
verifyOutput(expectedVersion = 1, expectedData = 5 to 8)

query.stop()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -492,16 +492,16 @@ class StreamSuite extends StreamTest {

val explainWithoutExtended = q.explainInternal(false)
// `extended = false` only displays the physical plan.
assert("LocalRelation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("LocalTableScan".r.findAllMatchIn(explainWithoutExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithoutExtended).size === 0)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithoutExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithoutExtended.contains("StateStoreRestore"))

val explainWithExtended = q.explainInternal(true)
// `extended = true` displays 3 logical plans (Parsed/Optimized/Optimized) and 1 physical
// plan.
assert("LocalRelation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("LocalTableScan".r.findAllMatchIn(explainWithExtended).size === 1)
assert("StreamingDataSourceV2Relation".r.findAllMatchIn(explainWithExtended).size === 3)
assert("DataSourceV2Scan".r.findAllMatchIn(explainWithExtended).size === 1)
// Use "StateStoreRestore" to verify that it does output a streaming physical plan
assert(explainWithExtended.contains("StateStoreRestore"))
} finally {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be
case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"

override def addData(query: Option[StreamExecution]): (Source, Offset) = {
override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = {
(source, source.addData(data))
}
}
Expand Down
Loading

0 comments on commit 0a07ea0

Please sign in to comment.