Skip to content

Commit

Permalink
Check if descriptions are available before loading; add save() and lo…
Browse files Browse the repository at this point in the history
…ad() convenience methods
  • Loading branch information
venkatesh-sivaraman committed Jan 4, 2022
1 parent 7b97634 commit 570594d
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 26 deletions.
34 changes: 31 additions & 3 deletions emblaze/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ def to_json(self, compressed=True, num_neighbors=None):
"frameLabels": [emb.label or "Frame {}".format(i) for i, emb in enumerate(self.embeddings)]
}

@staticmethod
def from_json(data, metric='euclidean'):
@classmethod
def from_json(cls, data, metric='euclidean'):
"""
Builds an EmbeddingSet from a JSON object. The provided object should
contain a "data" field containing frames, and optionally a "frameLabels"
Expand All @@ -542,4 +542,32 @@ def from_json(data, metric='euclidean'):
assert "data" in data, "JSON object must contain a 'data' field"
labels = data.get("frameLabels", [None for _ in range(len(data["data"]))])
embs = [Embedding.from_json(frame, label=label, metric=metric) for frame, label in zip(data["data"], labels)]
return EmbeddingSet(embs, align=False)
return cls(embs, align=False)

def save(self, file_path_or_buffer):
"""
Save this EmbeddingSet object to the given file path or file-like object
(in JSON format).
"""
if isinstance(file_path_or_buffer, str):
# File path
with open(file_path_or_buffer, 'w') as file:
json.dump(self.to_json(), file)
else:
# File object
json.dump(self.to_json(), file_path_or_buffer)

@classmethod
def load(cls, file_path_or_buffer, metric='euclidean'):
"""
Load the EmbeddingSet from the given file path or file-like object
containing JSON data.
"""
if isinstance(file_path_or_buffer, str):
# File path
with open(file_path_or_buffer, 'r') as file:
return cls.from_json(json.load(file), metric=metric)
else:
# File object
return cls.from_json(json.load(file_path_or_buffer), metric=metric)

96 changes: 73 additions & 23 deletions emblaze/thumbnails.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,52 @@ def to_json(self):
"format": self.format
}

@classmethod
def from_json(cls, data, ids=None):
"""
Loads this Thumbnails object from JSON. This base method chooses the
appropriate Thumbnails subclass to initialize, based on the format
declared in the given JSON object data. The ids parameter can be used
to subset the point IDs that are loaded from the file.
"""
if data["format"] == "spritesheet_and_text":
return CombinedThumbnails.from_json(data)
elif data["format"] == "spritesheet":
return ImageThumbnails.from_json(data, ids=ids)
elif data["format"] == "text_descriptions":
return TextThumbnails.from_json(data, ids=ids)
raise ValueError("Unsupported data format '{}'".format(data["format"]))

def get_ids(self):
"""Return a numpy array of the IDs used in this thumbnails object."""
raise NotImplementedError

def save(self, file_path_or_buffer):
"""
Save this Thumbnails object to the given file path or file-like object
(in JSON format).
"""
if isinstance(file_path_or_buffer, str):
# File path
with open(file_path_or_buffer, 'w') as file:
json.dump(self.to_json(), file)
else:
# File object
json.dump(self.to_json(), file_path_or_buffer)

@classmethod
def load(cls, file_path_or_buffer, ids=None):
"""
Load the appropriate Thumbnails subclass from the given file path or
file-like object containing JSON data.
"""
if isinstance(file_path_or_buffer, str):
# File path
with open(file_path_or_buffer, 'r') as file:
return cls.from_json(json.load(file), ids=ids)
else:
# File object
return cls.from_json(json.load(file_path_or_buffer), ids=ids)

class TextThumbnails(Thumbnails):
"""
Expand All @@ -44,6 +87,10 @@ def __init__(self, names, descriptions=None, ids=None):
integers.
"""
super().__init__("text_descriptions")
if ids is not None:
assert len(names) == len(ids), "Length mismatch: got {} names for {} IDs".format(len(names), len(ids))
if descriptions is not None:
assert len(descriptions) == len(ids), "Length mismatch: got {} descriptions for {} IDs".format(len(descriptions), len(ids))
self.data = ColumnarData({
Field.NAME: names
}, ids)
Expand Down Expand Up @@ -71,18 +118,20 @@ def to_json(self):
result = super().to_json()
names = self.data.field(Field.NAME)
descriptions = self.data.field(Field.DESCRIPTION)
def _make_json_item(i, id_val):
item = {"id": id_val, "name": str(names[i])}
if descriptions is not None and descriptions[i]:
item["description"] = str(descriptions[i])
return item
result["items"] = {
id_val: {
"id": id_val,
"name": str(names[i]),
"description": str(descriptions[i]) if descriptions is not None else "",
"frames": {} # Not implemented
} for i, id_val in enumerate(self.data.ids) if names[i] or descriptions[i]
id_val: _make_json_item(i, id_val)
for i, id_val in enumerate(self.data.ids)
if names[i] or (descriptions is not None and descriptions[i])
}
return standardize_json(result)

@staticmethod
def from_json(data, ids=None):
@classmethod
def from_json(cls, data, ids=None):
"""
Builds a TextThumbnails object from a JSON object. The provided object should
have an "items" key with a dictionary mapping ID values to text thumbnail
Expand All @@ -99,7 +148,7 @@ def from_json(data, ids=None):
ids = sorted(ids)
names = [items[id_val]["name"] for id_val in ids]
descriptions = [items[id_val].get("description", "") for id_val in ids]
return TextThumbnails(names, descriptions, ids)
return cls(names, descriptions, ids)

def __getitem__(self, ids):
"""
Expand Down Expand Up @@ -239,8 +288,8 @@ def to_json(self):

return standardize_json(result)

@staticmethod
def from_json(data, ids=None):
@classmethod
def from_json(cls, data, ids=None):
"""
Builds an ImageThumbnails object from a JSON object. The provided object should
have a "spritesheets" object that defines PIXI spritesheets, and an
Expand All @@ -250,7 +299,7 @@ def from_json(data, ids=None):
assert "spritesheets" in data, "JSON object must contain a 'spritesheets' field"
spritesheets = data["spritesheets"]
if ids is None:
ids = ImageThumbnails._get_spritesheet_ids(spritesheets)
ids = cls._get_spritesheet_ids(spritesheets)

names = None
descriptions = None
Expand All @@ -259,14 +308,14 @@ def from_json(data, ids=None):
names = [items[str(id_val)]["name"] for id_val in ids]
descriptions = [items[str(id_val)].get("description", "") for id_val in ids]

return ImageThumbnails(None,
spritesheets=spritesheets,
ids=ids,
names=names,
descriptions=descriptions)
return cls(None,
spritesheets=spritesheets,
ids=ids,
names=names,
descriptions=descriptions)

@staticmethod
def _get_spritesheet_ids(spritesheets):
@classmethod
def _get_spritesheet_ids(cls, spritesheets):
ids = sorted([k for sp in spritesheets.values() for k in sp["spec"]["frames"].keys()])
try:
ids = [int(id_val) for id_val in ids]
Expand Down Expand Up @@ -454,7 +503,8 @@ class CombinedThumbnails(Thumbnails):
"""
def __init__(self, thumbnail_objects):
has_images = any(t.format == "spritesheet" for t in thumbnail_objects)
super().__init__("spritesheet" if has_images else "text_descriptions")
has_texts = any(t.format == "text_descriptions" for t in thumbnail_objects)
super().__init__(("spritesheet_and_text" if has_images else "text_descriptions") if has_texts else "spritesheet")

# Merge the list of representations together
self.ids = np.array(sorted(set.union(*(set(t.get_ids().tolist()) for t in thumbnail_objects))))
Expand Down Expand Up @@ -510,16 +560,16 @@ def to_json(self):

return result

@staticmethod
def from_json(data):
@classmethod
def from_json(cls, data, ids=None):
thumbnails = []
if "spritesheets" in data:
thumbnails.append(ImageThumbnails.from_json({"spritesheets": data["spritesheets"]}))

if "items" in data:
thumbnails.append(TextThumbnails.from_json({"items": data["items"]}))

return CombinedThumbnails(thumbnails)
return cls(thumbnails)

def image(self, ids=None):
"""
Expand Down

0 comments on commit 570594d

Please sign in to comment.