Skip to content

Commit

Permalink
Add split-apply-combine.
Browse files Browse the repository at this point in the history
  • Loading branch information
faustomorales committed Feb 10, 2023
1 parent 3aac309 commit 4c2ad0f
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 12 deletions.
12 changes: 7 additions & 5 deletions mira/core/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def __call__(
bboxes: typing.List[typing.Tuple[int, int, int, int]],
keypoints: typing.List[typing.Tuple[int, int]],
bbox_indices: typing.List[int],
keypoint_indices: typing.List[typing.Tuple[int, int]],
keypoint_indices: typing.List[
typing.Union[typing.Tuple[int, int], typing.Tuple[None, None]]
],
) -> AugmentedResult:
pass

Expand Down Expand Up @@ -317,7 +319,7 @@ class CoarseDropout(A.CoarseDropout):
bounding boxes and keypoints.
"""

def apply(
def apply( # type: ignore[override]
self,
image: np.ndarray,
fill_value: typing.Union[int, float],
Expand All @@ -326,13 +328,13 @@ def apply(
) -> np.ndarray:
holes_cnt = params["holes_cnt"]
image = image.copy()
image = A.functional.cutout(image, holes, fill_value)
image = A.functional.cutout(image, holes, fill_value) # type: ignore[attr-defined]
cv2.drawContours(
image, holes_cnt, contourIdx=-1, color=fill_value, thickness=-1
)
return image

def apply_to_mask(
def apply_to_mask( # type: ignore[override]
self,
mask: np.ndarray,
mask_fill_value: typing.Union[int, float],
Expand All @@ -343,7 +345,7 @@ def apply_to_mask(
return mask
holes_cnt = params["holes_cnt"]
mask = mask.copy()
mask = A.functional.cutout(mask, holes, mask_fill_value)
mask = A.functional.cutout(mask, holes, mask_fill_value) # type: ignore[attr-defined]
cv2.drawContours(
mask, holes_cnt, contourIdx=-1, color=mask_fill_value, thickness=-1
)
Expand Down
10 changes: 5 additions & 5 deletions mira/core/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,13 @@ def to_subcrops(self, max_size: int) -> typing.List["Scene"]:
captured.extend(ann_inc)
subcrops.append(
self.assign(
image=image[y1:y2, x1:x2],
image=image[y1:y2, x1:x2], # type: ignore[misc]
annotations=[
a.assign(
x1=a.x1 - x1,
y1=a.y1 - y1,
x2=a.x2 - x1,
y2=a.y2 - y1,
x1=a.x1 - x1, # type: ignore[operator]
y1=a.y1 - y1, # type: ignore[operator]
x2=a.x2 - x1, # type: ignore[operator]
y2=a.y2 - y1, # type: ignore[operator]
)
for a in ann_inc
],
Expand Down
82 changes: 80 additions & 2 deletions mira/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import json
import typing
import logging
import operator
import itertools
import collections

import requests
Expand Down Expand Up @@ -426,10 +428,13 @@ def split(
return splits


def flatten(t):
FlattenItem = typing.TypeVar("FlattenItem")


def flatten(t: typing.Iterable[typing.List[FlattenItem]]) -> typing.List[FlattenItem]:
"""Standard utility function for flattening a nested list taken from
https://stackoverflow.com/questions/952914/how-to-make-a-flat-list-out-of-a-list-of-lists."""
return [item for sublist in t for item in sublist]
return list(itertools.chain.from_iterable(t))


def find_largest_unique_contours(contours, threshold=1, method="iou"):
Expand Down Expand Up @@ -576,3 +581,76 @@ def load_json(filepath: str):
"""Load JSON from file"""
with open(filepath, "r", encoding="utf8") as f:
return json.loads(f.read())


SplitApplyRecombineItem = typing.TypeVar("SplitApplyRecombineItem")
SplitApplyRecombineOutput = typing.TypeVar("SplitApplyRecombineOutput")
SplitApplyRecombineIndexKey = typing.NamedTuple(
"SplitApplyRecombineIndexKey", [("key", str), ("index", int)]
)


def split_apply_combine(
items: typing.List[SplitApplyRecombineItem],
key: typing.Callable[
[SplitApplyRecombineItem],
str,
],
func: typing.Callable[
[typing.List[SplitApplyRecombineItem]],
typing.List[SplitApplyRecombineOutput],
],
):
"""This is an implementation of the split-apply-combine pattern in data processing. It
takes a list of items, items, and processes them in the following way:
Split: The items are sorted and grouped based on a key function, key, that maps each item
to a key.The function takes an item as input and returns a key value.
Apply: For each group of items with the same key, the function func is applied to
the key and the list of items in that group. func takes a key and a list of
items as inputs and returns a list of outputs.
Combine: The outputs from the func are recombined and the resulting list is returned with
the outputs in corresponding order to the original inputs.
In the implementation, the enumerate function is used to add an index to each item in the input
list items. The sorted function is used to sort the items based on their keys, as returned by the key function.
The itertools.groupby function is then used to group the items based on their keys.
The map function is used to apply the func to each group of items and produce a list of outputs.
Finally, the outputs are sorted based on their indices and the final result is returned as a list."""

def compute_group_result(
groupi: typing.Tuple[str, typing.Iterator[SplitApplyRecombineIndexKey]]
):
groupl = list(groupi[1])
return list(
zip(
[x.index for x in groupl],
func([items[x.index] for x in groupl]),
)
)

return list(
map(
operator.itemgetter(1),
sorted(
# Combine.
flatten(
# Apply.
map(
compute_group_result,
# Split.
itertools.groupby(
sorted(
map(
lambda indexItem: SplitApplyRecombineIndexKey(
key(indexItem[1]), indexItem[0]
),
enumerate(items),
)
),
key=lambda entry: entry.key,
),
),
),
),
)
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ exclude = [
'mira/thirdparty',
'mira/core/protos'
]
no_implicit_optional = false
ignore_missing_imports = true
check_untyped_defs = true

Expand Down
15 changes: 15 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,18 @@ def test_safe_crop():
assert variety == 1
# Make sure it works on empty images.
scene.assign(annotations=[]).augment(augmenter)


def test_split_apply_combine():
def process(arr):
print("Length", len(arr))
verify = [i % 2 == 0 for i in arr]
assert all(verify) or not any(verify)
if verify[0]:
return arr
return [i - 1 for i in arr]

processed = mc.utils.split_apply_combine(
list(range(10)), key=lambda x: str(x % 2 == 0), func=process
)
assert all(e == a for e, a in zip([0, 0, 2, 2, 4, 4, 6, 6, 8, 8], processed))

0 comments on commit 4c2ad0f

Please sign in to comment.