Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
382 changes: 115 additions & 267 deletions scenarios/tracking/01_training_introduction.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion utils_cv/tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def _write_fairMOT_format(self) -> None:
self.fairmot_imlist_path = osp.join(
self.root, "{}.train".format(self.name)
)
with open(self.fairmot_imlist_path, "a") as f:
with open(self.fairmot_imlist_path, "w") as f:
for im_filename in sorted(self.im_filenames):
f.write(osp.join(self.im_dir, im_filename) + "\n")

Expand Down
20 changes: 10 additions & 10 deletions utils_cv/tracking/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_gpu_str():

def _get_frame(input_video: str, frame_id: int):
video = cv2.VideoCapture()
video.open(input_video)
video.open(input_video)
video.set(cv2.CAP_PROP_POS_FRAMES, frame_id)
_, im = video.read()
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
Expand Down Expand Up @@ -178,7 +178,7 @@ def fit(

Raise:
Exception if dataset is undefined

Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/train.py
"""
if not self.dataset:
Expand Down Expand Up @@ -227,7 +227,7 @@ def fit(
print(f"{k}: {v}")
if epoch in opt_fit.lr_step:
lr = opt_fit.lr * (0.1 ** (opt_fit.lr_step.index(epoch) + 1))
for param_group in optimizer.param_groups:
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr

# store losses in each epoch
Expand All @@ -237,11 +237,11 @@ def fit(

def plot_training_losses(self, figsize: Tuple[int, int] = (10, 5)) -> None:
"""
Plot training loss.
Plot training loss.

Args:
figsize (optional): width and height wanted for figure of training-loss plot

"""
fig = plt.figure(figsize=figsize)
ax1 = fig.add_subplot(1, 1, 1)
Expand Down Expand Up @@ -274,15 +274,15 @@ def evaluate(
self, results: Dict[int, List[TrackingBbox]], gt_root_path: str
) -> str:

"""
"""
Evaluate performance wrt MOTA, MOTP, track quality measures, global ID measures, and more,
as computed by py-motmetrics on a single experiment. By default, use 'single_vid' as exp_name.

Args:
results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
results: prediction results from predict() function, i.e. Dict[int, List[TrackingBbox]]
gt_root_path: path of dataset containing GT annotations in MOTchallenge format (xywh)
Returns:
strsummary: str output by method in 'motmetrics' package, containing metrics scores
strsummary: str output by method in 'motmetrics' package, containing metrics scores
"""

# Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
Expand Down Expand Up @@ -371,7 +371,7 @@ def predict(

Args:
im_or_video_path: path to image(s) or video. Supports jpg, jpeg, png, tif formats for images.
Supports mp4, avi formats for video.
Supports mp4, avi formats for video.
conf_thres: confidence thresh for tracking
det_thres: confidence thresh for detection
nms_thres: iou thresh for nms
Expand Down
43 changes: 21 additions & 22 deletions utils_cv/tracking/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@


def plot_single_frame(
results: Dict[int, List[TrackingBbox]], input_video: str, frame_id: int
input_video: str,
frame_id: int,
results: Dict[int, List[TrackingBbox]] = None
) -> None:
"""
Plot the bounding box and id on a wanted frame. Display as image to front end.
"""
Plot the bounding box and id on a wanted frame. Display as image to front end.

Args:
results: dictionary mapping frame id to a list of predicted TrackingBboxes
input_video: path to the input video
frame_id: frame_id for frame to show tracking result
results: dictionary mapping frame id to a list of predicted TrackingBboxes
"""

if results is None: # if no tracking bboxes, only plot image
# Get frame from video
im = Image.fromarray(_get_frame(input_video, frame_id))
# Display image
IPython.display.display(im)
# Extract frame
im = _get_frame(input_video, frame_id)

else:
# Overlay results
if results:
results = OrderedDict(sorted(results.items()))

# Assign bbox color per id
Expand All @@ -43,27 +43,26 @@ def plot_single_frame(
)
color_map = assign_colors(unique_ids)

# Get frame from video
im = _get_frame(input_video, frame_id)

# Extract tracking results for wanted frame, and draw bboxes+tracking id, display frame
cur_tracks = results[frame_id]

if len(cur_tracks) > 0:
im = draw_boxes(im, cur_tracks, color_map)
im = Image.fromarray(im)
IPython.display.display(im)

# Display image
im = Image.fromarray(im)
IPython.display.display(im)


def play_video(
results: Dict[int, List[TrackingBbox]], input_video: str
) -> None:
"""
"""
Plot the predicted tracks on the input video. Displays to front-end as sequence of images stringed together in a video.

Args:
results: dictionary mapping frame id to a list of predicted TrackingBboxes
input_video: path to the input video
input_video: path to the input video
"""

results = OrderedDict(sorted(results.items()))
Expand Down Expand Up @@ -98,7 +97,7 @@ def play_video(
def write_video(
results: Dict[int, List[TrackingBbox]], input_video: str, output_video: str
) -> None:
"""
"""
Plot the predicted tracks on the input video. Write the output to {output_path}.

Args:
Expand Down Expand Up @@ -143,7 +142,7 @@ def draw_boxes(
cur_tracks: List[TrackingBbox],
color_map: Dict[int, Tuple[int, int, int]],
) -> np.ndarray:
"""
"""
Overlay bbox and id labels onto the frame

Args:
Expand Down Expand Up @@ -181,11 +180,11 @@ def draw_boxes(


def assign_colors(id_list: List[int],) -> Dict[int, Tuple[int, int, int]]:
"""
"""
Produce corresponding unique color palettes for unique ids

Args:
id_list: list of track ids
id_list: list of track ids
"""
palette = (2 ** 11 - 1, 2 ** 15 - 1, 2 ** 20 - 1)

Expand Down