Skip to content

Commit

Permalink
Efficiency improvements for collection serialization/deserialization …
Browse files Browse the repository at this point in the history
…and dataset splitting.
  • Loading branch information
faustomorales committed Aug 26, 2022
1 parent cd0db51 commit 01579ab
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 21 deletions.
52 changes: 40 additions & 12 deletions mira/core/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import io
import glob
import json
import typing
import logging
Expand All @@ -28,7 +29,7 @@

Dimensions = typing.NamedTuple("Dimensions", [("width", int), ("height", int)])


# pylint: disable=too-many-public-methods
class Scene:
"""A single annotated image.
Expand Down Expand Up @@ -164,6 +165,12 @@ def image_bytes(self) -> bytes:
"""Get the image as a PNG encoded to bytes."""
return utils.image2bytes(self.image)

@classmethod
def load(cls, filepath: str):
"""Load a scence from a filepath."""
with open(filepath, "rb") as f:
return cls.fromString(f.read())

@classmethod
def from_qsl(
cls,
Expand Down Expand Up @@ -319,10 +326,10 @@ def dimensions(self) -> Dimensions:
self._dimensions = dimensions
return self._dimensions

def toString(self):
def toString(self, extension=".png"):
"""Serialize scene to string."""
return mps.Scene(
image=cv2.imencode(".png", self.image)[1].tobytes(),
image=cv2.imencode(extension, self.image)[1].tobytes(),
categories=mps.Categories(categories=[c.name for c in self.categories]),
metadata=json.dumps(self.metadata or {}),
masks=[
Expand Down Expand Up @@ -740,6 +747,7 @@ def compute_iou(self, other: "Scene"):
return iou


# pylint: disable=too-many-public-methods
class SceneCollection:
"""A collection of scenes.
Expand Down Expand Up @@ -951,12 +959,12 @@ def sample(self, n, replace=True) -> "SceneCollection":
selected = np.random.choice(len(self.scenes), n, replace=replace)
return self.assign(scenes=[self.scenes[i] for i in selected])

def save(self, filename: str):
def save(self, filename: str, **kwargs):
"""Save scene collection a tarball."""
with tarfile.open(filename, mode="w") as tar:
for idx, scene in enumerate(tqdm.tqdm(self.scenes)):
with tempfile.NamedTemporaryFile() as temp:
temp.write(scene.toString())
temp.write(scene.toString(**kwargs))
temp.flush()
tar.add(name=temp.name, arcname=str(idx))

Expand Down Expand Up @@ -989,10 +997,31 @@ def save_placeholder(
tar.add(name=temp.name, arcname=str(idx))

@classmethod
def load(cls, filename: str, directory: str = None):
def load_from_directory(cls, directory: str):
"""Load a dataset that already was extracted from directory."""
return cls(
scenes=[
Scene.load(f).assign(image=f + ".png")
for f in tqdm.tqdm(
sorted(
[
f
for f in glob.glob(os.path.join(directory, "*"))
if not os.path.splitext(f)[1]
],
key=int,
)
)
]
)

@classmethod
def load(cls, filename: str, directory: str = None, force=False):
"""Load scene collection from a tarball. If a directory
is provided, images will be saved into that directory
rather than retained in memory."""
if directory and os.path.isdir(directory) and not force:
return cls.load_from_directory(directory)
if directory:
os.makedirs(directory, exist_ok=True)
scenes = []
Expand All @@ -1006,13 +1035,12 @@ def load(cls, filename: str, directory: str = None):
else:
label_filepath = os.path.join(directory, str(idx))
image_filepath = label_filepath + ".png"
if os.path.isfile(label_filepath) and os.path.isfile(
image_filepath
if (
os.path.isfile(label_filepath)
and os.path.isfile(image_filepath)
and not force
):
with open(label_filepath, "rb") as f:
scene = Scene.fromString(f.read()).assign(
image=image_filepath
)
scene = Scene.load(label_filepath).assign(image=image_filepath)
else:
scene = Scene.fromString(data.read())
cv2.imwrite(
Expand Down
33 changes: 25 additions & 8 deletions mira/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,15 @@ def compute_coverage(boxesA: np.ndarray, boxesB: np.ndarray) -> np.ndarray:
return coverageA


def groupby_unsorted(seq, key=lambda x: x):
"""groupby for unsorted inputs, taken from https://code.activestate.com/recipes/580800-groupby-for-unsorted-input/#c1"""
indexes = collections.defaultdict(list)
for i, elem in enumerate(seq):
indexes[key(elem)].append(i)
for k, idxs in indexes.items():
yield k, (seq[i] for i in idxs)


def split(
items: typing.List[typing.Any],
sizes: typing.List[float],
Expand Down Expand Up @@ -366,17 +375,25 @@ def split(
items
), "stratify must be the same length as the collection."
rng = np.random.default_rng(seed=random_state)
unique = collections.Counter(group)
hashes = [
hash(tuple(set(s for s, g in zip(stratify, group) if g == u))) for u in unique
grouped = [
{**dict(zip(["idxs", "stratifiers"], zip(*grouper))), "group": g}
for g, grouper in groupby_unsorted(
list(zip(range(len(stratify)), stratify)),
key=lambda v: typing.cast(typing.Sequence[typing.Hashable], group)[v[0]],
)
]
totals = collections.Counter(hashes)
for ht, t in totals.items():
hashes = {
h: list(g)
for h, g in groupby_unsorted(
grouped, key=lambda g: hash(tuple(set(g["stratifiers"])))
)
}
for subgroups in hashes.values():
for a, u in zip(
rng.choice(len(sizes), size=t, p=sizes),
[u for h, u in zip(hashes, unique) if h == ht],
rng.choice(len(sizes), size=len(subgroups), p=sizes),
subgroups,
):
splits[a].extend(i for i, g in zip(items, group) if g == u)
splits[a].extend(items[idx] for idx in u["idxs"])
return splits


Expand Down
2 changes: 1 addition & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_split():
abs(size - len(split) / len(items)) < 0.05
for split, size in zip(splits, sizes)
]
)
), f"Sizes were {splits}"

# Roughly achieve stratification
assert all(
Expand Down

0 comments on commit 01579ab

Please sign in to comment.