-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
track.py
200 lines (170 loc) · 7.67 KB
/
track.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
# Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import argparse
from functools import partial
from pathlib import Path
import torch
from boxmot import TRACKERS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils import ROOT, WEIGHTS
from boxmot.utils.checks import TestRequirements
from examples.detectors import get_yolo_inferer
__tr = TestRequirements()
__tr.check_packages(('ultralytics @ git+https://github.com/mikel-brostrom/ultralytics.git', )) # install
from ultralytics import YOLO
from ultralytics.data.utils import VID_FORMATS
from ultralytics.utils.plotting import save_one_box
from examples.utils import write_mot_results
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
assert predictor.custom_args.tracking_method in TRACKERS, \
f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}"
tracking_config = \
ROOT /\
'boxmot' /\
'configs' /\
(predictor.custom_args.tracking_method + '.yaml')
trackers = []
for i in range(predictor.dataset.bs):
tracker = create_tracker(
predictor.custom_args.tracking_method,
tracking_config,
predictor.custom_args.reid_model,
predictor.device,
predictor.custom_args.half,
predictor.custom_args.per_class
)
# motion only modeles do not have
if hasattr(tracker, 'model'):
tracker.model.warmup()
trackers.append(tracker)
predictor.trackers = trackers
@torch.no_grad()
def run(args):
yolo = YOLO(
args.yolo_model if 'yolov8' in str(args.yolo_model) else 'yolov8n.pt',
)
results = yolo.track(
source=args.source,
conf=args.conf,
iou=args.iou,
agnostic_nms=args.agnostic_nms,
show=args.show,
stream=True,
device=args.device,
show_conf=args.show_conf,
save_txt=args.save_txt,
show_labels=args.show_labels,
save=args.save,
verbose=args.verbose,
exist_ok=args.exist_ok,
project=args.project,
name=args.name,
classes=args.classes,
imgsz=args.imgsz,
vid_stride=args.vid_stride,
line_width=args.line_width
)
yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True))
if 'yolov8' not in str(args.yolo_model):
# replace yolov8 model
m = get_yolo_inferer(args.yolo_model)
model = m(
model=args.yolo_model,
device=yolo.predictor.device,
args=yolo.predictor.args
)
yolo.predictor.model = model
# store custom args in predictor
yolo.predictor.custom_args = args
for frame_idx, r in enumerate(results):
if r.boxes.data.shape[1] == 7:
if yolo.predictor.source_type.webcam or args.source.endswith(VID_FORMATS):
p = yolo.predictor.save_dir / 'mot' / (args.source + '.txt')
yolo.predictor.mot_txt_path = p
elif 'MOT16' or 'MOT17' or 'MOT20' in args.source:
p = yolo.predictor.save_dir / 'mot' / (Path(args.source).parent.name + '.txt')
yolo.predictor.mot_txt_path = p
if args.save_mot:
write_mot_results(
yolo.predictor.mot_txt_path,
r,
frame_idx,
)
if args.save_id_crops:
for d in r.boxes:
print('args.save_id_crops', d.data)
save_one_box(
d.xyxy,
r.orig_img.copy(),
file=(
yolo.predictor.save_dir / 'crops' /
str(int(d.cls.cpu().numpy().item())) /
str(int(d.id.cpu().numpy().item())) / f'{frame_idx}.jpg'
),
BGR=True
)
if args.save_mot:
print(f'MOT results saved to {yolo.predictor.mot_txt_path}')
def parse_opt():
parser = argparse.ArgumentParser()
parser.add_argument('--yolo-model', type=Path, default=WEIGHTS / 'yolov8n',
help='yolo model path')
parser.add_argument('--reid-model', type=Path, default=WEIGHTS / 'osnet_x0_25_msmt17.pt',
help='reid model path')
parser.add_argument('--tracking-method', type=str, default='deepocsort',
help='deepocsort, botsort, strongsort, ocsort, bytetrack')
parser.add_argument('--source', type=str, default='0',
help='file/dir/URL/glob, 0 for webcam')
parser.add_argument('--imgsz', '--img', '--img-size', nargs='+', type=int, default=[640],
help='inference size h,w')
parser.add_argument('--conf', type=float, default=0.5,
help='confidence threshold')
parser.add_argument('--iou', type=float, default=0.7,
help='intersection over union (IoU) threshold for NMS')
parser.add_argument('--device', default='',
help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
parser.add_argument('--show', action='store_true',
help='display tracking video results')
parser.add_argument('--save', action='store_true',
help='save video tracking results')
# class 0 is person, 1 is bycicle, 2 is car... 79 is oven
parser.add_argument('--classes', nargs='+', type=int,
help='filter by class: --classes 0, or --classes 0 2 3')
parser.add_argument('--project', default=ROOT / 'runs' / 'track',
help='save results to project/name')
parser.add_argument('--name', default='exp',
help='save results to project/name')
parser.add_argument('--exist-ok', action='store_true',
help='existing project/name ok, do not increment')
parser.add_argument('--half', action='store_true',
help='use FP16 half-precision inference')
parser.add_argument('--vid-stride', type=int, default=1,
help='video frame-rate stride')
parser.add_argument('--show-labels', action='store_false',
help='either show all or only bboxes')
parser.add_argument('--show-conf', action='store_false',
help='hide confidences when show')
parser.add_argument('--save-txt', action='store_true',
help='save tracking results in a txt file')
parser.add_argument('--save-id-crops', action='store_true',
help='save each crop to its respective id folder')
parser.add_argument('--save-mot', action='store_true',
help='save tracking results in a single txt file')
parser.add_argument('--line-width', default=None, type=int,
help='The line width of the bounding boxes. If None, it is scaled to the image size.')
parser.add_argument('--per-class', default=False, action='store_true',
help='not mix up classes when tracking')
parser.add_argument('--verbose', default=True, action='store_true',
help='print results per frame')
parser.add_argument('--agnostic-nms', default=False, action='store_true',
help='class-agnostic NMS')
opt = parser.parse_args()
return opt
if __name__ == "__main__":
opt = parse_opt()
run(opt)