Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions utils_cv/tracking/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.keypoints = None
self.mask_paths = None

# Init FairMOT opt object
# Init FairMOT opt object with all parameter settings
opt = opts()

# Read annotations
Expand All @@ -64,7 +64,7 @@ def __init__(
# Create FairMOT dataset object
transforms = T.Compose([T.ToTensor()])
self.train_data = JointDataset(
opt.opt,
opt,
self.root,
{name: self.fairmot_imlist_path},
(opt.input_w, opt.input_h),
Expand Down
66 changes: 28 additions & 38 deletions utils_cv/tracking/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
dataset: Optional[TrackingDataset] = None,
model_path: Optional[str] = None,
arch: str = "dla_34",
head_conv: int = None,
head_conv: int = -1,
) -> None:
"""
Initialize learner object.
Expand All @@ -142,10 +142,9 @@ def __init__(
"""
self.opt = opts()
self.opt.arch = arch
self.opt.head_conv = head_conv if head_conv else -1
self.opt.gpus = _get_gpu_str()
self.opt.set_head_conv(head_conv)
self.opt.set_gpus(_get_gpu_str())
self.opt.device = torch_device()

self.dataset = dataset
self.model = None
self._init_model(model_path)
Expand Down Expand Up @@ -183,40 +182,42 @@ def fit(
"""
if not self.dataset:
raise Exception("No dataset provided")
lr_step = str(lr_step)
if type(lr_step) is not list:
lr_step = [lr_step]
lr_step = [int(x) for x in lr_step]

opt_fit = deepcopy(self.opt) # copy opt to avoid bug
opt_fit.lr = lr
opt_fit.lr_step = lr_step
opt_fit.num_epochs = num_epochs
# update parameters
self.opt.lr = lr
self.opt.lr_step = lr_step
self.opt.num_epochs = num_epochs
opt = deepcopy(self.opt) #to avoid fairMOT over-writing opt

# update dataset options
opt_fit.update_dataset_info_and_set_heads(self.dataset.train_data)
opt.update_dataset_info_and_set_heads(self.dataset.train_data)

# initialize dataloader
train_loader = self.dataset.train_dl

self.model = create_model(
self.opt.arch, self.opt.heads, self.opt.head_conv
opt.arch, opt.heads, opt.head_conv
)
self.model = load_model(self.model, opt_fit.load_model)
self.optimizer = torch.optim.Adam(self.model.parameters(), opt_fit.lr)
self.model = load_model(self.model, opt.load_model)
self.optimizer = torch.optim.Adam(self.model.parameters(), opt.lr)
start_epoch = 0

Trainer = train_factory[opt_fit.task]
trainer = Trainer(opt_fit.opt, self.model, self.optimizer)
trainer.set_device(opt_fit.gpus, opt_fit.chunk_sizes, opt_fit.device)
Trainer = train_factory[opt.task]
trainer = Trainer(opt, self.model, self.optimizer)
trainer.set_device(opt.gpus, opt.chunk_sizes, opt.device)

# initialize loss vars
self.losses_dict = defaultdict(list)

# training loop
for epoch in range(
start_epoch + 1, start_epoch + opt_fit.num_epochs + 1
start_epoch + 1, start_epoch + opt.num_epochs + 1
):
print(
"=" * 5,
f" Epoch: {epoch}/{start_epoch + opt_fit.num_epochs} ",
f" Epoch: {epoch}/{start_epoch + opt.num_epochs} ",
"=" * 5,
)
self.epoch = epoch
Expand All @@ -226,8 +227,8 @@ def fit(
print(f"{k}:{v} min")
else:
print(f"{k}: {v}")
if epoch in opt_fit.lr_step:
lr = opt_fit.lr * (0.1 ** (opt_fit.lr_step.index(epoch) + 1))
if epoch in opt.lr_step:
lr = opt.lr * (0.1 ** (opt.lr_step.index(epoch) + 1))
for param_group in self.optimizer.param_groups:
param_group["lr"] = lr

Expand Down Expand Up @@ -369,8 +370,6 @@ def predict(
self,
im_or_video_path: str,
conf_thres: float = 0.6,
det_thres: float = 0.3,
nms_thres: float = 0.4,
track_buffer: int = 30,
min_box_area: float = 200,
frame_rate: int = 30,
Expand All @@ -382,8 +381,6 @@ def predict(
im_or_video_path: path to image(s) or video. Supports jpg, jpeg, png, tif formats for images.
Supports mp4, avi formats for video.
conf_thres: confidence thresh for tracking
det_thres: confidence thresh for detection
nms_thres: iou thresh for nms
track_buffer: tracking buffer
min_box_area: filter out tiny boxes
frame_rate: frame rate
Expand All @@ -392,20 +389,13 @@ def predict(

Implementation inspired from code found here: https://github.com/ifzhang/FairMOT/blob/master/src/track.py
"""
opt_pred = deepcopy(self.opt) # copy opt to avoid bug
opt_pred.conf_thres = conf_thres
opt_pred.det_thres = det_thres
opt_pred.nms_thres = nms_thres
opt_pred.track_buffer = track_buffer
opt_pred.min_box_area = min_box_area
self.opt.conf_thres = conf_thres
self.opt.track_buffer = track_buffer
self.opt.min_box_area = min_box_area
opt = deepcopy(self.opt) #to avoid fairMOT over-writing opt

# initialize tracker
if self.model:
tracker = JDETracker(
opt_pred.opt, frame_rate=frame_rate, model=self.model
)
else:
tracker = JDETracker(opt_pred.opt, frame_rate=frame_rate)
tracker = JDETracker(opt, frame_rate=frame_rate, model=self.model)

# initialize dataloader
dataloader = self._get_dataloader(im_or_video_path)
Expand All @@ -422,7 +412,7 @@ def predict(
tlbr = t.tlbr
tid = t.track_id
vertical = tlwh[2] / tlwh[3] > 1.6
if tlwh[2] * tlwh[3] > opt_pred.min_box_area and not vertical:
if tlwh[2] * tlwh[3] > opt.min_box_area and not vertical:
bb = TrackingBbox(
tlbr[0], tlbr[1], tlbr[2], tlbr[3], frame_id, tid
)
Expand Down
Loading