From 27f5ded76f9afed7da34ab7cba4fef0a66d523cc Mon Sep 17 00:00:00 2001 From: hankcs Date: Fri, 25 Dec 2020 13:13:07 -0500 Subject: [PATCH] Fix warm up and progbar in BiaffineDependencyParser https://bbs.hankcs.com/t/topic/3131 --- hanlp/common/component.py | 5 ++++- hanlp/components/parsers/biaffine_parser.py | 8 ++++---- hanlp/version.py | 2 +- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/hanlp/common/component.py b/hanlp/common/component.py index 44fd8408c..b81ab44a4 100644 --- a/hanlp/common/component.py +++ b/hanlp/common/component.py @@ -5,6 +5,7 @@ import logging import math import os +import warnings from abc import ABC, abstractmethod from typing import Any, Dict, Optional, List @@ -95,7 +96,9 @@ def evaluate(self, input_path: str, save_dir=None, output=False, batch_size=128, samples = size_of_dataset(tst_data) num_batches = math.ceil(samples / batch_size) if warm_up: - self.model.predict_on_batch(list(tst_data.take(1))[0]) + for x, y in tst_data: + self.model.predict_on_batch(x) + break if output: assert save_dir, 'Must pass save_dir in order to output' if isinstance(output, bool): diff --git a/hanlp/components/parsers/biaffine_parser.py b/hanlp/components/parsers/biaffine_parser.py index 837ac2db7..9f34612f5 100644 --- a/hanlp/components/parsers/biaffine_parser.py +++ b/hanlp/components/parsers/biaffine_parser.py @@ -256,14 +256,14 @@ def evaluate_dataset(self, tst_data, callbacks, output, num_batches): params = {'verbose': 1, 'epochs': 1, 'metrics': ['loss'] + self.config.metrics, 'steps': steps_per_epoch} for c in callbacks: c.set_params(params) - c.on_train_begin() # otherwise AttributeError: 'ProgbarLogger' object has no attribute 'verbose' - c.on_epoch_begin(0) + c.on_test_begin() + c.on_epoch_end(0) logs = {} if output: output = open(output, 'w', encoding='utf-8') for idx, ((words, feats), (arcs, rels)) in enumerate(iter(tst_data)): for c in callbacks: - c.on_batch_begin(idx, logs) + c.on_test_batch_begin(idx, logs) arc_scores, rel_scores, loss, mask, arc_preds, rel_preds = self.evaluate_batch(words, feats, arcs, rels, arc_loss, rel_loss, metric) if output: @@ -273,7 +273,7 @@ def evaluate_dataset(self, tst_data, callbacks, output, num_batches): logs['loss'] = loss logs.update(metric.to_dict()) for c in callbacks: - c.on_batch_end(idx, logs) + c.on_test_batch_end(idx, logs) for c in callbacks: c.on_epoch_end(0) c.on_test_end() diff --git a/hanlp/version.py b/hanlp/version.py index 8d6704aa1..083a83357 100644 --- a/hanlp/version.py +++ b/hanlp/version.py @@ -2,4 +2,4 @@ # Author: hankcs # Date: 2019-12-28 19:26 -__version__ = '2.0.0-alpha.68' +__version__ = '2.0.0-alpha.69'