Skip to content

Commit

Permalink
Fix warm up and progbar in BiaffineDependencyParser https://bbs.hankc…
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Dec 25, 2020
1 parent 56c44d3 commit 27f5ded
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 6 deletions.
5 changes: 4 additions & 1 deletion hanlp/common/component.py
Expand Up @@ -5,6 +5,7 @@
import logging
import math
import os
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, List

Expand Down Expand Up @@ -95,7 +96,9 @@ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128,
samples = size_of_dataset(tst_data)
num_batches = math.ceil(samples / batch_size)
if warm_up:
self.model.predict_on_batch(list(tst_data.take(1))[0])
for x, y in tst_data:
self.model.predict_on_batch(x)
break
if output:
assert save_dir, 'Must pass save_dir in order to output'
if isinstance(output, bool):
Expand Down
8 changes: 4 additions & 4 deletions hanlp/components/parsers/biaffine_parser.py
Expand Up @@ -256,14 +256,14 @@ def evaluate_dataset(self, tst_data, callbacks, output, num_batches):
params = {'verbose': 1, 'epochs': 1, 'metrics': ['loss'] + self.config.metrics, 'steps': steps_per_epoch}
for c in callbacks:
c.set_params(params)
c.on_train_begin() # otherwise AttributeError: 'ProgbarLogger' object has no attribute 'verbose'
c.on_epoch_begin(0)
c.on_test_begin()
c.on_epoch_end(0)
logs = {}
if output:
output = open(output, 'w', encoding='utf-8')
for idx, ((words, feats), (arcs, rels)) in enumerate(iter(tst_data)):
for c in callbacks:
c.on_batch_begin(idx, logs)
c.on_test_batch_begin(idx, logs)
arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels,
arc_loss, rel_loss, metric)
if output:
Expand All @@ -273,7 +273,7 @@ def evaluate_dataset(self, tst_data, callbacks, output, num_batches):
logs['loss'] = loss
logs.update(metric.to_dict())
for c in callbacks:
c.on_batch_end(idx, logs)
c.on_test_batch_end(idx, logs)
for c in callbacks:
c.on_epoch_end(0)
c.on_test_end()
Expand Down
2 changes: 1 addition & 1 deletion hanlp/version.py
Expand Up @@ -2,4 +2,4 @@
# Author: hankcs
# Date: 2019-12-28 19:26

__version__ = '2.0.0-alpha.68'
__version__ = '2.0.0-alpha.69'

1 comment on commit 27f5ded

@hanlpbot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit has been mentioned on 蝴蝶效应. There might be relevant details there:

https://bbs.hankcs.com/t/topic/3131/2

Please sign in to comment.