Skip to content

Commit

Permalink
[Inference] Add inference support and examples for BERT (#145)
Browse files Browse the repository at this point in the history
* add bert files

* add support to hdf5 bert weight loading

* add bert_pb and export script

* fix relative path of util

* fix compilation issues

* export complete position_embedding

* add ls_bert example

* update debug codes for bert

* fix bert export weight match issue;
add layernorm moving codes;

* fix style

* clean comments

* update example as BERT for classification

* output classification prediction

* add support for attention mask

* add support for attention mask

Co-authored-by: Yang Wei <godweiyang@gmail.com>
  • Loading branch information
Xingyao Wang and godweiyang authored Aug 9, 2021
1 parent 00d58f0 commit a57966a
Show file tree
Hide file tree
Showing 13 changed files with 2,112 additions and 0 deletions.
623 changes: 623 additions & 0 deletions examples/inference/python/bert_pb2.py

Large diffs are not rendered by default.

179 changes: 179 additions & 0 deletions examples/inference/python/hf_bert_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import h5py
import numpy as np
from collections import OrderedDict
from transformers import BertModel
from utils import fill_hdf5_layer

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"


"""
For the mapping dictionary: key is the value of the proto parameter,
value is a powerful expression, each && split tensor name of the matching path or expression.
The sub-pattern of the path is separated by spaces, and the expression starts with a expression_.
You can operate separately on each tensor and support multiple expressions. Multiple matching paths
and the expression will finally be concatenated on axis = -1.
"""
enc_layer_mapping_dict = OrderedDict(
{
# BERT is post_layernorm
# NOTE: add an additional "final" at the beginning for some weight
# to distinguish them from "attention output *"
"multihead_norm_scale": "attention output LayerNorm weight",
"multihead_norm_bias": "attention output LayerNorm bias",
"multihead_project_kernel_qkv": "attention self query weight&&attention self key weight&&attention self value weight&&expression_.transpose(0, 1)",
"multihead_project_bias_qkv": "attention self query bias&&attention self key bias&&attention self value bias",
"multihead_project_kernel_output": "attention output dense weight&&expression_.transpose(0, 1)",
"multihead_project_bias_output": "attention output dense bias",
"ffn_norm_scale": "final output LayerNorm weight",
"ffn_norm_bias": "final output LayerNorm bias",
"ffn_first_kernel": "intermediate dense weight&&expression_.transpose(0, 1)",
"ffn_first_bias": "intermediate dense bias",
"ffn_second_kernel": "final output dense weight&&expression_.transpose(0, 1)",
"ffn_second_bias": "final output dense bias",
}
)

src_emb_mapping_dict = OrderedDict(
{
"norm_scale": "embeddings LayerNorm weight",
"norm_bias": "embeddings LayerNorm bias",
"position_embedding": "embeddings position_embeddings weight",
# manually process token_embedding due to "token_type_embeddings"
# "token_embedding": "embeddings word_embeddings weight",
}
)


def extract_bert_weights(
output_file,
model_dir,
head_num,
pad_id=0,
max_step=50,
):
# load var names
encoder_state_dict = BertModel.from_pretrained(model_dir).state_dict()

# Insert additional "final" to some weight to prevent ambiguous match
def _insert_final(key):
l = key.split(".")
l.insert(3, "final")
return ".".join(l)

encoder_state_dict = OrderedDict(
[
(_insert_final(k), v)
if len(k.split(".")) > 3 and k.split(".")[3] == "output"
else (k, v)
for k, v in encoder_state_dict.items()
]
)

enc_var_name_list = list(encoder_state_dict.keys())

# initialize output file
output_file += ".hdf5"
print("Saving model to hdf5...")
print("Writing to {0}".format(output_file))
hdf5_file = h5py.File(output_file, "w")

# fill each encoder layer's params
enc_tensor_names = {}
for name in enc_var_name_list:
name_split = name.split(".")
if len(name_split) <= 2 or not name_split[2].isdigit():
continue
layer_id = int(name_split[2])
enc_tensor_names.setdefault(layer_id, []).append(name)

# fill encoder_stack
for layer_id in sorted(enc_tensor_names.keys()):
fill_hdf5_layer(
enc_tensor_names[layer_id],
encoder_state_dict,
hdf5_file,
f"encoder_stack/{layer_id}/",
enc_layer_mapping_dict,
)

# fill src_embedding - except for position embedding
fill_hdf5_layer(
enc_var_name_list,
encoder_state_dict,
hdf5_file,
"src_embedding/",
src_emb_mapping_dict,
)

# handling token_embeddings for BERT
token_embedding = (
encoder_state_dict["embeddings.word_embeddings.weight"]
+ encoder_state_dict["embeddings.token_type_embeddings.weight"][0]
)
print(f"processed token_embedding, shape: {token_embedding.shape}")
token_embedding = token_embedding.flatten().tolist()
hdf5_file.create_dataset(
"src_embedding/token_embedding", data=token_embedding, dtype="f4"
)

# save number of layers metadata
hdf5_file.create_dataset(
"model_conf/n_encoder_stack", data=len(enc_tensor_names), dtype="i4"
)
# fill in model_conf
hdf5_file.create_dataset("model_conf/head_num", data=head_num, dtype="i4")
hdf5_file.create_dataset("model_conf/src_padding_id", data=pad_id, dtype="i4")
hdf5_file.create_dataset("model_conf/is_post_ln", data=True, dtype="?")
hdf5_file.create_dataset("model_conf/use_gelu", data=True, dtype="?")

# Move layernorm weights to match layernorm implementation in lightseq
tmp_scale, tmp_bias = (
hdf5_file["src_embedding/norm_scale"][()],
hdf5_file["src_embedding/norm_bias"][()],
)
for layer_id in sorted(enc_tensor_names.keys()):
new_tmp_scale = hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_scale"][()]
new_tmp_bias = hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_bias"][()]
hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_scale"][()] = tmp_scale
hdf5_file[f"encoder_stack/{layer_id}/multihead_norm_bias"][()] = tmp_bias
tmp_scale, tmp_bias = new_tmp_scale, new_tmp_bias

new_tmp_scale = hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_scale"][()]
new_tmp_bias = hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_bias"][()]
hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_scale"][()] = tmp_scale
hdf5_file[f"encoder_stack/{layer_id}/ffn_norm_bias"][()] = tmp_bias
tmp_scale, tmp_bias = new_tmp_scale, new_tmp_bias
hdf5_file["src_embedding/norm_scale"][()] = tmp_scale
hdf5_file["src_embedding/norm_bias"][()] = tmp_bias

hdf5_file.close()
# read-in again to double check
hdf5_file = h5py.File(output_file, "r")

def _print_pair(key, value):
if key == "sampling_method":
value = "".join(map(chr, value[()]))
else:
value = value[()]
print(f"{key}: {value}")

list(map(lambda x: _print_pair(*x), hdf5_file["model_conf"].items()))


if __name__ == "__main__":
output_lightseq_model_name = "lightseq_bert_base_uncased"
input_huggingface_bert_model = "bert-base-uncased"
head_number = 12

pad_id = 0
max_step = 50
extract_bert_weights(
output_lightseq_model_name,
input_huggingface_bert_model,
head_num=head_number,
pad_id=pad_id,
max_step=max_step,
)
111 changes: 111 additions & 0 deletions examples/inference/python/ls_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import time
import argparse

import torch
import lightseq.inference as lsi
from transformers import BertTokenizer, BertForSequenceClassification


def ls_bert(model, inputs, attn_mask):
torch.cuda.synchronize()
start_time = time.perf_counter()
ls_output = model.infer(inputs, attn_mask)
torch.cuda.synchronize()
end_time = time.perf_counter()
return ls_output, end_time - start_time


def hf_bert(model, inputs, attn_mask):
torch.cuda.synchronize()
start_time = time.perf_counter()
hf_output = model(inputs.to("cuda:0"), attention_mask=attn_mask.to("cuda:0"))
torch.cuda.synchronize()
end_time = time.perf_counter()
return hf_output, end_time - start_time


def ls_generate(model, inputs_id, attn_mask):
print("=========lightseq=========")
print("lightseq generating...")
ls_output, ls_time = ls_bert(model, inputs_id, attn_mask)
print(f"lightseq time: {ls_time}s")
print("lightseq results (class predictions):")
print(ls_output.argmax(axis=1).detach().cpu().numpy())


def hf_generate(model, inputs_id, attn_mask):
print("=========huggingface=========")
print("huggingface generating...")
hf_output, hf_time = hf_bert(model, inputs_id, attn_mask)
print(f"huggingface time: {hf_time}s")
print("huggingface results (class predictions):")
print(hf_output.logits.argmax(axis=1).detach().cpu().numpy())


def warmup(tokenizer, ls_model, hf_model, sentences):
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
inputs_id = inputs["input_ids"]
attn_mask = inputs["attention_mask"]

ls_generate(ls_model, inputs_id, attn_mask)
hf_generate(hf_model, inputs_id, attn_mask)


class LightseqBertClassification:
def __init__(self, ls_weight_path, hf_model):
self.ls_bert = lsi.Bert(ls_weight_path, 128)
self.pooler = hf_model.bert.pooler
self.classifier = hf_model.classifier

def infer(self, inputs, attn_mask):
last_hidden_states = self.ls_bert.infer(inputs, attn_mask)
last_hidden_states = torch.Tensor(last_hidden_states).float()
pooled_output = self.pooler(last_hidden_states.to("cuda:0"))
logits = self.classifier(pooled_output)
return logits


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--user_input", action="store_true")
args = parser.parse_args()

print("initializing bert tokenizer...")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

print("creating huggingface model...")
hf_model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
hf_model.to("cuda:0")

print("creating lightseq model...")
ls_model = LightseqBertClassification("lightseq_bert_base_uncased.hdf5", hf_model)

sentences = [
"Hello, my dog is cute",
"Hey, how are you",
"This is a test",
"Testing the model again",
]

print("====================START warmup====================")
warmup(tokenizer, ls_model, hf_model, sentences)
print("====================END warmup====================")

while True:
if args.user_input:
sentences = [input("input the masked sentence:\n")]

print("tokenizing the sentences...")
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
inputs_id = inputs["input_ids"]
attn_mask = inputs["attention_mask"]

ls_generate(ls_model, inputs_id, attn_mask)
hf_generate(hf_model, inputs_id, attn_mask)

if not args.user_input:
break


if __name__ == "__main__":
main()
7 changes: 7 additions & 0 deletions lightseq/inference/model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,10 @@ target_link_libraries(gpt_model PUBLIC gpt_weight)
target_link_libraries(gpt_model PUBLIC CUDA::cublas_static CUDA::cublasLt_static)
target_include_directories(gpt_model INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
set_target_properties(gpt_model PROPERTIES CUDA_SEPARABLE_COMPILATION ON)

add_library(bert_model STATIC bert_encoder.cc.cu)
target_link_libraries(bert_model PUBLIC cuda_kernels)
target_link_libraries(bert_model PUBLIC bert_weight)
target_link_libraries(bert_model PUBLIC CUDA::cublas_static
CUDA::cublasLt_static)
set_target_properties(bert_model PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
Loading

0 comments on commit a57966a

Please sign in to comment.