Skip to content

Commit

Permalink
update readme and downstream tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
Embedding committed Jun 22, 2019
1 parent f11b0df commit c041328
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 30 deletions.
Binary file added .DS_Store
Binary file not shown.
42 changes: 35 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Table of Contents
* [Features](#features)
* [Requirements](#requirements)
* [Quickstart](#quickstart)
* [Datasets](#datasets)
* [Instructions](#instructions)
* [Scripts](#scripts)
* [Experiments](#experiments)
Expand Down Expand Up @@ -60,9 +61,9 @@ The book review corpus is obtained by book review dataset. We remove labels and
The format of the classification dataset is as follows (label and instance are separated by \t):
```
label text_a
1 instance1
0 instance2
1 instance3
1 instance1
0 instance2
1 instance3
```

We use Google's Chinese vocabulary file, which contains 21128 Chinese characters. The format of the vocabulary is as follows:
Expand All @@ -80,7 +81,7 @@ python3 preprocess.py --corpus_path corpora/book_review_bert.txt --vocab_path mo
```
Pre-processing is time-consuming. Multi-process can largely accelerate the pre-processing speed.
Then we download [Google's pre-trained Chinese model](https://share.weiyun.com/5s9AsfQ), and put it into *models* folder.
We load Google's pre-trained model and train it on book review corpus. Suppose we have a machine with 8 GPUs. We explicitly specify model's encoder and target:
We load Google's pre-trained model and train it on book review corpus. We should better explicitly specify model's encoder and target. Suppose we have a machine with 8 GPUs.:
```
python3 pretrain.py --dataset_path dataset.pt --vocab_path models/google_vocab.txt --pretrained_model_path models/google_model.bin \
--output_model_path models/book_review_model.bin --world_size 8 --gpu_ranks 0 1 2 3 4 5 6 7 \
Expand All @@ -90,17 +91,44 @@ Finally, we do classification. We can use *google_model.bin*:
```
python3 classifier.py --pretrained_model_path models/google_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/book_review/train.tsv --dev_path datasets/book_review/dev.tsv --test_path datasets/book_review/test.tsv \
--epochs_num 3 --batch_size 64 --encoder bert
--epochs_num 3 --batch_size 32 --encoder bert
```
or use our [*book_review_model.bin*](https://share.weiyun.com/52BEFs2), which is the output of pretrain.py:
```
python3 classifier.py --pretrained_model_path models/book_review_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/book_review/train.tsv --dev_path datasets/book_review/dev.tsv --test_path datasets/book_review/test.tsv \
--epochs_num 3 --batch_size 64 --encoder bert
--epochs_num 3 --batch_size 32 --encoder bert
```
It turns out that the result of Google's model is 87.5; The result of *book_review_model.bin* is 88.1. It is also noticable that we don't need to specify the target in fine-tuning stage. Pre-training target is replaced with task-specific target.

We could search proper pre-trained models in [Chinese model zoo](#chinese_model_zoo) for further improvements. For example, we could download [a model pre-trained on Amazon corpus (over 4 million reviews) with BERT encoder and classification target](https://share.weiyun.com/5XuxtFA). It achieves 88.5 accuracy on book review dataset.
We could search proper pre-trained models in [Chinese model zoo](#chinese_model_zoo) for further improvements. For example, we could download [a model pre-trained on Amazon corpus (over 4 million reviews) with BERT encoder and classification (CLS) target](https://share.weiyun.com/5XuxtFA). It achieves 88.5 accuracy on book review dataset.

BERT is really slow. It could be great if we can speed up the model and still achieve comparable performance. We select a 2-layers LSTM encoder to substitute 12-layers Transformer encoder. We could download a model pre-trained with LSTM encoder and language modeling (LM) and classification (CLS) targets:
```
python3 classifier.py --pretrained_model_path models/ --vocab_path models/google_vocab.txt \
--train_path datasets/book_review/train.tsv --dev_path datasets/book_review/dev.tsv --test_path datasets/book_review/test.tsv \
--epochs_num 3 --batch_size 64 --encoder lstm --pooling mean --config_path models/rnn_config.json --learning_rate 1e-3
```
We can achieve 87.0 accuracy on testset, which is also a competitive result. Using LSTM without pre-training can only achieve 80.2 accuracy. In practice, above model is around 10 times faster than BERT. One can see Chinese model zoo section for more detailed information about above pre-trained LSTM model.

Besides classification, UER-py also provides scripts for other downstream tasks. We could use tagger.py for sequence labeling:
```
python3 tagger.py --pretrained_model_path models/google_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/msra/train.tsv --dev_path datasets/msra/dev.tsv --test_path datasets/msra/test.tsv \
--epochs_num 5 --batch_size 8 --encoder bert
```
We could download [a model pre-trained on RenMinRiBao (news corpus)](https://share.weiyun.com/5HKnsxq) and finetune on it:
```
python3 tagger.py --pretrained_model_path models/rmrb_model.bin --vocab_path models/google_vocab.txt \
--train_path datasets/msra/train.tsv --dev_path datasets/msra/dev.tsv --test_path datasets/msra/test.tsv \
--epochs_num 5 --batch_size 8 --encoder bert
```
It turns out that the result of Google's model is 92.6; The result of *rmrb_model.bin* is 94.4.

<br/>

## Datasets
This project includes a range of Chinese datasets. Small-scale datasets can be downloaded at [datasets_zh.zip](). datasets_zh.zip contains 7 datasets: XNLI, LCQMC, MSRA-NER, ChnSentiCorp, and nlpcc-dbqa are from [Baidu ERNIE](https://github.com/PaddlePaddle/LARK/tree/develop/ERNIE); Book review (from BNU) and Shopping are two sentence-level sentiment analysis datasets. Large-scale datasets can be found in [glyph's github project](https://github.com/zhangxiangxiao/glyph).

<br/>

Expand Down
11 changes: 6 additions & 5 deletions classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def main():
# Model options.
parser.add_argument("--batch_size", type=int, default=64,
help="Batch size.")
parser.add_argument("--seq_length", type=int, default=100,
parser.add_argument("--seq_length", type=int, default=128,
help="Sequence length.")
parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
"cnn", "gatedcnn", "attn", \
Expand Down Expand Up @@ -149,7 +149,6 @@ def main():
except:
pass
args.labels_num = len(labels_set)
print(columns)

# Load vocabulary.
vocab = Vocab()
Expand All @@ -164,7 +163,7 @@ def main():
# Load or initialize parameters.
if args.pretrained_model_path is not None:
# Initialize with pretrained model.
bert_model.load_state_dict(torch.load(args.pretrained_model_path), strict=True)
bert_model.load_state_dict(torch.load(args.pretrained_model_path), strict=False)
else:
# Initialize with normal distribution.
for n, p in list(bert_model.named_parameters()):
Expand Down Expand Up @@ -203,7 +202,9 @@ def batch_loader(batch_size, input_ids, label_ids, mask_ids):
def read_dataset(path):
dataset = []
with open(path, mode="r", encoding="utf-8") as f:
for line in f:
for line_id, line in enumerate(f):
if line_id == 0:
continue
try:
line = line.strip().split('\t')
if len(line) == 2:
Expand Down Expand Up @@ -238,7 +239,7 @@ def read_dataset(path):
tokens.append(0)
mask.append(0)
dataset.append((tokens, label, mask))
elif len(line) == 4: # For sentence pair input.
elif len(line) == 4: # For dbqa input.
qid=int(line[columns["qid"]])
label = int(line[columns["label"]])
text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
Expand Down
2 changes: 1 addition & 1 deletion cloze.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def batch_loader(batch_size, input_ids, seg_ids, mask_positions, label_ids):
for j, p in enumerate(mask_positions_batch):
topn_tokens = (-prob[j][p]).argsort()[:args.topn]

sentence = "".join([vocab.i2w[token_id] for token_id in input_ids[j] if token_id != 0])
sentence = "".join([vocab.i2w[token_id] for token_id in input_ids_batch[j] if token_id != 0])
pred_tokens = " ".join(vocab.i2w[token_id] for token_id in topn_tokens)
label_token = vocab.i2w[label_ids_batch[j]]
f_output.write(sentence + '\n')
Expand Down
Binary file added corpora/.DS_Store
Binary file not shown.
44 changes: 27 additions & 17 deletions feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import numpy as np
import torch.nn as nn

from bert.utils.vocab import Vocab
from bert.utils.constants import *
from bert.utils.config import load_hyperparam
from bert.utils.tokenizer import *
from bert.model_builder import build_model
from uer.utils.vocab import Vocab
from uer.utils.constants import *
from uer.utils.config import load_hyperparam
from uer.utils.tokenizer import *
from uer.model_builder import build_model


class SequenceEncoder(torch.nn.Module):
Expand All @@ -25,15 +25,6 @@ def __init__(self, bert_model):

def forward(self, input_ids, seg_ids):
emb = self.embedding(input_ids, seg_ids)
# seq_length = emb.size(1)
# #
# mask = (seg_ids>0).\
# unsqueeze(1).\
# repeat(1, seq_length, 1).\
# unsqueeze(1)

# mask = mask.float()
# mask = (1.0 - mask) * -10000.0
output = self.encoder(emb, seg_ids)
return output

Expand All @@ -50,11 +41,27 @@ def forward(self, input_ids, seg_ids):
help="Path of the vocabulary file.")
parser.add_argument("--output_path", required=True,
help="Path of the input file which is in npy format.")
# Subword options.
parser.add_argument("--subword_type", choices=["none", "char"], default="none",
help="Subword feature type.")
parser.add_argument("--sub_vocab_path", type=str, default="models/sub_vocab.txt",
help="Path of the subword vocabulary file.")
parser.add_argument("--subencoder", choices=["avg", "lstm", "gru", "cnn"], default="avg",
help="Subencoder type.")
parser.add_argument("--sub_layers_num", type=int, default=2, help="The number of subencoder layers.")

# Model options
parser.add_argument("--seq_length", type=int, default=100, help="Sequence length.")
parser.add_argument("--batch_size", type=int, default=8, help="Batch size.")
parser.add_argument("--config_path", default="./model.config", help="Model config file.")
parser.add_argument("--config_path", default="models/google_config.json", help="Model config file.")
parser.add_argument("--encoder", choices=["bert", "lstm", "gru", \
"cnn", "gatedcnn", "attn", \
"rcnn", "crnn", "gpt"], \
default="bert", help="Encoder type.")
parser.add_argument("--target", choices=["bert", "lm", "cls", "mlm", "nsp", "s2s"], default="bert",
help="The training target of the pretraining model.")

# Tokenizer options.

# Tokenizer options.
parser.add_argument("--tokenizer", choices=["char", "word", "space", "mixed"], default="char",
Expand All @@ -72,9 +79,10 @@ def forward(self, input_ids, seg_ids):
# Load vocabulary.
vocab = Vocab()
vocab.load(args.vocab_path)
args.vocab = vocab

# Build and load model.
bert_model = build_model(args, len(vocab))
bert_model = build_model(args)
pretrained_model = torch.load(args.model_path)
bert_model.load_state_dict(pretrained_model, strict=True)
seq_encoder = SequenceEncoder(bert_model)
Expand All @@ -91,7 +99,7 @@ def forward(self, input_ids, seg_ids):
if args.tokenizer == "mixed":
tokenizer = MixedTokenizer(vocab)
else:
tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"]()
tokenizer = globals()[args.tokenizer.capitalize() + "Tokenizer"](args)

dataset = []
with open(args.input_path, mode="r", encoding="utf-8") as f:
Expand Down Expand Up @@ -127,6 +135,8 @@ def batch_loader(batch_size, input_ids, seg_ids):

sentence_vectors = []
for i, (input_ids_batch, seg_ids_batch) in enumerate(batch_loader(args.batch_size, input_ids, seg_ids)):
input_ids_batch = input_ids_batch.to(device)
seg_ids_batch = seg_ids_batch.to(device)
output = seq_encoder(input_ids_batch, seg_ids_batch)
output = output.cpu().data.numpy()
sentence_vectors.append(output[:,0,:])
Expand Down
Binary file added uer/.DS_Store
Binary file not shown.

0 comments on commit c041328

Please sign in to comment.