Skip to content

Commit

Permalink
Merge pull request #208 from murthylab/bug/tf_upsampling
Browse files Browse the repository at this point in the history
Bug/tf upsampling
  • Loading branch information
ntabris committed Oct 4, 2019
2 parents e44a79f + b5156a8 commit f201513
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 59 deletions.
18 changes: 18 additions & 0 deletions sleap/config/active.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,24 @@ learning:

inference:

- name: conf_job
label: Node (confmap) Training Profile
type: list
default: a
options: a,b,c

- name: paf_job
label: Edge (paf) Training Profile
type: list
default: a
options: a,b,c

- name: centroid_job
label: Centroid Training Profile
type: list
default: a
options: a,b,c

- name: _predict_frames
label: Predict On
type: list
Expand Down
47 changes: 32 additions & 15 deletions sleap/gui/active.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from PySide2 import QtWidgets, QtCore


SELECT_FILE_OPTION = "Select a training profile file..."


class ActiveLearningDialog(QtWidgets.QDialog):
"""Active learning dialog.
Expand Down Expand Up @@ -49,6 +52,10 @@ def __init__(
self.labels_filename = labels_filename
self.labels = labels
self.mode = mode
self._job_filter = None

if self.mode == "inference":
self._job_filter = lambda job: job.is_trained

print(f"Number of frames to train on: {len(labels.user_labeled_frames)}")

Expand Down Expand Up @@ -162,6 +169,13 @@ def _rebuild_job_options(self):
# list default profiles
find_saved_jobs(profile_dir, self.job_options)

# Apply any filters
if self._job_filter:
for model_type, jobs_list in self.job_options.items():
self.job_options[model_type] = [
(path, job) for (path, job) in jobs_list if self._job_filter(job)
]

def _update_job_menus(self, init: bool = False):
"""Updates the menus with training profile options.
Expand All @@ -176,9 +190,11 @@ def _update_job_menus(self, init: bool = False):
if model_type not in self.job_options:
self.job_options[model_type] = []
if init:
field.currentIndexChanged.connect(
lambda idx, mt=model_type: self._update_from_selected_job(mt, idx)
)

def menu_action(idx, mt=model_type, field=field):
self._update_from_selected_job(mt, idx, field)

field.currentIndexChanged.connect(menu_action)
else:
# block signals so we can update combobox without overwriting
# any user data with the defaults from the profile
Expand Down Expand Up @@ -365,6 +381,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]:
for model_type in self._get_model_types_to_use():
job, _ = self._get_current_job(model_type)

if job is None:
continue

if job.model.output_type != ModelOutputType.CENTROIDS:
# update training job from params in form
trainer = job.trainer
Expand Down Expand Up @@ -499,8 +518,9 @@ def _option_list_from_jobs(self, model_type: ModelOutputType):
"""Returns list of menu options for given model type."""
jobs = self.job_options[model_type]
option_list = [name for (name, job) in jobs]
option_list.append("")
option_list.append("---")
option_list.append("Select a training profile file...")
option_list.append(SELECT_FILE_OPTION)
return option_list

def _add_job_file(self, model_type):
Expand Down Expand Up @@ -548,9 +568,10 @@ def _add_job_file_to_list(self, filename: str, model_type: ModelOutputType):
text=f"Profile selected is for training {str(file_model_type)} instead of {str(model_type)}."
).exec_()

def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
def _update_from_selected_job(self, model_type: ModelOutputType, idx: int, field):
"""Updates dialog settings after user selects a training profile."""
jobs = self.job_options[model_type]
field_text = field.currentText()
if idx == -1:
return
if idx < len(jobs):
Expand All @@ -569,17 +590,13 @@ def _update_from_selected_job(self, model_type: ModelOutputType, idx: int):
self.form_widget.set_form_data(training_params)

# is the model already trained?
has_trained = False
final_model_filename = job.final_model_filename
if final_model_filename is not None:
if os.path.exists(os.path.join(job.save_dir, final_model_filename)):
has_trained = True
is_trained = job.is_trained
field_name = f"_use_trained_{str(model_type)}"
# update "use trained" checkbox
self.form_widget.fields[field_name].setEnabled(has_trained)
self.form_widget[field_name] = has_trained
else:
# last item is "select file..."
# update "use trained" checkbox if present
if field_name in self.form_widget.fields:
self.form_widget.fields[field_name].setEnabled(is_trained)
self.form_widget[field_name] = is_trained
elif field_text == SELECT_FILE_OPTION:
self._add_job_file(model_type)


Expand Down
90 changes: 46 additions & 44 deletions sleap/nn/peakfinding_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,49 @@ def impeaksnms_tf(I, min_thresh=0.3):

return inds, peak_vals

def upsample_peaks(unrolled_confmaps, peaks, h, w, channel_sample_ind, upsample_factor, win_size):
offset = (win_size - 1) / 2

# Get the boxes coordinates centered on the peaks, normalized to image
# coordinates
box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32))
top_left = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([-offset, -offset], dtype="float32")
) / (h - 1.0)
bottom_right = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([offset, offset], dtype="float32")
) / (w - 1.0)
boxes = tf.concat([top_left, bottom_right], axis=1)

small_windows = tf.image.crop_and_resize(
unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size]
)

# Upsample cropped windows
windows = tf.image.resize_bicubic(
small_windows, [upsample_factor * win_size, upsample_factor * win_size]
)

windows = tf.squeeze(windows)

# Find global maximum of each window
windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2)

# Adjust back to resolution before upsampling
windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(
upsample_factor, tf.float32
)

# Convert to offsets relative to the original peaks (center of cropped windows)
windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2)
windows_offsets = tf.pad(
windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0
) # (nc, 4)

# Apply offsets
return tf.cast(peaks, tf.float32) + windows_offsets

def find_peaks_tf(
confmaps,
Expand All @@ -68,54 +111,13 @@ def find_peaks_tf(
sample_ind = tf.floordiv(channel_sample_ind, c)

peaks = tf.concat([sample_ind, y, x, channel_ind], axis=1) # (nc, 4)

# If we have run prediction on low res and need to upsample the peaks
# to a higher resolution. Compute sub-pixel accurate peaks
# from these approximate peaks and return the upsampled sub-pixel peaks.
if upsample_factor > 1:

offset = (win_size - 1) / 2

# Get the boxes coordinates centered on the peaks, normalized to image
# coordinates
box_ind = tf.squeeze(tf.cast(channel_sample_ind, tf.int32))
top_left = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([-offset, -offset], dtype="float32")
) / (h - 1.0)
bottom_right = (
tf.cast(peaks[:, 1:3], tf.float32)
+ tf.constant([offset, offset], dtype="float32")
) / (w - 1.0)
boxes = tf.concat([top_left, bottom_right], axis=1)

small_windows = tf.image.crop_and_resize(
unrolled_confmaps, boxes, box_ind, crop_size=[win_size, win_size]
)

# Upsample cropped windows
windows = tf.image.resize_bicubic(
small_windows, [upsample_factor * win_size, upsample_factor * win_size]
)

windows = tf.squeeze(windows)

# Find global maximum of each window
windows_peaks = find_maxima_tf(windows) # [row_ind, col_ind] ==> (nc, 2)

# Adjust back to resolution before upsampling
windows_peaks = tf.cast(windows_peaks, tf.float32) / tf.cast(
upsample_factor, tf.float32
)

# Convert to offsets relative to the original peaks (center of cropped windows)
windows_offsets = windows_peaks - tf.cast(offset, tf.float32) # (nc, 2)
windows_offsets = tf.pad(
windows_offsets, [[0, 0], [1, 1]], mode="CONSTANT", constant_values=0
) # (nc, 4)

# Apply offsets
peaks = tf.cast(peaks, tf.float32) + windows_offsets
peaks = tf.cond(tf.less(tf.shape(peaks)[0], 1),
lambda: upsample_peaks(unrolled_confmaps, peaks, h, w, channel_sample_ind, upsample_factor, win_size),
lambda: tf.cast(peaks, tf.float32))

return peaks, peak_vals

Expand Down
8 changes: 8 additions & 0 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,14 @@ class TrainingJob:
newest_model_filename: Union[str, None] = None
final_model_filename: Union[str, None] = None

@property
def is_trained(self):
if self.final_model_filename is not None:
path = os.path.join(self.save_dir, self.final_model_filename)
if os.path.exists(path):
return True
return False

@staticmethod
def save_json(training_job: "TrainingJob", filename: str):
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/gui/test_active.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,19 @@ def test_active_gui(qtbot, centered_pair_labels):
assert ModelOutputType.PART_AFFINITY_FIELD not in jobs


def test_inference_gui(qtbot, centered_pair_labels):
win = ActiveLearningDialog(
labels_filename="foo.json", labels=centered_pair_labels, mode="inference"
)
win.show()
qtbot.addWidget(win)

# There aren't any trained models, so there should be no options shown for
# inference
jobs = win._get_current_training_jobs()
assert len(jobs) == 0


def test_make_default_training_jobs():
jobs = make_default_training_jobs()

Expand Down

0 comments on commit f201513

Please sign in to comment.