Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Feature] Add Ner Suffix feature #1123

Open
wants to merge 1 commit into
base: v0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 53 additions & 19 deletions scripts/sequence_labeling/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def remove_docstart_sentence(sentences):
return ret


def bert_tokenize_sentence(sentence, bert_tokenizer):
def bert_tokenize_sentence(sentence, bert_tokenizer, tagging_first_token):
"""Apply BERT tokenizer on a tagged sentence to break words into sub-words.
This function assumes input tags are following IOBES, and outputs IOBES tags.

Expand All @@ -141,6 +141,9 @@ def bert_tokenize_sentence(sentence, bert_tokenizer):
List of tagged words
bert_tokenizer: nlp.data.BertTokenizer
BERT tokenizer
tagging_first_token: bool, optional (default: True)
By default, only the first token of a word is going to be tagged.
If ``tagging_first_token`` is set to False, then the last token of a word is going to be tagged.

Returns
-------
Expand All @@ -151,14 +154,20 @@ def bert_tokenize_sentence(sentence, bert_tokenizer):
# break a word into sub-word tokens
sub_token_texts = bert_tokenizer(token.text)
# only the first token of a word is going to be tagged
ret.append(TaggedToken(text=sub_token_texts[0], tag=token.tag))
ret += [TaggedToken(text=sub_token_text, tag=NULL_TAG)
for sub_token_text in sub_token_texts[1:]]
if tagging_first_token:
ret.append(TaggedToken(text=sub_token_texts[0], tag=token.tag))
ret += [TaggedToken(text=sub_token_text, tag=NULL_TAG)
for sub_token_text in sub_token_texts[1:]]
# only the last token of a word is going to be tagged
else:
ret += [TaggedToken(text=sub_token_text, tag=NULL_TAG)
for sub_token_text in sub_token_texts[:-1]]
ret.append(TaggedToken(text=sub_token_texts[-1], tag=token.tag))

return ret


def load_segment(file_path, bert_tokenizer):
def load_segment(file_path, bert_tokenizer, tagging_first_token):
"""Load CoNLL format NER datafile with BIO-scheme tags.

Tagging scheme is converted into BIOES, and words are tokenized into wordpieces
Expand All @@ -169,6 +178,10 @@ def load_segment(file_path, bert_tokenizer):
file_path: str
Path of the file
bert_tokenizer: nlp.data.BERTTokenizer
tagging_first_token: bool, optional (default: True)
By default, only the first token of a word is going to be tagged.
If ``tagging_first_token`` is set to False, then the last token of a word is going to be tagged.


Returns
-------
Expand All @@ -177,7 +190,7 @@ def load_segment(file_path, bert_tokenizer):
logging.info('Loading sentences in %s...', file_path)
bio2_sentences = remove_docstart_sentence(read_bio_as_bio2(file_path))
bioes_sentences = [bio_bioes(sentence) for sentence in bio2_sentences]
subword_sentences = [bert_tokenize_sentence(sentence, bert_tokenizer)
subword_sentences = [bert_tokenize_sentence(sentence, bert_tokenizer, tagging_first_token)
for sentence in bioes_sentences]

logging.info('load %s, its max seq len: %d',
Expand All @@ -203,19 +216,23 @@ class BERTTaggingDataset:
Length of the input sequence to BERT.
is_cased: bool
Whether to use cased model.
tagging_first_token: bool, optional (default: True)
By default, only the first token of a word is going to be tagged.
If ``tagging_first_token`` is set to False, then the last token of a word is going to be tagged.

"""

def __init__(self, text_vocab, train_path, dev_path, test_path, seq_len, is_cased,
tag_vocab=None):
tag_vocab=None, tagging_first_token=True):
self.text_vocab = text_vocab
self.seq_len = seq_len

self.bert_tokenizer = nlp.data.BERTTokenizer(vocab=text_vocab, lower=not is_cased)

train_sentences = [] if train_path is None else load_segment(train_path,
self.bert_tokenizer)
dev_sentences = [] if dev_path is None else load_segment(dev_path, self.bert_tokenizer)
test_sentences = [] if test_path is None else load_segment(test_path, self.bert_tokenizer)
self.bert_tokenizer, tagging_first_token)
dev_sentences = [] if dev_path is None else load_segment(dev_path, self.bert_tokenizer, tagging_first_token)
test_sentences = [] if test_path is None else load_segment(test_path, self.bert_tokenizer, tagging_first_token)
all_sentences = train_sentences + dev_sentences + test_sentences

if tag_vocab is None:
Expand Down Expand Up @@ -318,7 +335,7 @@ def num_tag_types(self):


def convert_arrays_to_text(text_vocab, tag_vocab,
np_text_ids, np_true_tags, np_pred_tags, np_valid_length):
np_text_ids, np_true_tags, np_pred_tags, np_valid_length, tagging_first_token=True):
"""Convert numpy array data into text

Parameters
Expand All @@ -327,6 +344,10 @@ def convert_arrays_to_text(text_vocab, tag_vocab,
np_true_tags: tag_ids (batch_size, seq_len)
np_pred_tags: tag_ids (batch_size, seq_len)
np.array: valid_length (batch_size,) the number of tokens until [SEP] token
tagging_first_token: bool, optional (default: True)
By default, only the first token of a word is going to be tagged.
If ``tagging_first_token`` is set to False, then the last token of a word is going to be tagged.


Returns
-------
Expand All @@ -337,19 +358,32 @@ def convert_arrays_to_text(text_vocab, tag_vocab,
for sample_index in range(np_valid_length.shape[0]):
sample_len = np_valid_length[sample_index]
entries = []
tmptext = ""
for i in range(1, sample_len - 1):
token_text = text_vocab.idx_to_token[np_text_ids[sample_index, i]]
true_tag = tag_vocab.idx_to_token[int(np_true_tags[sample_index, i])]
pred_tag = tag_vocab.idx_to_token[int(np_pred_tags[sample_index, i])]
# we don't need to predict on NULL tags
if true_tag == NULL_TAG:
last_entry = entries[-1]
entries[-1] = PredictedToken(text=last_entry.text + token_text,
true_tag=last_entry.true_tag,
pred_tag=last_entry.pred_tag)
if tagging_first_token:
if true_tag == NULL_TAG:
last_entry = entries[-1]
entries[-1] = PredictedToken(text=last_entry.text + token_text,
true_tag=last_entry.true_tag,
pred_tag=last_entry.pred_tag)
else:
entries.append(PredictedToken(text=token_text,
true_tag=true_tag, pred_tag=pred_tag))
else:
entries.append(PredictedToken(text=token_text,
true_tag=true_tag, pred_tag=pred_tag))

if true_tag == NULL_TAG:
tmptext += token_text
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better name it as tmp_text. Or what about partial_text?

else:
if len(tmptext) > 0:
text = tmptext + token_text
entries.append(PredictedToken(text=text,
true_tag=true_tag, pred_tag=pred_tag))
tmptext = ''
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can both cases be merged here? For example, if len(tmptext) == 0, you can still have text = tmptext + token_text which is equivalent to token_text.

entries.append(PredictedToken(text=token_text,
true_tag=true_tag, pred_tag=pred_tag))
predictions.append(entries)
return predictions
7 changes: 5 additions & 2 deletions scripts/sequence_labeling/finetune_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ def parse_args():
help='Learning rate for optimization')
arg_parser.add_argument('--warmup-ratio', type=float, default=0.1,
help='Warmup ratio for learning rate scheduling')
arg_parser.add_argument('--tagging-first-token', type=str2bool, default=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about parser.add_argument('--tag-last-token', action='store_true'). It seems simpler to call finetune_bert.py --tag-last-token than finetune_bert.py --tagging-first-token=False.

In either case please update the test case in scripts/tests/ to run invoke the finetune_bert.py with both options. You can parametrize the test following for example Haibin's recent PR: https://github.com/dmlc/gluon-nlp/pull/1121/files#diff-fa82d34d543ff657c2fe09553bd0fa34R234

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I will update it.

help='Choose to use the first or last piece of the word')
args = arg_parser.parse_args()
return args

Expand All @@ -95,7 +97,7 @@ def main(config):
config.dropout_prob)

dataset = BERTTaggingDataset(text_vocab, config.train_path, config.dev_path, config.test_path,
config.seq_len, config.cased)
config.seq_len, config.cased,tagging_first_token=config.tagging_first_token)

train_data_loader = dataset.get_train_data_loader(config.batch_size)
dev_data_loader = dataset.get_dev_data_loader(config.batch_size)
Expand Down Expand Up @@ -178,7 +180,8 @@ def evaluate(data_loader):
np_true_tags = tag_ids.asnumpy()

predictions += convert_arrays_to_text(text_vocab, dataset.tag_vocab, np_text_ids,
np_true_tags, np_pred_tags, np_valid_length)
np_true_tags, np_pred_tags, np_valid_length,
tagging_first_token=config.tagging_first_token)

all_true_tags = [[entry.true_tag for entry in entries] for entries in predictions]
all_pred_tags = [[entry.pred_tag for entry in entries] for entries in predictions]
Expand Down
10 changes: 7 additions & 3 deletions scripts/sequence_labeling/predict_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os

import mxnet as mx
from ner_utils import get_bert_model, get_context
from ner_utils import get_bert_model, get_context,str2bool
from ner_utils import load_metadata
from data import BERTTaggingDataset, convert_arrays_to_text
from model import BERTTagger
Expand Down Expand Up @@ -68,6 +68,10 @@ def parse_args():
help='Number (index) of GPU to run on, e.g. 0. '
'If not specified, CPU context is used.')
arg_parser.add_argument('--batch-size', type=int, default=32, help='Batch size for training')
arg_parser.add_argument('--tagging-first-token', type=str2bool, default=True,
help='Choose to tag first word piece or the last word piece')


args = arg_parser.parse_args()
return args

Expand All @@ -81,7 +85,7 @@ def main(config):
train_config.dropout_prob)

dataset = BERTTaggingDataset(text_vocab, None, None, config.test_path,
config.seq_len, train_config.cased, tag_vocab=tag_vocab)
config.seq_len, train_config.cased, tag_vocab=tag_vocab,tagging_first_token=config.tagging_first_token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls add white space after the comma.


test_data_loader = dataset.get_test_data_loader(config.batch_size)

Expand Down Expand Up @@ -112,7 +116,7 @@ def evaluate(data_loader):
np_true_tags = tag_ids.asnumpy()

predictions += convert_arrays_to_text(text_vocab, dataset.tag_vocab, np_text_ids,
np_true_tags, np_pred_tags, np_valid_length)
np_true_tags, np_pred_tags, np_valid_length,tagging_first_token=config.tagging_first_token)

all_true_tags = [[entry.true_tag for entry in entries] for entries in predictions]
all_pred_tags = [[entry.pred_tag for entry in entries] for entries in predictions]
Expand Down