-
Notifications
You must be signed in to change notification settings - Fork 332
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Inference] Add inference support and examples for BERT (#145)
* add bert files * add support to hdf5 bert weight loading * add bert_pb and export script * fix relative path of util * fix compilation issues * export complete position_embedding * add ls_bert example * update debug codes for bert * fix bert export weight match issue; add layernorm moving codes; * fix style * clean comments * update example as BERT for classification * output classification prediction * add support for attention mask * add support for attention mask Co-authored-by: Yang Wei <godweiyang@gmail.com>
- Loading branch information
1 parent
00d58f0
commit a57966a
Showing
13 changed files
with
2,112 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import h5py | ||
import numpy as np | ||
from collections import OrderedDict | ||
from transformers import BertModel | ||
from utils import fill_hdf5_layer | ||
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | ||
|
||
|
||
""" | ||
For the mapping dictionary: key is the value of the proto parameter, | ||
value is a powerful expression, each && split tensor name of the matching path or expression. | ||
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_. | ||
You can operate separately on each tensor and support multiple expressions. Multiple matching paths | ||
and the expression will finally be concatenated on axis = -1. | ||
""" | ||
enc_layer_mapping_dict = OrderedDict( | ||
{ | ||
# BERT is post_layernorm | ||
# NOTE: add an additional "final" at the beginning for some weight | ||
# to distinguish them from "attention output *" | ||
"multihead_norm_scale": "attention output LayerNorm weight", | ||
"multihead_norm_bias": "attention output LayerNorm bias", | ||
"multihead_project_kernel_qkv": "attention self query weight&&attention self key weight&&attention self value weight&&expression_.transpose(0, 1)", | ||
"multihead_project_bias_qkv": "attention self query bias&&attention self key bias&&attention self value bias", | ||
"multihead_project_kernel_output": "attention output dense weight&&expression_.transpose(0, 1)", | ||
"multihead_project_bias_output": "attention output dense bias", | ||
"ffn_norm_scale": "final output LayerNorm weight", | ||
"ffn_norm_bias": "final output LayerNorm bias", | ||
"ffn_first_kernel": "intermediate dense weight&&expression_.transpose(0, 1)", | ||
"ffn_first_bias": "intermediate dense bias", | ||
"ffn_second_kernel": "final output dense weight&&expression_.transpose(0, 1)", | ||
"ffn_second_bias": "final output dense bias", | ||
} | ||
) | ||
|
||
src_emb_mapping_dict = OrderedDict( | ||
{ | ||
"norm_scale": "embeddings LayerNorm weight", | ||
"norm_bias": "embeddings LayerNorm bias", | ||
"position_embedding": "embeddings position_embeddings weight", | ||
# manually process token_embedding due to "token_type_embeddings" | ||
# "token_embedding": "embeddings word_embeddings weight", | ||
} | ||
) | ||
|
||
|
||
def extract_bert_weights( | ||
output_file, | ||
model_dir, | ||
head_num, | ||
pad_id=0, | ||
max_step=50, | ||
): | ||
# load var names | ||
encoder_state_dict = BertModel.from_pretrained(model_dir).state_dict() | ||
|
||
# Insert additional "final" to some weight to prevent ambiguous match | ||
def _insert_final(key): | ||
l = key.split(".") | ||
l.insert(3, "final") | ||
return ".".join(l) | ||
|
||
encoder_state_dict = OrderedDict( | ||
[ | ||
(_insert_final(k), v) | ||
if len(k.split(".")) > 3 and k.split(".")[3] == "output" | ||
else (k, v) | ||
for k, v in encoder_state_dict.items() | ||
] | ||
) | ||
|
||
enc_var_name_list = list(encoder_state_dict.keys()) | ||
|
||
# initialize output file | ||
output_file += ".hdf5" | ||
print("Saving model to hdf5...") | ||
print("Writing to {0}".format(output_file)) | ||
hdf5_file = h5py.File(output_file, "w") | ||
|
||
# fill each encoder layer's params | ||
enc_tensor_names = {} | ||
for name in enc_var_name_list: | ||
name_split = name.split(".") | ||
if len(name_split) <= 2 or not name_split[2].isdigit(): | ||
continue | ||
layer_id = int(name_split[2]) | ||
enc_tensor_names.setdefault(layer_id, []).append(name) | ||
|
||
# fill encoder_stack | ||
for layer_id in sorted(enc_tensor_names.keys()): | ||
fill_hdf5_layer( | ||
enc_tensor_names[layer_id], | ||
encoder_state_dict, | ||
hdf5_file, | ||
f"encoder_stack/{layer_id}/", | ||
enc_layer_mapping_dict, | ||
) | ||
|
||
# fill src_embedding - except for position embedding | ||
fill_hdf5_layer( | ||
enc_var_name_list, | ||
encoder_state_dict, | ||
hdf5_file, | ||
"src_embedding/", | ||
src_emb_mapping_dict, | ||
) | ||
|
||
# handling token_embeddings for BERT | ||
token_embedding = ( | ||
encoder_state_dict["embeddings.word_embeddings.weight"] | ||
+ encoder_state_dict["embeddings.token_type_embeddings.weight"][0] | ||
) | ||
print(f"processed token_embedding, shape: {token_embedding.shape}") | ||
token_embedding = token_embedding.flatten().tolist() | ||
hdf5_file.create_dataset( | ||
"src_embedding/token_embedding", data=token_embedding, dtype="f4" | ||
) | ||
|
||
# save number of layers metadata | ||
hdf5_file.create_dataset( | ||
"model_conf/n_encoder_stack", data=len(enc_tensor_names), dtype="i4" | ||
) | ||
# fill in model_conf | ||
hdf5_file.create_dataset("model_conf/head_num", data=head_num, dtype="i4") | ||
hdf5_file.create_dataset("model_conf/src_padding_id", data=pad_id, dtype="i4") | ||
hdf5_file.create_dataset("model_conf/is_post_ln", data=True, dtype="?") | ||
hdf5_file.create_dataset("model_conf/use_gelu", data=True, dtype="?") | ||
|
||
# Move layernorm weights to match layernorm implementation in lightseq | ||
tmp_scale, tmp_bias = ( | ||
hdf5_file["src_embedding/norm_scale"][()], | ||
hdf5_file["src_embedding/norm_bias"][()], | ||
) | ||
for layer_id in sorted(enc_tensor_names.keys()): | ||
new_tmp_scale = hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_scale"][()] | ||
new_tmp_bias = hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_bias"][()] | ||
hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_scale"][()] = tmp_scale | ||
hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_bias"][()] = tmp_bias | ||
tmp_scale, tmp_bias = new_tmp_scale, new_tmp_bias | ||
|
||
new_tmp_scale = hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_scale"][()] | ||
new_tmp_bias = hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_bias"][()] | ||
hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_scale"][()] = tmp_scale | ||
hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_bias"][()] = tmp_bias | ||
tmp_scale, tmp_bias = new_tmp_scale, new_tmp_bias | ||
hdf5_file["src_embedding/norm_scale"][()] = tmp_scale | ||
hdf5_file["src_embedding/norm_bias"][()] = tmp_bias | ||
|
||
hdf5_file.close() | ||
# read-in again to double check | ||
hdf5_file = h5py.File(output_file, "r") | ||
|
||
def _print_pair(key, value): | ||
if key == "sampling_method": | ||
value = "".join(map(chr, value[()])) | ||
else: | ||
value = value[()] | ||
print(f"{key}: {value}") | ||
|
||
list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items())) | ||
|
||
|
||
if __name__ == "__main__": | ||
output_lightseq_model_name = "lightseq_bert_base_uncased" | ||
input_huggingface_bert_model = "bert-base-uncased" | ||
head_number = 12 | ||
|
||
pad_id = 0 | ||
max_step = 50 | ||
extract_bert_weights( | ||
output_lightseq_model_name, | ||
input_huggingface_bert_model, | ||
head_num=head_number, | ||
pad_id=pad_id, | ||
max_step=max_step, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import time | ||
import argparse | ||
|
||
import torch | ||
import lightseq.inference as lsi | ||
from transformers import BertTokenizer, BertForSequenceClassification | ||
|
||
|
||
def ls_bert(model, inputs, attn_mask): | ||
torch.cuda.synchronize() | ||
start_time = time.perf_counter() | ||
ls_output = model.infer(inputs, attn_mask) | ||
torch.cuda.synchronize() | ||
end_time = time.perf_counter() | ||
return ls_output, end_time - start_time | ||
|
||
|
||
def hf_bert(model, inputs, attn_mask): | ||
torch.cuda.synchronize() | ||
start_time = time.perf_counter() | ||
hf_output = model(inputs.to("cuda:0"), attention_mask=attn_mask.to("cuda:0")) | ||
torch.cuda.synchronize() | ||
end_time = time.perf_counter() | ||
return hf_output, end_time - start_time | ||
|
||
|
||
def ls_generate(model, inputs_id, attn_mask): | ||
print("=========lightseq=========") | ||
print("lightseq generating...") | ||
ls_output, ls_time = ls_bert(model, inputs_id, attn_mask) | ||
print(f"lightseq time: {ls_time}s") | ||
print("lightseq results (class predictions):") | ||
print(ls_output.argmax(axis=1).detach().cpu().numpy()) | ||
|
||
|
||
def hf_generate(model, inputs_id, attn_mask): | ||
print("=========huggingface=========") | ||
print("huggingface generating...") | ||
hf_output, hf_time = hf_bert(model, inputs_id, attn_mask) | ||
print(f"huggingface time: {hf_time}s") | ||
print("huggingface results (class predictions):") | ||
print(hf_output.logits.argmax(axis=1).detach().cpu().numpy()) | ||
|
||
|
||
def warmup(tokenizer, ls_model, hf_model, sentences): | ||
inputs = tokenizer(sentences, return_tensors="pt", padding=True) | ||
inputs_id = inputs["input_ids"] | ||
attn_mask = inputs["attention_mask"] | ||
|
||
ls_generate(ls_model, inputs_id, attn_mask) | ||
hf_generate(hf_model, inputs_id, attn_mask) | ||
|
||
|
||
class LightseqBertClassification: | ||
def __init__(self, ls_weight_path, hf_model): | ||
self.ls_bert = lsi.Bert(ls_weight_path, 128) | ||
self.pooler = hf_model.bert.pooler | ||
self.classifier = hf_model.classifier | ||
|
||
def infer(self, inputs, attn_mask): | ||
last_hidden_states = self.ls_bert.infer(inputs, attn_mask) | ||
last_hidden_states = torch.Tensor(last_hidden_states).float() | ||
pooled_output = self.pooler(last_hidden_states.to("cuda:0")) | ||
logits = self.classifier(pooled_output) | ||
return logits | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--user_input", action="store_true") | ||
args = parser.parse_args() | ||
|
||
print("initializing bert tokenizer...") | ||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") | ||
|
||
print("creating huggingface model...") | ||
hf_model = BertForSequenceClassification.from_pretrained("bert-base-uncased") | ||
hf_model.to("cuda:0") | ||
|
||
print("creating lightseq model...") | ||
ls_model = LightseqBertClassification("lightseq_bert_base_uncased.hdf5", hf_model) | ||
|
||
sentences = [ | ||
"Hello, my dog is cute", | ||
"Hey, how are you", | ||
"This is a test", | ||
"Testing the model again", | ||
] | ||
|
||
print("====================START warmup====================") | ||
warmup(tokenizer, ls_model, hf_model, sentences) | ||
print("====================END warmup====================") | ||
|
||
while True: | ||
if args.user_input: | ||
sentences = [input("input the masked sentence:\n")] | ||
|
||
print("tokenizing the sentences...") | ||
inputs = tokenizer(sentences, return_tensors="pt", padding=True) | ||
inputs_id = inputs["input_ids"] | ||
attn_mask = inputs["attention_mask"] | ||
|
||
ls_generate(ls_model, inputs_id, attn_mask) | ||
hf_generate(hf_model, inputs_id, attn_mask) | ||
|
||
if not args.user_input: | ||
break | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.