Skip to content

Commit

Permalink
Use new merge method in active learning.
Browse files Browse the repository at this point in the history
Also, save all the new predictions together in one h5 file.
  • Loading branch information
ntabris committed Oct 7, 2019
1 parent edd389a commit 86aad64
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 29 deletions.
40 changes: 12 additions & 28 deletions sleap/gui/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import cattr

from datetime import datetime
from functools import reduce
from pkg_resources import Requirement, resource_filename
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -699,28 +700,6 @@ def find_saved_jobs(
return jobs


def add_frames_from_json(labels: Labels, new_labels_json: str) -> int:
"""Merges new predictions (given as json string) into dataset.
Args:
labels: The dataset to which we're adding the predictions.
new_labels_json: A JSON string which can be deserialized into `Labels`.
Returns:
Number of labeled frames with new predictions.
"""
# Deserialize the new frames, matching to the existing videos/skeletons if possible
new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames

# Remove any frames without instances
new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs))

# Now add them to labels and merge labeled frames with same video/frame_idx
labels.extend_from(new_lfs)
labels.merge_matching_frames()

return len(new_lfs)


def run_active_learning_pipeline(
labels_filename: str,
labels: Labels,
Expand Down Expand Up @@ -879,8 +858,8 @@ def run_active_inference(
# from multiprocessing import Pool

# total_new_lf_count = 0
# timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
# inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5")
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5")

# Create Predictor from the results of training
# pool = Pool(processes=1)
Expand Down Expand Up @@ -942,10 +921,15 @@ def run_active_inference(
# Remove any frames without instances
new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs))

# Now add them to labels and merge labeled frames with same video/frame_idx
# labels.extend_from(new_lfs)
labels.extend_from(new_lfs, unify=True)
labels.merge_matching_frames()
# Create and save dataset with predictions
new_labels = Labels(new_lfs)
Labels.save_file(new_labels, inference_output_path)

# Merge predictions into current labels dataset
_, _, new_conflicts = Labels.complex_merge_between(labels, new_labels)

# new predictions should replace old ones
Labels.finish_complex_merge(labels, new_conflicts)

# close message window
if gui:
Expand Down
3 changes: 2 additions & 1 deletion tests/gui/test_active.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import pytest

from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance
Expand All @@ -9,7 +10,6 @@
ActiveLearningDialog,
make_default_training_jobs,
find_saved_jobs,
add_frames_from_json,
)


Expand Down Expand Up @@ -80,6 +80,7 @@ def test_find_saved_jobs():
assert os.path.basename(paths[1]) == "default_confmaps.json"


@pytest.mark.skip(reason="for old merging method")
def test_add_frames_from_json():
vid_a = Video.from_filename("foo.mp4")
vid_b = Video.from_filename("bar.mp4")
Expand Down

0 comments on commit 86aad64

Please sign in to comment.