Skip to content

Commit

Permalink
Merge pull request #203 from murthylab/viz
Browse files Browse the repository at this point in the history
Fixes for pip install.
  • Loading branch information
ntabris committed Oct 2, 2019
2 parents 4cb1a07 + bb66140 commit 10923bd
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 15 deletions.
38 changes: 29 additions & 9 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,7 +1021,7 @@ def loadVideo(self, video: Video, video_idx: int = None):
def openSkeleton(self):
"""Shows gui for loading saved skeleton into project."""
filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"]
filename, selected_filter = QFileDialog.getOpenFileName(
filename, selected_filter = openFileDialog(
self,
dir=None,
caption="Open skeleton...",
Expand Down Expand Up @@ -1050,7 +1050,7 @@ def saveSkeleton(self):
"""Shows gui for saving skeleton from project."""
default_name = "skeleton.json"
filters = ["JSON skeleton (*.json)", "HDF5 skeleton (*.h5 *.hdf5)"]
filename, selected_filter = QFileDialog.getSaveFileName(
filename, selected_filter = saveFileDialog(
self,
caption="Save As...",
dir=default_name,
Expand Down Expand Up @@ -1280,7 +1280,7 @@ def visualizeOutputs(self):
models_dir = os.path.join(os.path.dirname(self.filename), "models/")

# Show dialog
filename, selected_filter = QFileDialog.getOpenFileName(
filename, selected_filter = openFileDialog(
self,
dir=models_dir,
caption="Import model outputs...",
Expand Down Expand Up @@ -1495,7 +1495,7 @@ def clearFrameNegativeAnchors(self):
def importPredictions(self):
"""Starts gui for importing another dataset into currently one."""
filters = ["HDF5 dataset (*.h5 *.hdf5)", "JSON labels (*.json *.json.zip)"]
filenames, selected_filter = QFileDialog.getOpenFileNames(
filenames, selected_filter = openFileDialogs(
self,
dir=None,
caption="Import labeled data...",
Expand Down Expand Up @@ -1859,12 +1859,12 @@ def openProject(self, first_open: bool = False):
"DeepLabCut csv (*.csv)",
]

filename, selected_filter = QFileDialog.getOpenFileName(
filename, selected_filter = openFileDialog(
self,
dir=None,
caption="Import labeled data...",
filter=";;".join(filters),
options=self._file_dialog_options,
# options=self._file_dialog_options,
)

if len(filename) == 0:
Expand Down Expand Up @@ -1896,7 +1896,7 @@ def saveProjectAs(self):
"JSON labels (*.json)",
"Compressed JSON (*.zip)",
]
filename, selected_filter = QFileDialog.getSaveFileName(
filename, selected_filter = saveFileDialog(
self,
caption="Save As...",
dir=default_name,
Expand Down Expand Up @@ -2000,7 +2000,7 @@ def exportLabeledClip(self):
if not okay:
return

filename, _ = QFileDialog.getSaveFileName(
filename, _ = saveFileDialog(
self,
caption="Save Video As...",
dir=self.filename + ".avi",
Expand All @@ -2023,7 +2023,7 @@ def exportLabeledClip(self):
def exportLabeledFrames(self):
"""Gui for exporting the training dataset of labels/frame images."""
filters = ["HDF5 dataset (*.h5)", "Compressed JSON dataset (*.json *.json.zip)"]
filename, _ = QFileDialog.getSaveFileName(
filename, _ = saveFileDialog(
self,
caption="Save Labeled Frames As...",
dir=self.filename + ".h5",
Expand Down Expand Up @@ -2173,6 +2173,26 @@ def openKeyRef(self):
ShortcutDialog().exec_()


def openFileDialog(*args, **kwargs):
"""Wrapper for openFileDialog.
Passes along everything except empty "options" arg.
"""
if "options" in kwargs and not kwargs["options"]:
del kwargs["options"]
return QFileDialog.getOpenFileName(*args, **kwargs)


def saveFileDialog(*args, **kwargs):
"""Wrapper for saveFileDialog.
Passes along everything except empty "options" arg.
"""
if "options" in kwargs and not kwargs["options"]:
del kwargs["options"]
return QFileDialog.getSaveFileName(*args, **kwargs)


def main(*args, **kwargs):
"""Starts new instance of app."""
app = QApplication([])
Expand Down
Empty file added sleap/gui/overlays/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion sleap/gui/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def load_video(self, video: Video, initial_frame=0, plot=True):
# self.seekbar.setTickInterval(1)
self.seekbar.setValue(self.frame_idx)
self.seekbar.setMinimum(0)
self.seekbar.setMaximum(self.video.frames - 1)
self.seekbar.setMaximum(self.video.last_frame_idx)
self.seekbar.setEnabled(True)

if plot:
Expand Down
Empty file added sleap/info/__init__.py
Empty file.
1 change: 1 addition & 0 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,6 +1923,7 @@ def save_frame_data_hdf5(
Args:
output_path: Path to HDF5 file.
format: The image format to use for the data. Defaults to png.
all_labels: Include any labeled frames, not just the frames
we'll use for training (i.e., those with Instances).
Expand Down
46 changes: 41 additions & 5 deletions sleap/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,28 +140,42 @@ def frames(self):
def channels(self):
"""See :class:`Video`."""
if "channels" in self.__dataset_h5.attrs:
return self.__dataset_h5.attrs["channels"]
return int(self.__dataset_h5.attrs["channels"])
return self.__dataset_h5.shape[self.__channel_idx]

@property
def width(self):
"""See :class:`Video`."""
if "width" in self.__dataset_h5.attrs:
return self.__dataset_h5.attrs["width"]
return int(self.__dataset_h5.attrs["width"])
return self.__dataset_h5.shape[self.__width_idx]

@property
def height(self):
"""See :class:`Video`."""
if "height" in self.__dataset_h5.attrs:
return self.__dataset_h5.attrs["height"]
return int(self.__dataset_h5.attrs["height"])
return self.__dataset_h5.shape[self.__height_idx]

@property
def dtype(self):
"""See :class:`Video`."""
return self.__dataset_h5.dtype

@property
def last_frame_idx(self) -> int:
"""
The idx number of the last frame.
Overrides method of base :class:`Video` class for videos with
select frames indexed by number from original video, since the last
frame index here will not match the number of frames in video.
"""
if self.__original_to_current_frame_idx:
last_key = sorted(self.__original_to_current_frame_idx.keys())[-1]
return last_key
return self.frames - 1

def get_frame(self, idx) -> np.ndarray:
"""
Get a frame from the underlying HDF5 video data.
Expand Down Expand Up @@ -514,6 +528,19 @@ def dtype(self):
"""See :class:`Video`."""
return self.__img.dtype

@property
def last_frame_idx(self) -> int:
"""
The idx number of the last frame.
Overrides method of base :class:`Video` class for videos with
select frames indexed by number from original video, since the last
frame index here will not match the number of frames in video.
"""
if self.index_by_original:
return self.__store.frame_max
return self.frames - 1

def get_frame(self, frame_number: int) -> np.ndarray:
"""
Get a frame from the underlying ImgStore video data.
Expand Down Expand Up @@ -637,6 +664,15 @@ def num_frames(self) -> int:
"""
return self.frames

@property
def last_frame_idx(self) -> int:
"""
The idx number of the last frame. Usually `numframes - 1`.
"""
if hasattr(self.backend, "last_frame_idx"):
return self.backend.last_frame_idx
return self.frames - 1

@property
def shape(self) -> Tuple[int, int, int, int]:
""" Returns (frame count, height, width, channels)."""
Expand Down Expand Up @@ -882,7 +918,7 @@ def to_imgstore(
format,
mode="w",
basedir=path,
imgshape=(self.shape[1], self.shape[2], self.shape[3]),
imgshape=(self.height, self.width, self.channels),
chunksize=1000,
)

Expand All @@ -899,7 +935,7 @@ def to_imgstore(
# since we can't save an empty imgstore.
if len(frame_numbers) == 0:
store.add_image(
np.zeros((self.shape[1], self.shape[2], self.shape[3])), 0, time.time()
np.zeros((self.height, self.width, self.channels)), 0, time.time()
)

store.close()
Expand Down
30 changes: 30 additions & 0 deletions tests/io/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def test_mp4_get_shape(small_robot_mp4_vid):
assert small_robot_mp4_vid.shape == (166, 320, 560, 3)


def test_mp4_fps(small_robot_mp4_vid):
assert small_robot_mp4_vid.fps == 30.0


def test_mp4_len(small_robot_mp4_vid):
assert len(small_robot_mp4_vid) == 166

Expand Down Expand Up @@ -154,6 +158,8 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir):
frames = imgstore_vid.get_frames([0, 1, 2])
assert frames.shape == (3, 320, 560, 3)

assert imgstore_vid.last_frame_idx == len(frame_indices) - 1

with pytest.raises(ValueError):
imgstore_vid.get_frames(frame_indices)

Expand All @@ -164,10 +170,30 @@ def test_imgstore_indexing(small_robot_mp4_vid, tmpdir):
frames = imgstore_vid.get_frames(frame_indices)
assert frames.shape == (3, 320, 560, 3)

assert imgstore_vid.last_frame_idx == max(frame_indices)

with pytest.raises(ValueError):
imgstore_vid.get_frames([0, 1, 2])


def test_imgstore_deferred_loading(small_robot_mp4_vid, tmpdir):
path = os.path.join(tmpdir, "test_imgstore")
frame_indices = [20, 40, 15]
vid = small_robot_mp4_vid.to_imgstore(path, frame_numbers=frame_indices)

# This is actually testing that the __img will be loaded when needed,
# since we use __img to get dtype.
assert vid.dtype == np.dtype("uint8")


def test_imgstore_single_channel(centered_pair_vid, tmpdir):
path = os.path.join(tmpdir, "test_imgstore")
frame_indices = [20, 40, 15]
vid = centered_pair_vid.to_imgstore(path, frame_numbers=frame_indices)

assert vid.channels == 1


def test_empty_hdf5_video(small_robot_mp4_vid, tmpdir):
path = os.path.join(tmpdir, "test_to_hdf5")
hdf5_vid = small_robot_mp4_vid.to_hdf5(path, "testvid", frame_numbers=[])
Expand Down Expand Up @@ -217,6 +243,8 @@ def test_hdf5_indexing(small_robot_mp4_vid, tmpdir):
frames = hdf5_vid.get_frames([0, 1, 2])
assert frames.shape == (3, 320, 560, 3)

assert hdf5_vid.last_frame_idx == len(frame_indices) - 1

with pytest.raises(ValueError):
hdf5_vid.get_frames(frame_indices)

Expand All @@ -232,5 +260,7 @@ def test_hdf5_indexing(small_robot_mp4_vid, tmpdir):
frames = hdf5_vid2.get_frames(frame_indices)
assert frames.shape == (3, 320, 560, 3)

assert hdf5_vid2.last_frame_idx == max(frame_indices)

with pytest.raises(ValueError):
hdf5_vid2.get_frames([0, 1, 2])

0 comments on commit 10923bd

Please sign in to comment.