/
mot.py
105 lines (87 loc) · 4.55 KB
/
mot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.decode import mot_decode
from models.losses import FocalLoss
from models.losses import RegL1Loss, RegLoss, NormRegL1Loss, RegWeightedL1Loss
from models.utils import _sigmoid, _tranpose_and_gather_feat
from utils.post_process import ctdet_post_process
from .base_trainer import BaseTrainer
class MotLoss(torch.nn.Module):
def __init__(self, opt):
super(MotLoss, self).__init__()
self.crit = torch.nn.MSELoss() if opt.mse_loss else FocalLoss()
self.crit_reg = RegL1Loss() if opt.reg_loss == 'l1' else \
RegLoss() if opt.reg_loss == 'sl1' else None
self.crit_wh = torch.nn.L1Loss(reduction='sum') if opt.dense_wh else \
NormRegL1Loss() if opt.norm_wh else \
RegWeightedL1Loss() if opt.cat_spec_wh else self.crit_reg
self.opt = opt
self.emb_dim = opt.reid_dim
self.nID = opt.nID
self.classifier = nn.Linear(self.emb_dim, self.nID)
self.IDLoss = nn.CrossEntropyLoss(ignore_index=-1)
#self.TriLoss = TripletLoss()
self.emb_scale = math.sqrt(2) * math.log(self.nID - 1)
self.s_det = nn.Parameter(-1.85 * torch.ones(1))
self.s_id = nn.Parameter(-1.05 * torch.ones(1))
def forward(self, outputs, batch):
opt = self.opt
hm_loss, wh_loss, off_loss, id_loss = 0, 0, 0, 0
for s in range(opt.num_stacks):
output = outputs[s]
if not opt.mse_loss:
output['hm'] = _sigmoid(output['hm'])
hm_loss += self.crit(output['hm'], batch['hm']) / opt.num_stacks
if opt.wh_weight > 0:
if opt.dense_wh:
mask_weight = batch['dense_wh_mask'].sum() + 1e-4
wh_loss += (
self.crit_wh(output['wh'] * batch['dense_wh_mask'],
batch['dense_wh'] * batch['dense_wh_mask']) /
mask_weight) / opt.num_stacks
else:
wh_loss += self.crit_reg(
output['wh'], batch['reg_mask'],
batch['ind'], batch['wh']) / opt.num_stacks
if opt.reg_offset and opt.off_weight > 0:
off_loss += self.crit_reg(output['reg'], batch['reg_mask'],
batch['ind'], batch['reg']) / opt.num_stacks
if opt.id_weight > 0:
id_head = _tranpose_and_gather_feat(output['id'], batch['ind'])
id_head = id_head[batch['reg_mask'] > 0].contiguous()
id_head = self.emb_scale * F.normalize(id_head)
id_target = batch['ids'][batch['reg_mask'] > 0]
id_output = self.classifier(id_head).contiguous()
id_loss += self.IDLoss(id_output, id_target)
#id_loss += self.IDLoss(id_output, id_target) + self.TriLoss(id_head, id_target)
#loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss + opt.id_weight * id_loss
det_loss = opt.hm_weight * hm_loss + opt.wh_weight * wh_loss + opt.off_weight * off_loss
loss = torch.exp(-self.s_det) * det_loss + torch.exp(-self.s_id) * id_loss + (self.s_det + self.s_id)
loss *= 0.5
#print(loss, hm_loss, wh_loss, off_loss, id_loss)
loss_stats = {'loss': loss, 'hm_loss': hm_loss,
'wh_loss': wh_loss, 'off_loss': off_loss, 'id_loss': id_loss}
return loss, loss_stats
class MotTrainer(BaseTrainer):
def __init__(self, opt, model, optimizer=None):
super(MotTrainer, self).__init__(opt, model, optimizer=optimizer)
def _get_losses(self, opt):
loss_states = ['loss', 'hm_loss', 'wh_loss', 'off_loss', 'id_loss']
loss = MotLoss(opt)
return loss_states, loss
def save_result(self, output, batch, results):
reg = output['reg'] if self.opt.reg_offset else None
dets = mot_decode(
output['hm'], output['wh'], reg=reg,
cat_spec_wh=self.opt.cat_spec_wh, K=self.opt.K)
dets = dets.detach().cpu().numpy().reshape(1, -1, dets.shape[2])
dets_out = ctdet_post_process(
dets.copy(), batch['meta']['c'].cpu().numpy(),
batch['meta']['s'].cpu().numpy(),
output['hm'].shape[2], output['hm'].shape[3], output['hm'].shape[1])
results[batch['meta']['img_id'].cpu().numpy()[0]] = dets_out[0]