Skip to content

Commit

Permalink
Merge 18844c6 into 10923bd
Browse files Browse the repository at this point in the history
  • Loading branch information
davidt0x committed Oct 3, 2019
2 parents 10923bd + 18844c6 commit c50a7af
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 5 deletions.
45 changes: 40 additions & 5 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,6 +1617,12 @@ def append_unique(old, new):
}
instance_type_to_idx = {Instance: 0, PredictedInstance: 1}

# Each instance we create will have and index in the dataset, keep track of
# these so we can quickly add from_predicted links on a second pass.
instance_to_idx = {}
instances_with_from_predicted = []
instances_from_predicted = []

# If we are appending, we need look inside to see what frame, instance, and point
# ids we need to start from. This gives us offsets to use.
if append and "points" in f:
Expand All @@ -1633,9 +1639,7 @@ def append_unique(old, new):
point_id = 0
pred_point_id = 0
instance_id = 0
frame_id = 0
all_from_predicted = []
from_predicted_id = 0

for frame_id, label in enumerate(labels):
frames[frame_id] = (
frame_id + frame_id_offset,
Expand All @@ -1645,6 +1649,11 @@ def append_unique(old, new):
instance_id + instance_id_offset + len(label.instances),
)
for instance in label.instances:

# Add this instance to our lookup structure we will need for from_predicted
# links
instance_to_idx[instance] = instance_id

parray = instance.get_points_array(copy=False, full=True)
instance_type = type(instance)

Expand All @@ -1659,8 +1668,8 @@ def append_unique(old, new):
# Keep track of any from_predicted instance links, we will insert the
# correct instance_id in the dataset after we are done.
if instance.from_predicted:
all_from_predicted.append(instance.from_predicted)
from_predicted_id = from_predicted_id + 1
instances_with_from_predicted.append(instance_id)
instances_from_predicted.append(instance.from_predicted)

# Copy all the data
instances[instance_id] = (
Expand Down Expand Up @@ -1688,6 +1697,21 @@ def append_unique(old, new):

instance_id = instance_id + 1

# Add from_predicted links
for instance_id, from_predicted in zip(
instances_with_from_predicted, instances_from_predicted
):
try:
instances[instance_id]["from_predicted"] = instance_to_idx[
from_predicted
]
except KeyError:
# If we haven't encountered the from_predicted instance yet then don't save the link.
# It’s possible for a user to create a regular instance from a predicted instance and then
# delete all predicted instances from the file, but in this case I don’t think there’s any reason
# to remember which predicted instance the regular instance came from.
pass

# We pre-allocated our points array with max possible size considering the max
# skeleton size, drop any unused points.
points = points[0:point_id]
Expand Down Expand Up @@ -1785,6 +1809,10 @@ def load_hdf5(
tracks = labels.tracks.copy()
tracks.extend([None])

# A dict to keep track of instances that have a from_predicted link. The key is the
# instance and the value is the index of the instance.
from_predicted_lookup = {}

# Create the instances
instances = []
for i in instances_dset:
Expand All @@ -1806,6 +1834,13 @@ def load_hdf5(
)
instances.append(instance)

if i["from_predicted"] != -1:
from_predicted_lookup[instance] = i["from_predicted"]

# Make a second pass to add any from_predicted links
for instance, from_predicted_idx in from_predicted_lookup.items():
instance.from_predicted = instances[from_predicted_idx]

# Create the labeled frames
frames = [
LabeledFrame(
Expand Down
21 changes: 21 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,3 +692,24 @@ def test_labels_append_hdf5(multi_skel_vid_labels, tmpdir):
loaded_labels = Labels.load_hdf5(filename=filename)

_check_labels_match(labels, loaded_labels)


def test_hdf5_from_predicted(multi_skel_vid_labels, tmpdir):
labels = multi_skel_vid_labels
filename = os.path.join(tmpdir, "test.h5")

# Add some predicted instances to create from_predicted links
for frame_num, frame in enumerate(labels):
if frame_num % 20 == 0:
frame.instances[0].from_predicted = PredictedInstance.from_instance(
frame.instances[0], float(frame_num)
)
frame.instances.append(frame.instances[0].from_predicted)

# Save and load, compare the results
Labels.save_hdf5(filename=filename, labels=labels)
loaded_labels = Labels.load_hdf5(filename=filename)

for frame_num, frame in enumerate(loaded_labels):
if frame_num % 20 == 0:
assert frame.instances[0].from_predicted.score == float(frame_num)

0 comments on commit c50a7af

Please sign in to comment.