From c28f520ec2e77c6a5f7139b5131182024eddd1be Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 26 Sep 2014 13:56:50 -0700 Subject: [PATCH] support updateStateByKey --- python/pyspark/streaming/dstream.py | 30 +++++++++---- python/pyspark/streaming/tests.py | 19 ++++++++ python/pyspark/streaming/util.py | 11 ++--- .../streaming/api/python/PythonDStream.scala | 44 ++++++++++++++++--- 4 files changed, 83 insertions(+), 21 deletions(-) diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index 38bb54f25eaa2..27e1400b8ba0b 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -366,8 +366,9 @@ def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration, numPartitions=None): reduced = self.reduceByKey(func) - def reduceFunc(a, t): - return a.reduceByKey(func, numPartitions) + def reduceFunc(a, b, t): + b = b.reduceByKey(func, numPartitions) + return a.union(b).reduceByKey(func, numPartitions) if a else b def invReduceFunc(a, b, t): b = b.reduceByKey(func, numPartitions) @@ -378,19 +379,30 @@ def invReduceFunc(a, b, t): windowDuration = Seconds(windowDuration) if not isinstance(slideDuration, Duration): slideDuration = Seconds(slideDuration) - serializer = reduced._jrdd_deserializer - jreduceFunc = RDDFunction(self.ctx, reduceFunc, reduced._jrdd_deserializer) + jreduceFunc = RDDFunction2(self.ctx, reduceFunc, reduced._jrdd_deserializer) jinvReduceFunc = RDDFunction2(self.ctx, invReduceFunc, reduced._jrdd_deserializer) dstream = self.ctx._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(), jreduceFunc, jinvReduceFunc, windowDuration._jduration, slideDuration._jduration) - return DStream(dstream.asJavaDStream(), self._ssc, serializer) + return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) + + def updateStateByKey(self, updateFunc, numPartitions=None): + """ + :param updateFunc: [(k, vs, s)] -> [(k, s)] + """ + def reduceFunc(a, b, t): + if a is None: + g = b.groupByKey(numPartitions).map(lambda (k, vs): (k, list(vs), None)) + else: + g = a.cogroup(b).map(lambda (k, (va, vb)): + (k, list(vb), list(va)[0] if len(va) else None)) + return g.mapPartitions(lambda x: updateFunc(x) or []) - def updateStateByKey(self, updateFunc): - # FIXME: convert updateFunc to java JFunction2 - jFunc = updateFunc - return self._jdstream.updateStateByKey(jFunc) + jreduceFunc = RDDFunction2(self.ctx, reduceFunc, + self.ctx.serializer, self._jrdd_deserializer) + dstream = self.ctx._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc) + return DStream(dstream.asJavaDStream(), self._ssc, self.ctx.serializer) class TransformedDStream(DStream): diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index aa20b7efbee46..755ea224e56da 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -294,6 +294,25 @@ def func(dstream): [('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]] self._test_func(input, func, expected) + def update_state_by_key(self): + + def updater(it): + for k, vs, s in it: + if not s: + s = vs + else: + s.extend(vs) + yield (k, s) + + input = [[('k', i)] for i in range(5)] + + def func(dstream): + return dstream.updateStateByKey(updater) + + expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]] + expected = [[('k', v)] for v in expected] + self._test_func(input, func, expected) + class TestStreamingContext(unittest.TestCase): def setUp(self): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 4051732f25302..fdbd01ec1766d 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -50,15 +50,16 @@ class RDDFunction2(object): This class is for py4j callback. This class is related with org.apache.spark.streaming.api.python.PythonRDDFunction2. """ - def __init__(self, ctx, func, jrdd_deserializer): + def __init__(self, ctx, func, jrdd_deserializer, jrdd_deserializer2=None): self.ctx = ctx self.func = func - self.deserializer = jrdd_deserializer + self.jrdd_deserializer = jrdd_deserializer + self.jrdd_deserializer2 = jrdd_deserializer2 or jrdd_deserializer def call(self, jrdd, jrdd2, milliseconds): try: - rdd = RDD(jrdd, self.ctx, self.deserializer) if jrdd else None - other = RDD(jrdd2, self.ctx, self.deserializer) if jrdd2 else None + rdd = RDD(jrdd, self.ctx, self.jrdd_deserializer) if jrdd else None + other = RDD(jrdd2, self.ctx, self.jrdd_deserializer2) if jrdd2 else None r = self.func(rdd, other, milliseconds) if r: return r._jrdd @@ -67,7 +68,7 @@ def call(self, jrdd, jrdd2, milliseconds): traceback.print_exc() def __repr__(self): - return "RDDFunction(%s, %s)" % (str(self.deserializer), str(self.func)) + return "RDDFunction2(%s)" % (str(self.func)) class Java: implements = ['org.apache.spark.streaming.api.python.PythonRDDFunction2'] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 689c04fa49135..b904e273eb438 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -118,7 +118,7 @@ private[spark] class PythonTransformed2DStream (parent: DStream[_], parent2: DSt private[spark] class PythonReducedWindowedDStream( parent: DStream[Array[Byte]], - reduceFunc: PythonRDDFunction, + reduceFunc: PythonRDDFunction2, invReduceFunc: PythonRDDFunction2, _windowDuration: Duration, _slideDuration: Duration @@ -149,10 +149,6 @@ class PythonReducedWindowedDStream( override def parentRememberDuration: Duration = rememberDuration + windowDuration override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { - None - val reduceF = reduceFunc - val invReduceF = invReduceFunc - val currentTime = validTime val currentWindow = new Interval(currentTime - windowDuration + parent.slideDuration, currentTime) @@ -196,7 +192,7 @@ class PythonReducedWindowedDStream( parent.slice(previousWindow.endTime, currentWindow.endTime - parent.slideDuration) if (newRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(newRDDs).union(subbed)), validTime.milliseconds)) + Some(reduceFunc.call(JavaRDD.fromRDD(subbed), JavaRDD.fromRDD(ssc.sc.union(newRDDs)), validTime.milliseconds)) } else { Some(subbed) } @@ -205,7 +201,7 @@ class PythonReducedWindowedDStream( val currentRDDs = parent.slice(currentWindow.beginTime, currentWindow.endTime - parent.slideDuration) if (currentRDDs.size > 0) { - Some(reduceFunc.call(JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) + Some(reduceFunc.call(null, JavaRDD.fromRDD(ssc.sc.union(currentRDDs)), validTime.milliseconds)) } else { None } @@ -216,6 +212,40 @@ class PythonReducedWindowedDStream( } +/** + * Copied from ReducedWindowedDStream + */ +private[spark] +class PythonStateDStream( + parent: DStream[Array[Byte]], + reduceFunc: PythonRDDFunction2 + ) extends DStream[Array[Byte]](parent.ssc) { + + super.persist(StorageLevel.MEMORY_ONLY) + + override def dependencies = List(parent) + + override def slideDuration: Duration = parent.slideDuration + + override val mustCheckpoint = true + + override def compute(validTime: Time): Option[RDD[Array[Byte]]] = { + val lastState = getOrCompute(validTime - slideDuration) + val newRDD = parent.getOrCompute(validTime) + if (newRDD.isDefined) { + if (lastState.isDefined) { + Some(reduceFunc.call(JavaRDD.fromRDD(lastState.get), JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) + } else { + Some(reduceFunc.call(null, JavaRDD.fromRDD(newRDD.get), validTime.milliseconds)) + } + } else { + lastState + } + } + + val asJavaDStream = JavaDStream.fromDStream(this) +} + /** * This is used for foreachRDD() in Python */