Skip to content

Commit

Permalink
Merge a31b4e3 into e77f0f9
Browse files Browse the repository at this point in the history
  • Loading branch information
ntabris committed Mar 27, 2020
2 parents e77f0f9 + a31b4e3 commit 1ebb6e8
Show file tree
Hide file tree
Showing 20 changed files with 425 additions and 95 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_requirements(require_name=None):
include_package_data=True,
entry_points={
"console_scripts": [
"sleap-convert=sleap.io.convert:main",
"sleap-label=sleap.gui.app:main",
"sleap-train=sleap.nn.training:main",
"sleap-track=sleap.nn.inference:main",
Expand Down
8 changes: 8 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,14 @@ def add_submenu_choices(menu, title, options, key):
add_menu_item(fileMenu, "save", "Save", self.commands.saveProject)
add_menu_item(fileMenu, "save as", "Save As...", self.commands.saveProjectAs)

fileMenu.addSeparator()
add_menu_item(
fileMenu,
"export analysis",
"Export Analysis HDF5...",
self.commands.exportAnalysisFile,
)

fileMenu.addSeparator()
add_menu_item(fileMenu, "close", "Quit", self.close)

Expand Down
58 changes: 47 additions & 11 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def saveProjectAs(self):
"""Show gui to save project as a new file."""
self.execute(SaveProjectAs)

def exportAnalysisFile(self):
"""Shows gui for exporting analysis h5 file."""
self.execute(ExportAnalysisFile)

def exportLabeledClip(self):
"""Shows gui for exporting clip with visual annotations."""
self.execute(ExportLabeledClip)
Expand Down Expand Up @@ -460,9 +464,8 @@ def do_action(context: "CommandContext", params: dict):
@staticmethod
def ask(context: "CommandContext", params: dict) -> bool:
filters = [
"HDF5 dataset (*.h5 *.hdf5)",
"SLEAP HDF5 dataset (*.slp *.h5 *.hdf5)",
"JSON labels (*.json *.json.zip)",
"DeepLabCut csv (*.csv)",
]

filename, selected_filter = FileDialog.open(
Expand Down Expand Up @@ -605,9 +608,7 @@ class ImportDeepLabCut(AppCommand):
@staticmethod
def do_action(context: "CommandContext", params: dict):

labels = Labels.load_deeplabcut_csv(
filename=params["filename"]
)
labels = Labels.load_deeplabcut_csv(filename=params["filename"])

new_window = context.app.__class__()
new_window.showMaximized()
Expand Down Expand Up @@ -666,8 +667,8 @@ def ask(context: CommandContext, params: dict) -> bool:
default_name = str(p.with_name(f"{p.stem} copy{p.suffix}"))

filters = [
"HDF5 dataset (*.h5)",
"JSON labels (*.json)",
"SLEAP HDF5 dataset (*.slp)",
"SLEAP JSON dataset (*.json)",
"Compressed JSON (*.zip)",
]
filename, selected_filter = FileDialog.save(
Expand All @@ -684,6 +685,35 @@ def ask(context: CommandContext, params: dict) -> bool:
return True


class ExportAnalysisFile(AppCommand):
@classmethod
def do_action(cls, context: CommandContext, params: dict):
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
context.labels, output_path=params["output_path"], all_frames=True
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
default_name = context.state["filename"] or "untitled"
p = PurePath(default_name)
default_name = str(p.with_name(f"{p.stem}.analysis.h5"))

filename, selected_filter = FileDialog.save(
context.app,
caption="Export Analysis File...",
dir=default_name,
filter="SLEAP Analysis HDF5 (*.h5)",
)

if len(filename) == 0:
return False

params["output_path"] = filename
return True


class SaveProject(SaveProjectAs):
@classmethod
def ask(cls, context: CommandContext, params: dict) -> bool:
Expand Down Expand Up @@ -753,17 +783,20 @@ def do_action(context: CommandContext, params: dict):
Labels.save_file(
context.state["labels"],
params["filename"],
default_suffix="h5",
default_suffix="slp",
save_frame_data=True,
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
filters = ["HDF5 dataset (*.h5)", "Compressed JSON dataset (*.json *.json.zip)"]
filters = [
"SLEAP HDF5 dataset (*.slp *.h5)",
"Compressed JSON dataset (*.json *.json.zip)",
]
filename, _ = FileDialog.save(
context.app,
caption="Save Labeled Frames As...",
dir=context.state["filename"] + ".h5",
dir=context.state["filename"] + ".slp",
filters=";;".join(filters),
)
if len(filename) == 0:
Expand Down Expand Up @@ -1586,7 +1619,10 @@ class MergeProject(EditCommand):

@classmethod
def ask_and_do(cls, context: CommandContext, params: dict):
filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"]
filters = [
"SLEAP HDF5 dataset (*.slp *.h5 *.hdf5)",
"SLEAP JSON dataset (*.json *.json.zip)",
]

filenames, selected_filter = FileDialog.openMultiple(
context.app,
Expand Down
2 changes: 1 addition & 1 deletion sleap/gui/learning/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def predict_subprocess(

# Make path where we'll save predictions
output_path = ".".join(
(video.filename, datetime.now().strftime("%y%m%d_%H%M%S"), "predictions.h5",)
(video.filename, datetime.now().strftime("%y%m%d_%H%M%S"), "predictions.slp",)
)

for job_path in trained_job_paths:
Expand Down
65 changes: 25 additions & 40 deletions sleap/gui/overlays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,36 +6,35 @@
import numpy as np
from typing import Sequence, Union

import sleap
from sleap.io.video import Video
from sleap.gui.widgets.video import QtVideoPlayer
from sleap.nn.data.providers import VideoReader
from sleap.nn.inference import VisualPredictor


@attr.s(auto_attribs=True)
class ModelData:
inference_object: Union[
"sleap.nn.peak_finding.ConfmapPeakFinder", "sleap.nn.paf_grouping.PAFGrouper"
]
predictor: VisualPredictor
video: Video
do_rescale: bool = False
output_scale: float = 1.0
adjust_vals: bool = True

def __getitem__(self, i):
def __getitem__(self, i: int):
"""Data data for frame i from predictor."""
frame_img = self.video[i]

frame_result = self.inference_object.inference(
self.inference_object.preproc(frame_img)
).numpy()
# Get predictions for frame i
frame_result = self.predictor.predict(VideoReader(self.video, [i]))

# We just want the single image results
if type(i) != slice:
frame_result = frame_result[0]
# todo: support for pafs
frame_result = frame_result[0][self.predictor.confidence_maps_key_name()]

if self.adjust_vals:
frame_result = np.clip(frame_result, 0, 1)

# Determine output scale by comparing original image with model output
self.output_scale = self.video.height / frame_result.shape[0]

return frame_result


Expand All @@ -44,17 +43,20 @@ class DataOverlay:

data: Sequence = None
player: QtVideoPlayer = None
overlay_class: QtWidgets.QGraphicsObject = None
overlay_class: Union["ConfMapsPlot", "MultiQuiverPlot", None] = None

def add_to_scene(self, video, frame_idx):
if self.data is None:
return

if self.overlay_class is None:
return

img_data = self.data[frame_idx]
# print(img_data.shape, np.ptp(img_data))

self._add(
self.player.view.scene,
self.overlay_class(img_data, scale=1.0 / self.data.output_scale),
to=self.player.view.scene,
what=self.overlay_class(img_data, scale=self.data.output_scale),
)

def _add(
Expand All @@ -68,38 +70,21 @@ def _add(

@classmethod
def from_model(cls, filename, video, **kwargs):
from sleap.nn.model import ModelOutputType, InferenceModel
from sleap.nn import job

# Load the trained model
trained_job = job.TrainingJob.load_json(filename)
inference_model = InferenceModel.from_training_job(trained_job)
model_output_type = trained_job.model.output_type

if trained_job.model.output_type == ModelOutputType.PART_AFFINITY_FIELD:
from sleap.nn import paf_grouping

inference_object = paf_grouping.PAFGrouper(inference_model=inference_model)
else:
from sleap.nn import peak_finding

inference_object = peak_finding.ConfmapPeakFinder(
inference_model=inference_model
)

# Construct the ModelData object that runs inference
data_object = ModelData(
inference_object, video, output_scale=inference_model.output_scale
predictor=VisualPredictor.from_trained_models(filename), video=video
)

# Determine whether to use confmap or paf overlay
# todo: make this selectable by user for bottom up model w/ both outputs
from sleap.gui.overlays.confmaps import ConfMapsPlot
from sleap.gui.overlays.pafs import MultiQuiverPlot

if model_output_type == ModelOutputType.PART_AFFINITY_FIELD:
overlay_class = MultiQuiverPlot
else:
overlay_class = ConfMapsPlot
# todo: support for pafs
# if model_output_type == ModelOutputType.PART_AFFINITY_FIELD:
# overlay_class = MultiQuiverPlot
# else:
overlay_class = ConfMapsPlot

return cls(data=data_object, overlay_class=overlay_class, **kwargs)

Expand Down
Loading

0 comments on commit 1ebb6e8

Please sign in to comment.