Skip to content

Commit

Permalink
Fixing averaging of loss printed during training
Browse files Browse the repository at this point in the history
  • Loading branch information
swansonk14 committed Sep 21, 2020
1 parent 7452ee6 commit e9a28de
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions chemprop/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def train(model: MoleculeModel,
debug = logger.debug if logger is not None else print

model.train()
loss_sum, iter_count = 0, 0
loss_sum = iter_count = 0

for batch in tqdm(data_loader, total=len(data_loader), leave=False):
# Prepare batch
Expand All @@ -67,7 +67,7 @@ def train(model: MoleculeModel,
loss = loss.sum() / mask.sum()

loss_sum += loss.item()
iter_count += len(batch)
iter_count += 1

loss.backward()
if args.grad_clip:
Expand All @@ -85,7 +85,7 @@ def train(model: MoleculeModel,
pnorm = compute_pnorm(model)
gnorm = compute_gnorm(model)
loss_avg = loss_sum / iter_count
loss_sum, iter_count = 0, 0
loss_sum = iter_count = 0

lrs_str = ', '.join(f'lr_{i} = {lr:.4e}' for i, lr in enumerate(lrs))
debug(f'Loss = {loss_avg:.4e}, PNorm = {pnorm:.4f}, GNorm = {gnorm:.4f}, {lrs_str}')
Expand Down

0 comments on commit e9a28de

Please sign in to comment.