Skip to content

Commit

Permalink
Fix issue where empty partitions caused incorrect behavior (#119)
Browse files Browse the repository at this point in the history
* Fix issue where empty partitions caused incorrect behavior

* Formatting

* Fix empty partition bug

* Revert unnecessary change

* Minor correctness fix
  • Loading branch information
devin-petersohn authored and osalpekar committed Oct 8, 2018
1 parent 07f3cbc commit c6a8080
Showing 1 changed file with 50 additions and 6 deletions.
56 changes: 50 additions & 6 deletions modin/data_management/partitioning/partition_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ def __init__(self, partitions):
# Partition class is the class to use for storing each partition. It must
# extend the `RemotePartition` class.
_partition_class = None
# Whether or not we have already filtered out the empty partitions.
_filtered_empties = False

def _get_partitions(self):
if not self._filtered_empties:
self._partitions_cache = np.array(
[
[
self._partitions_cache[i][j]
for j in range(len(self._partitions_cache[i]))
if self.block_lengths[i] != 0 or self.block_widths[j] != 0
]
for i in range(len(self._partitions_cache))
]
)
self._remove_empty_blocks()
self._filtered_empties = True
return self._partitions_cache

def _set_partitions(self, new_partitions):
self._filtered_empties = False
self._partitions_cache = new_partitions

partitions = property(_get_partitions, _set_partitions)

def preprocess_func(self, map_func):
"""Preprocess a function to be applied to `RemotePartition` objects.
Expand Down Expand Up @@ -102,7 +126,11 @@ def block_lengths(self):
# The first column will have the correct lengths. We have an
# invariant that requires that all blocks be the same length in a
# row of blocks.
self._lengths_cache = [obj.length().get() for obj in self.partitions.T[0]]
self._lengths_cache = (
[obj.length().get() for obj in self._partitions_cache.T[0]]
if len(self._partitions_cache.T) > 0
else []
)
return self._lengths_cache

# Widths of the blocks
Expand All @@ -119,9 +147,21 @@ def block_widths(self):
# The first column will have the correct lengths. We have an
# invariant that requires that all blocks be the same width in a
# column of blocks.
self._widths_cache = [obj.width().get() for obj in self.partitions[0]]
self._widths_cache = (
[obj.width().get() for obj in self._partitions_cache[0]]
if len(self._partitions_cache) > 0
else []
)
return self._widths_cache

def _remove_empty_blocks(self):
if self._widths_cache is not None:
self._widths_cache = [width for width in self._widths_cache if width != 0]
if self._lengths_cache is not None:
self._lengths_cache = [
length for length in self._lengths_cache if length != 0
]

@property
def shape(self) -> Tuple[int, int]:
return int(np.sum(self.block_lengths)), int(np.sum(self.block_widths))
Expand Down Expand Up @@ -963,8 +1003,10 @@ def block_lengths(self):
# The first column will have the correct lengths. We have an
# invariant that requires that all blocks be the same length in a
# row of blocks.
self._lengths_cache = ray.get(
[obj.length().oid for obj in self.partitions.T[0]]
self._lengths_cache = (
ray.get([obj.length().oid for obj in self._partitions_cache.T[0]])
if len(self._partitions_cache.T) > 0
else []
)
return self._lengths_cache

Expand All @@ -982,8 +1024,10 @@ def block_widths(self):
# The first column will have the correct lengths. We have an
# invariant that requires that all blocks be the same width in a
# column of blocks.
self._widths_cache = ray.get(
[obj.width().oid for obj in self.partitions[0]]
self._widths_cache = (
ray.get([obj.width().oid for obj in self._partitions_cache[0]])
if len(self._partitions_cache) > 0
else []
)
return self._widths_cache

Expand Down

0 comments on commit c6a8080

Please sign in to comment.