Skip to content

Commit

Permalink
[SPARK-21985][PYSPARK] PairDeserializer is broken for double-zipped RDDs
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
(edited)
Fixes a bug introduced in apache#16121

In PairDeserializer convert each batch of keys and values to lists (if they do not have `__len__` already) so that we can check that they are the same size. Normally they already are lists so this should not have a performance impact, but this is needed when repeated `zip`'s are done.

## How was this patch tested?

Additional unit test

Author: Andrew Ray <ray.andrew@gmail.com>

Closes apache#19226 from aray/SPARK-21985.
  • Loading branch information
aray authored and HyukjinKwon committed Sep 17, 2017
1 parent f407302 commit 6adf67d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def load_stream(self, stream):

def _load_stream_without_unbatching(self, stream):
"""
Return an iterator of deserialized batches (lists) of objects from the input stream.
Return an iterator of deserialized batches (iterable) of objects from the input stream.
if the serializer does not operate on batches the default implementation returns an
iterator of single element lists.
"""
Expand Down Expand Up @@ -343,6 +343,10 @@ def _load_stream_without_unbatching(self, stream):
key_batch_stream = self.key_ser._load_stream_without_unbatching(stream)
val_batch_stream = self.val_ser._load_stream_without_unbatching(stream)
for (key_batch, val_batch) in zip(key_batch_stream, val_batch_stream):
# For double-zipped RDDs, the batches can be iterators from other PairDeserializer,
# instead of lists. We need to convert them to lists if needed.
key_batch = key_batch if hasattr(key_batch, '__len__') else list(key_batch)
val_batch = val_batch if hasattr(val_batch, '__len__') else list(val_batch)
if len(key_batch) != len(val_batch):
raise ValueError("Can not deserialize PairRDD with different number of items"
" in batches: (%d, %d)" % (len(key_batch), len(val_batch)))
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,18 @@ def test_cartesian_chaining(self):
set([(x, (y, y)) for x in range(10) for y in range(10)])
)

def test_zip_chaining(self):
# Tests for SPARK-21985
rdd = self.sc.parallelize('abc', 2)
self.assertSetEqual(
set(rdd.zip(rdd).zip(rdd).collect()),
set([((x, x), x) for x in 'abc'])
)
self.assertSetEqual(
set(rdd.zip(rdd.zip(rdd)).collect()),
set([(x, (x, x)) for x in 'abc'])
)

def test_deleting_input_files(self):
# Regression test for SPARK-1025
tempFile = tempfile.NamedTemporaryFile(delete=False)
Expand Down

0 comments on commit 6adf67d

Please sign in to comment.