Skip to content
Permalink
Browse files

Fix bug where model wasn't in training mode every epoch (#17)

  • Loading branch information...
achyudh committed May 2, 2019
1 parent 97a3d2d commit e52eff453d4c8330b1a6ae5778ed7726cdac5c1e
Showing with 1 addition and 2 deletions.
  1. +1 −2 common/trainers/bert_trainer.py
@@ -42,6 +42,7 @@ def __init__(self, model, optimizer, processor, args):

def train_epoch(self, train_dataloader):
for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
self.model.train()
batch = tuple(t.to(self.args.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
logits = self.model(input_ids, segment_ids, input_mask)
@@ -92,8 +93,6 @@ def train(self):

train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=self.args.batch_size)

self.model.train()

for epoch in trange(int(self.args.epochs), desc="Epoch"):
self.train_epoch(train_dataloader)
dev_evaluator = BertEvaluator(self.model, self.processor, self.args, split='dev')

0 comments on commit e52eff4

Please sign in to comment.
You can’t perform that action at this time.