Skip to content

Commit

Permalink
Merge 81be672 into e44a79f
Browse files Browse the repository at this point in the history
  • Loading branch information
ntabris committed Oct 9, 2019
2 parents e44a79f + 81be672 commit 8933cf9
Show file tree
Hide file tree
Showing 7 changed files with 195 additions and 124 deletions.
18 changes: 18 additions & 0 deletions sleap/config/active.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ learning:

inference:

- name: conf_job
label: Node (confmap) Training Profile
type: list
default: a
options: a,b,c

- name: paf_job
label: Edge (paf) Training Profile
type: list
default: a
options: a,b,c

- name: centroid_job
label: Centroid Training Profile
type: list
default: a
options: a,b,c

- name: _predict_frames
label: Predict On
type: list
Expand Down
87 changes: 44 additions & 43 deletions sleap/gui/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import cattr

from datetime import datetime
from functools import reduce
from pkg_resources import Requirement, resource_filename
from typing import Dict, List, Optional, Tuple
Expand All @@ -19,6 +20,9 @@
from PySide2 import QtWidgets, QtCore


SELECT_FILE_OPTION = "Select a training profile file..."


class ActiveLearningDialog(QtWidgets.QDialog):
"""Active learning dialog.
Expand Down Expand Up @@ -49,6 +53,10 @@ def __init__(
self.labels_filename = labels_filename
self.labels = labels
self.mode = mode
self._job_filter = None

if self.mode == "inference":
self._job_filter = lambda job: job.is_trained

print(f"Number of frames to train on: {len(labels.user_labeled_frames)}")

Expand Down Expand Up @@ -162,6 +170,13 @@ def _rebuild_job_options(self):
# list default profiles
find_saved_jobs(profile_dir, self.job_options)

# Apply any filters
if self._job_filter:
for model_type, jobs_list in self.job_options.items():
self.job_options[model_type] = [
(path, job) for (path, job) in jobs_list if self._job_filter(job)
]

def _update_job_menus(self, init: bool = False):
"""Updates the menus with training profile options.
Expand All @@ -176,9 +191,11 @@ def _update_job_menus(self, init: bool = False):
if model_type not in self.job_options:
self.job_options[model_type] = []
if init:
field.currentIndexChanged.connect(
lambda idx, mt=model_type: self._update_from_selected_job(mt, idx)
)

def menu_action(idx, mt=model_type, field=field):
self._update_from_selected_job(mt, idx, field)

field.currentIndexChanged.connect(menu_action)
else:
# block signals so we can update combobox without overwriting
# any user data with the defaults from the profile
Expand Down Expand Up @@ -365,6 +382,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]:
for model_type in self._get_model_types_to_use():
job, _ = self._get_current_job(model_type)

if job is None:
continue

if job.model.output_type != ModelOutputType.CENTROIDS:
# update training job from params in form
trainer = job.trainer
Expand Down Expand Up @@ -499,8 +519,9 @@ def _option_list_from_jobs(self, model_type: ModelOutputType):
"""Returns list of menu options for given model type."""
jobs = self.job_options[model_type]
option_list = [name for (name, job) in jobs]
option_list.append("")
option_list.append("---")
option_list.append("Select a training profile file...")
option_list.append(SELECT_FILE_OPTION)
return option_list

def _add_job_file(self, model_type):
Expand Down Expand Up @@ -548,9 +569,10 @@ def _add_job_file_to_list(self, filename: str, model_type: ModelOutputType):
text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}."
).exec_()

def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
def _update_from_selected_job(self, model_type: ModelOutputType, idx: int, field):
"""Updates dialog settings after user selects a training profile."""
jobs = self.job_options[model_type]
field_text = field.currentText()
if idx == -1:
return
if idx < len(jobs):
Expand All @@ -569,17 +591,13 @@ def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
self.form_widget.set_form_data(training_params)

# is the model already trained?
has_trained = False
final_model_filename = job.final_model_filename
if final_model_filename is not None:
if os.path.exists(os.path.join(job.save_dir, final_model_filename)):
has_trained = True
is_trained = job.is_trained
field_name = f"_use_trained_{str(model_type)}"
# update "use trained" checkbox
self.form_widget.fields[field_name].setEnabled(has_trained)
self.form_widget[field_name] = has_trained
else:
# last item is "select file..."
# update "use trained" checkbox if present
if field_name in self.form_widget.fields:
self.form_widget.fields[field_name].setEnabled(is_trained)
self.form_widget[field_name] = is_trained
elif field_text == SELECT_FILE_OPTION:
self._add_job_file(model_type)


Expand Down Expand Up @@ -682,28 +700,6 @@ def find_saved_jobs(
return jobs


def add_frames_from_json(labels: Labels, new_labels_json: str) -> int:
"""Merges new predictions (given as json string) into dataset.
Args:
labels: The dataset to which we're adding the predictions.
new_labels_json: A JSON string which can be deserialized into `Labels`.
Returns:
Number of labeled frames with new predictions.
"""
# Deserialize the new frames, matching to the existing videos/skeletons if possible
new_lfs = Labels.from_json(new_labels_json, match_to=labels).labeled_frames

# Remove any frames without instances
new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs))

# Now add them to labels and merge labeled frames with same video/frame_idx
labels.extend_from(new_lfs)
labels.merge_matching_frames()

return len(new_lfs)


def run_active_learning_pipeline(
labels_filename: str,
labels: Labels,
Expand Down Expand Up @@ -862,8 +858,8 @@ def run_active_inference(
# from multiprocessing import Pool

# total_new_lf_count = 0
# timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
# inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5")
timestamp = datetime.now().strftime("%y%m%d_%H%M%S")
inference_output_path = os.path.join(save_dir, f"{timestamp}.inference.h5")

# Create Predictor from the results of training
# pool = Pool(processes=1)
Expand Down Expand Up @@ -925,10 +921,15 @@ def run_active_inference(
# Remove any frames without instances
new_lfs = list(filter(lambda lf: len(lf.instances), new_lfs))

# Now add them to labels and merge labeled frames with same video/frame_idx
# labels.extend_from(new_lfs)
labels.extend_from(new_lfs, unify=True)
labels.merge_matching_frames()
# Create and save dataset with predictions
new_labels = Labels(new_lfs)
Labels.save_file(new_labels, inference_output_path)

# Merge predictions into current labels dataset
_, _, new_conflicts = Labels.complex_merge_between(labels, new_labels)

# new predictions should replace old ones
Labels.finish_complex_merge(labels, new_conflicts)

# close message window
if gui:
Expand Down
15 changes: 13 additions & 2 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ def clearFrameNegativeAnchors(self):
def importPredictions(self):
"""Starts gui for importing another dataset into currently one."""
filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"]
filenames, selected_filter = openFileDialogs(
filenames, selected_filter = openFileDialog(
self,
dir=None,
caption="Import labeled data...",
Expand Down Expand Up @@ -1545,7 +1545,7 @@ def doubleClickInstance(self, instance: Instance):
).boundingRect()

for node in self.skeleton.nodes:
if node.name not in instance.node_names or instance[node].isnan():
if node not in instance.nodes or instance[node].isnan():
# pick random points within currently zoomed view
x = (
in_view_rect.x()
Expand Down Expand Up @@ -2055,6 +2055,17 @@ def _plot_if_next(self, frame_iterator: Iterator) -> bool:
self.plotFrame(next_lf.frame_idx)
return True

def previousLabeledFrameIndex(self):
cur_idx = self.player.frame_idx
frames = self.labels.frames(self.video, from_frame_idx=cur_idx, reverse=True)

try:
next_idx = next(frames).frame_idx
except:
return

return next_idx

def previousLabeledFrame(self):
"""Goes to labeled frame prior to current frame."""
frames = self.labels.frames(
Expand Down

0 comments on commit 8933cf9

Please sign in to comment.