Skip to content

Commit

Permalink
[SPARK-27870][SQL][PYSPARK] Flush batch timely for pandas UDF (for im…
Browse files Browse the repository at this point in the history
…proving pandas UDFs pipeline)

## What changes were proposed in this pull request?

Flush batch timely for pandas UDF.

This could improve performance when multiple pandas UDF plans are pipelined.

When batch being flushed in time, downstream pandas UDFs will get pipelined as soon as possible, and pipeline will help hide the donwstream UDFs computation time. For example:

When the first UDF start computing on batch-3, the second pipelined UDF can start computing on batch-2, and the third pipelined UDF can start computing on batch-1.

If we do not flush each batch in time, the donwstream UDF's pipeline will lag behind too much, which may increase the total processing time.

I add flush at two places:
* JVM process feed data into python worker. In jvm side, when write one batch, flush it
* VM process read data from python worker output, In python worker side, when write one batch, flush it

If no flush, the default buffer size for them are both 65536. Especially in the ML case, in order to make realtime prediction, we will make batch size very small. The buffer size is too large for the case, which cause downstream pandas UDF pipeline lag behind too much.

### Note
* This is only applied to pandas scalar UDF.
* Do not flush for each batch. The minimum interval between two flush is 0.1 second. This avoid too frequent flushing when batch size is small. It works like:
```
        last_flush_time = time.time()
        for batch in iterator:
                writer.write_batch(batch)
                flush_time = time.time()
                if self.flush_timely and (flush_time - last_flush_time > 0.1):
                      stream.flush()
                      last_flush_time = flush_time
```

## How was this patch tested?

### Benchmark to make sure the flush do not cause performance regression
#### Test code:
```
numRows = ...
batchSize = ...

spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', str(batchSize))
df = spark.range(1, numRows + 1, numPartitions=1).select(col('id').alias('a'))

pandas_udf("int", PandasUDFType.SCALAR)
def fp1(x):
    return x + 10

beg_time = time.time()
result = df.select(sum(fp1('a'))).head()
print("result: " + str(result[0]))
print("consume time: " + str(time.time() - beg_time))
```
#### Test Result:

 params        | Consume time (Before) | Consume time (After)
------------ | ----------------------- | ----------------------
numRows=100000000, batchSize=10000 | 23.43s | 24.64s
numRows=100000000, batchSize=1000 | 36.73s | 34.50s
numRows=10000000, batchSize=100 | 35.67s | 32.64s
numRows=1000000, batchSize=10 | 33.60s | 32.11s
numRows=100000, batchSize=1 | 33.36s | 31.82s

### Benchmark pipelined pandas UDF
#### Test code:
```
spark.conf.set('spark.sql.execution.arrow.maxRecordsPerBatch', '1')
df = spark.range(1, 31, numPartitions=1).select(col('id').alias('a'))

pandas_udf("int", PandasUDFType.SCALAR)
def fp1(x):
    print("run fp1")
    time.sleep(1)
    return x + 100

pandas_udf("int", PandasUDFType.SCALAR)
def fp2(x, y):
    print("run fp2")
    time.sleep(1)
    return x + y

beg_time = time.time()
result = df.select(sum(fp2(fp1('a'), col('a')))).head()
print("result: " + str(result[0]))
print("consume time: " + str(time.time() - beg_time))

```
#### Test Result:

**Before**: consume time: 63.57s
**After**: consume time: 32.43s
**So the PR improve performance by make downstream UDF get pipelined early.**

Please review https://spark.apache.org/contributing.html before opening a pull request.

Closes apache#24734 from WeichenXu123/improve_pandas_udf_pipeline.

Lead-authored-by: WeichenXu <weichen.xu@databricks.com>
Co-authored-by: Xiangrui Meng <meng@databricks.com>
Signed-off-by: gatorsmile <gatorsmile@gmail.com>
  • Loading branch information
2 people authored and emanuelebardelli committed Jun 15, 2019
1 parent e4f2b0c commit cfd5b21
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 9 deletions.
18 changes: 16 additions & 2 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import collections
import zlib
import itertools
import time

if sys.version < '3':
import cPickle as pickle
Expand Down Expand Up @@ -230,11 +231,19 @@ class ArrowStreamSerializer(Serializer):
def dump_stream(self, iterator, stream):
import pyarrow as pa
writer = None
last_flush_time = time.time()
try:
for batch in iterator:
if writer is None:
writer = pa.RecordBatchStreamWriter(stream, batch.schema)
writer.write_batch(batch)
current_time = time.time()
# If it takes time to compute each input batch but per-batch data is very small,
# the data might stay in the buffer for long and downstream reader cannot read it.
# We want to flush timely in this case.
if current_time - last_flush_time > 0.1:
stream.flush()
last_flush_time = current_time
finally:
if writer is not None:
writer.close()
Expand Down Expand Up @@ -872,11 +881,16 @@ def write(self, bytes):
byte_pos = new_byte_pos
self.current_pos = 0

def close(self):
# if there is anything left in the buffer, write it out first
def flush(self):
if self.current_pos > 0:
write_int(self.current_pos, self.wrapped)
self.wrapped.write(self.buffer[:self.current_pos])
self.current_pos = 0
self.wrapped.flush()

def close(self):
# If there is anything left in the buffer, write it out first.
self.flush()
# -1 length indicates to the receiving end that we're done.
write_int(-1, self.wrapped)
self.wrapped.close()
Expand Down
3 changes: 3 additions & 0 deletions python/pyspark/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(self):
def write(self, b):
self.buffer += b

def flush(self):
pass

def close(self):
pass

Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/tests/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,16 @@ def test_chunked_stream(self):
# ends with a -1
self.assertEqual(dest.buffer[-4:], write_int(-1))

def test_chunked_stream_flush(self):
wrapped = ByteArrayOutput()
stream = serializers.ChunkedStream(wrapped, 10)
stream.write(bytearray([0]))
self.assertEqual(len(wrapped.buffer), 0, "small write should be buffered")
stream.flush()
# Expect buffer size 4 bytes + buffer data 1 byte.
self.assertEqual(len(wrapped.buffer), 5, "flush should work")
stream.close()


if __name__ == "__main__":
from pyspark.tests.test_serializers import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,21 @@ class ArrowPythonRunner(
val arrowWriter = ArrowWriter.create(root)
val writer = new ArrowStreamWriter(root, null, dataOut)
writer.start()

while (inputIterator.hasNext) {
val nextBatch = inputIterator.next()

while (nextBatch.hasNext) {
arrowWriter.write(nextBatch.next())
var lastFlushTime = System.currentTimeMillis()
inputIterator.foreach { batch =>
batch.foreach { row =>
arrowWriter.write(row)
}

arrowWriter.finish()
writer.writeBatch()
val currentTime = System.currentTimeMillis()
// If it takes time to compute each input batch but per-batch data is very small,
// the data might stay in the buffer for long and downstream reader cannot read it.
// We want to flush timely in this case.
if (currentTime - lastFlushTime > 100) {
dataOut.flush()
lastFlushTime = currentTime
}
arrowWriter.reset()
}
// end writes footer to the output stream and doesn't clean any resources.
Expand Down

0 comments on commit cfd5b21

Please sign in to comment.