Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
dlyldxwl committed Mar 19, 2019
1 parent 092c8e3 commit 95656f0
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion train.py
Expand Up @@ -143,7 +143,12 @@ def weights_init(m):
priors = priorbox.forward()
if args.cuda:
priors = priors.cuda()


def updateBN(s=0.0001):
for m in net.modules():
if isinstance(m,torch.nn.BatchNorm2d):
m.weight.grad.detach().add_(s*torch.sign(m.weight.detach()))

def train():
net.train()
# loss counters
Expand Down Expand Up @@ -220,6 +225,8 @@ def train():
loss_l, loss_c = criterion(out, priors, targets)
loss = loss_l + loss_c
loss.backward()
# if epoch > args.warm_epoch:
# updateBN()
optimizer.step()
t1 = time.time()
loc_loss += loss_l.item()
Expand Down

0 comments on commit 95656f0

Please sign in to comment.