Skip to content

Commit

Permalink
Fix TransformerTransform https://bbs.hankcs.com/t/topic/2822
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Oct 23, 2020
1 parent ad98823 commit bcd7ec7
Show file tree
Hide file tree
Showing 11 changed files with 20 additions and 19 deletions.
2 changes: 1 addition & 1 deletion hanlp/common/component.py
Expand Up @@ -469,7 +469,7 @@ def predict(self, data: Any, batch_size=None, **kwargs):
def predict_batch(self, batch, inputs=None, **kwargs):
X = batch[0]
Y = self.model.predict_on_batch(X)
for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, **kwargs):
for output in self.transform.Y_to_outputs(Y, X=X, inputs=inputs, batch=batch, **kwargs):
yield output

@property
Expand Down
2 changes: 1 addition & 1 deletion hanlp/common/transform.py
Expand Up @@ -237,7 +237,7 @@ def str_to_idx(self, X, Y) -> Tuple[Union[tf.Tensor, Tuple], tf.Tensor]:
def X_to_inputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]]) -> Iterable:
return [repr(x) for x in X]

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
return [repr(y) for y in Y]

def XY_to_inputs_outputs(self, X: Union[tf.Tensor, Tuple[tf.Tensor]],
Expand Down
2 changes: 1 addition & 1 deletion hanlp/components/classifiers/transformer_classifier.py
Expand Up @@ -79,7 +79,7 @@ def x_to_idx(self, x) -> Union[tf.Tensor, Tuple]:
logger.fatal('map_x should always be set to True')
exit(1)

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
preds = tf.argmax(Y, axis=-1)
for y in preds:
yield self.label_vocab.idx_to_token[y]
Expand Down
4 changes: 2 additions & 2 deletions hanlp/components/ner.py
Expand Up @@ -26,8 +26,8 @@ def predict_batch(self, batch, inputs=None):

class IOBES_Transform(Transform):

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
for words, tags in zip(inputs, super().Y_to_outputs(Y, gold, inputs=inputs, X=X)):
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
for words, tags in zip(inputs, super().Y_to_outputs(Y, gold, inputs=inputs, X=X, batch=batch)):
yield from iobes_to_span(words, tags)


Expand Down
4 changes: 2 additions & 2 deletions hanlp/components/parsers/conll.py
Expand Up @@ -404,7 +404,7 @@ def fit(self, trn_path: str, **kwargs) -> int:
self.form_vocab.add(token)
return num_samples

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
arc_preds, rel_preds, mask = Y
sents = []

Expand Down Expand Up @@ -497,7 +497,7 @@ def create_types_shapes_values(self) -> Tuple[Tuple, Tuple, Tuple]:
False, self.rel_vocab.safe_pad_token_idx)
return types, shapes, values

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None, batch=None) -> Iterable:
arc_preds, rel_preds, mask = Y
sents = []

Expand Down
11 changes: 5 additions & 6 deletions hanlp/components/taggers/transformers/transformer_transform.py
Expand Up @@ -79,7 +79,7 @@ def inputs_to_samples(self, inputs, gold=False):
if gold:
words, tags = sample
else:
words, tags = sample, [self.tag_vocab.pad_token] * len(sample)
words, tags = sample, [self.tag_vocab.idx_to_token[1]] * len(sample)

input_ids, input_mask, segment_ids, label_ids = convert_examples_to_features(words, tags,
self.tag_vocab.token_to_idx,
Expand Down Expand Up @@ -115,14 +115,13 @@ def y_to_idx(self, y) -> tf.Tensor:
def input_is_single_sample(self, input: Union[List[str], List[List[str]]]) -> bool:
return isinstance(input[0], str)

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, X=None, inputs=None,
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, X=None, inputs=None, batch=None,
**kwargs) -> Iterable:
assert X is not None, 'Need the X to know actual length of Y'
input_ids, input_mask, segment_ids = X
assert batch is not None, 'Need the batch to know actual length of Y'
label_mask = batch[1]

mask = tf.reduce_all(tf.not_equal(tf.expand_dims(input_ids, axis=-1), self.special_token_ids), axis=-1)
Y = tf.argmax(Y, axis=-1)
Y = Y[mask]
Y = Y[label_mask > 0]
tags = [self.tag_vocab.idx_to_token[tid] for tid in Y]
offset = 0
for words in inputs:
Expand Down
3 changes: 2 additions & 1 deletion hanlp/transform/txt.py
Expand Up @@ -221,7 +221,8 @@ def inputs_to_samples(self, inputs, gold=False):
chars = CharTable.normalize_chars(chars)
yield chars, tags

def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None) -> Iterable:
def Y_to_outputs(self, Y: Union[tf.Tensor, Tuple[tf.Tensor]], gold=False, inputs=None, X=None,
batch=None) -> Iterable:
yield from self.Y_to_tokens(self.tag_vocab, Y, gold, inputs)

def Y_to_tokens(self, tag_vocab, Y, gold, inputs):
Expand Down
2 changes: 1 addition & 1 deletion hanlp/version.py
Expand Up @@ -2,4 +2,4 @@
# Author: hankcs
# Date: 2019-12-28 19:26

__version__ = '2.0.0-alpha.65'
__version__ = '2.0.0-alpha.66'
2 changes: 1 addition & 1 deletion tests/demo/zh/demo_cws.py
Expand Up @@ -12,7 +12,7 @@
text = 'NLP统计模型没有加规则,聪明人知道自己加。英文、数字、自定义词典统统都是规则。'
print(tokenizer(text))

dic = {'自定义': 'custom', '词典': 'dict', '聪明人': 'smart'}
dic = {'自定义词典': 'custom_dict', '聪明人': 'smart'}


def split_by_dic(text: str):
Expand Down
2 changes: 1 addition & 1 deletion tests/demo/zh/demo_cws_trie.py
Expand Up @@ -10,7 +10,7 @@
print(tokenizer(text))

trie = Trie()
trie.update({'自定义': 'custom', '词典': 'dict', '聪明人': 'smart'})
trie.update({'自定义词典': 'custom_dict', '聪明人': 'smart'})


def split_sents(text: str, trie: Trie):
Expand Down
5 changes: 3 additions & 2 deletions tests/demo/zh/demo_pos.py
Expand Up @@ -3,6 +3,7 @@
# Date: 2019-12-28 21:25
import hanlp
from hanlp.pretrained.pos import CTB9_POS_ALBERT_BASE

tagger = hanlp.load(CTB9_POS_ALBERT_BASE)
print(tagger.predict(['我', '的', '希望', '是', '希望', '和平']))
print(tagger.predict([['支持', '批处理'], ['速度', '更', '快']]))
print(tagger.predict(['我', '的', '希望', '是', '希望', '世界', '和平']))
print(tagger.predict([['支持', '批处理', '地', '预测'], ['速度', '更', '快']]))

1 comment on commit bcd7ec7

@hanlpbot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This commit has been mentioned on 蝴蝶效应. There might be relevant details there:

https://bbs.hankcs.com/t/topic/2822/5

Please sign in to comment.