Skip to content

Commit

Permalink
Speed-up to track overlay.
Browse files Browse the repository at this point in the history
We were looping over each track for each frame, now we loop over
all instances for each frame for adding them to appropriate track
(which is much faster if there are lots of tracks for the video).
  • Loading branch information
ntabris committed Jan 13, 2020
1 parent c16a966 commit 174487b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
61 changes: 33 additions & 28 deletions sleap/gui/overlays/tracks.py
Expand Up @@ -9,10 +9,12 @@
import attr
import itertools

from typing import List
from typing import Iterable, List

from PySide2 import QtCore, QtGui

MAX_NODES_IN_TRAIL = 30


@attr.s(auto_attribs=True)
class TrackTrailOverlay:
Expand All @@ -37,42 +39,48 @@ class TrackTrailOverlay:
trail_length: int = 10
show: bool = False

def get_track_trails(self, frame_selection, track: Track):
def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]):
"""Get data needed to draw track trail.
Args:
frame_selection: an interable with the :class:`LabeledFrame`
frame_selection: an iterable with the :class:`LabeledFrame`
objects to include in trail.
track: the :class:`Track` for which to get trail
Returns:
list of lists of (x, y) tuples
Dictionary keyed by track, value is list of lists of (x, y) tuples
i.e., for every node in instance, we get a list of positions
"""

all_trails = [[] for _ in range(len(self.labels.nodes))]
all_track_trails = dict()

nodes = self.labels.nodes
if len(nodes) > MAX_NODES_IN_TRAIL:
nodes = nodes[:MAX_NODES_IN_TRAIL]

for frame in frame_selection:
frame_idx = frame.frame_idx

inst_on_track = [instance for instance in frame if instance.track == track]
if inst_on_track:
# just use the first instance from this track in this frame
inst = inst_on_track[0]
# loop through all nodes
for node_i, node in enumerate(self.labels.nodes):
for inst in frame:
if inst.track is not None:
if inst.track not in all_track_trails:
all_track_trails[inst.track] = [[] for _ in range(len(nodes))]

# loop through all nodes
for node_i, node in enumerate(nodes):

if node in inst.nodes and inst[node].visible:
point = (inst[node].x, inst[node].y)

if node in inst.nodes and inst[node].visible:
point = (inst[node].x, inst[node].y)
elif len(all_trails[node_i]):
point = all_trails[node_i][-1]
else:
point = None
# Add last location of node so that we can easily
# calculate trail length (since we adjust opacity).
elif len(all_track_trails[inst.track][node_i]):
point = all_track_trails[inst.track][node_i][-1]
else:
point = None

if point is not None:
all_trails[node_i].append(point)
if point is not None:
all_track_trails[inst.track][node_i].append(point)

return all_trails
return all_track_trails

def get_frame_selection(self, video: Video, frame_idx: int):
"""
Expand Down Expand Up @@ -116,17 +124,14 @@ def add_to_scene(self, video: Video, frame_idx: int):
video: current video
frame_idx: index of the frame to which the trail is attached
"""
if not self.show:
if not self.show or self.trail_length == 0:
return

frame_selection = self.get_frame_selection(video, frame_idx)
tracks_in_frame = self.get_tracks_in_frame(
video, frame_idx, include_trails=True
)

for track in tracks_in_frame:
all_track_trails = self.get_track_trails(frame_selection)

trails = self.get_track_trails(frame_selection, track)
for track, trails in all_track_trails.items():

color = QtGui.QColor(*self.player.color_manager.get_track_color(track))
pen = QtGui.QPen()
Expand Down
7 changes: 5 additions & 2 deletions tests/gui/test_tracks.py
Expand Up @@ -15,10 +15,13 @@ def test_track_trails(centered_pair_predictions):
assert tracks[0].name == "1"
assert tracks[1].name == "2"

tracks_with_trails = trail_manager.get_tracks_in_frame(labels.videos[0], 27, include_trails=True)
tracks_with_trails = trail_manager.get_tracks_in_frame(
labels.videos[0], 27, include_trails=True
)
assert len(tracks_with_trails) == 13

trails = trail_manager.get_track_trails(frames, tracks[0])
all_trails = trail_manager.get_track_trails(frames)
trails = all_trails[tracks[0]]

assert len(trails) == 24

Expand Down

0 comments on commit 174487b

Please sign in to comment.