-
Notifications
You must be signed in to change notification settings - Fork 4.8k
Open
Labels
Description
I wanna ask that is there something wrong in my code to run these frameworks.:pleading_face::pleading_face::pleading_face:
This is the result in deepspeed-inference's paper, showing that deepspeed is always faster than fastertransformer:
But this is my result:
My code:
I install deepspeed with this command:
DS_BUILD_TRANSFORMER_INFERENCE=1 pip install deepspeed --no-cache-dir && pip cache purge
And this is my code to use huggingface, deepspeed, and fastertransformer:
import json
import torch
import transformers
import deepspeed
from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT
import examples.pytorch.gpt.utils.gpt_token_encoder as encoder
class __huggingface_generate_model__(object):
def __init__(self, model_path, tokenizer_path, use_deepspeed=False) -> None:
model = transformers.AutoModelForCausalLM.from_pretrained(model_path)
if use_deepspeed:
ds_engine = deepspeed.init_inference(model, mp_size=1, dtype=torch.half, replace_with_kernel_inject=True)
model = ds_engine.module
else:
model = model.half()
model = model.to(device="cuda")
model.generate(input_ids=torch.IntTensor([[128]]).to(device=model.device), max_length=2, min_length=2, pad_token_id=0)
self.tokenizer = transformers.PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
self.model = model
self.encode = self.tokenizer.encode
self.decode = self.tokenizer.decode
self.device = self.model.device
def generate(self, input_ids: torch.IntTensor, beam_width, output_length, prompt_length):
num_beams = beam_width
num_return_sequences = beam_width
max_length = output_length + prompt_length
min_length = output_length + prompt_length
input_ids = input_ids.to(device=self.device)
return self.model.generate(input_ids=input_ids, do_sample=True, num_beams=num_beams, max_length=max_length,
min_length=min_length, num_return_sequences=num_return_sequences, pad_token_id=0)
class __fastertransformer_generate_model__(object):
def __init__(self, model_path, tokenizer_path, config_path) -> None:
self.tokenizer = transformers.PreTrainedTokenizerFast(tokenizer_file=tokenizer_path)
self.encode = self.tokenizer.encode
self.decode = self.tokenizer.decode
with open(config_path) as config_file:
config_dict = json.loads("".join(config_file.readlines()))
assert config_dict["n_embd"] % config_dict["n_head"] == 0
head_num, size_per_head, vocab_size = config_dict["n_head"], config_dict["n_embd"] // config_dict["n_head"], config_dict["vocab_size"]
start_id, end_id = 0, 0
layer_num, max_seq_len = config_dict["n_layer"], config_dict["n_positions"]
tensor_para_size, pipeline_para_size = 1, 1
lib_path = "/workspace/FasterTransformerCustom/fastertransformer/a04b6ce/FasterTransformer/release/lib/libth_transformer.so"
inference_data_type, weights_data_type = "fp16", "fp16"
self.model = ParallelGPT(head_num, size_per_head, vocab_size, start_id, end_id, layer_num, max_seq_len, tensor_para_size, pipeline_para_size,
lib_path=lib_path, inference_data_type=inference_data_type, weights_data_type=weights_data_type, shared_contexts_ratio=0)
assert self.model.load(ckpt_path=model_path)
self.device = self.model.device
def generate(self, input_ids: torch.IntTensor, beam_width, output_length, prompt_length):
input_ids = input_ids.to(device=self.device)
with torch.no_grad():
# Generate tokens.
input_lengths = torch.IntTensor( [prompt_length] * input_ids.size()[0] )
tokens_batch = self.model(start_ids=input_ids, start_lengths=input_lengths, output_len=output_length, beam_width=beam_width)
return tokens_batch
if __name__ == "__main__":
model1 = __huggingface_generate_model__(
# a gpt2 model with 255m parameters
use_deepspeed=True
)
model2 = __fastertransformer_generate_model__(
# the same gpt2 model
)
import time
text = "public Date getGmtCreate ( ) { return gmtCreate ; } public void setGmtCreate ( Date gmtCreate ) {"
input_ids = model1.encode(text)
start_time = time.time()
result = model1.generate(input_ids=torch.IntTensor([input_ids]), beam_width=4, output_length=32, prompt_length=len(input_ids))
latency = time.time() - start_time
result = result.detach().cpu().tolist()
for output_ids in result:
print(model1.decode(output_ids[len(input_ids): ]))
print(latency)
input_ids = model2.encode(text)
start_time = time.time()
result = model2.generate(input_ids=torch.IntTensor([input_ids]), beam_width=4, output_length=32, prompt_length=len(input_ids))
latency = time.time() - start_time
result = result.detach().cpu().tolist()[0]
for output_ids in result:
print(model2.decode(output_ids[len(input_ids): ]))
print(latency)
Code's output. We can see that fastertransformer was faster than deepspeed: (different results maybe due to different beam search strategies in FT and HF)
[2023-03-06 12:58:15,476] [INFO] [logging.py:75:log_dist] [Rank -1] DeepSpeed info: version=0.8.1, git-hash=unknown, git-branch=unknown
[2023-03-06 12:58:15,487] [WARNING] [config_utils.py:74:_process_deprecated_field] Config parameter mp_size is deprecated use tensor_parallel.tp_size instead
[2023-03-06 12:58:15,488] [INFO] [logging.py:75:log_dist] [Rank -1] quantize_bits = 8 mlp_extra_grouping = False, quantize_groups = 1
[2023-03-06 12:58:19,527] [INFO] [logging.py:75:log_dist] [Rank -1] DeepSpeed-Inference config: {'layer_id': 0, 'hidden_size': 1024, 'intermediate_size': 4096, 'heads': 16, 'num_hidden_layers': -1, 'fp16': True, 'pre_layer_norm': True, 'local_rank': -1, 'stochastic_mode': False, 'epsilon': 1e-05, 'mp_size': 1, 'q_int8': False, 'scale_attention': True, 'triangular_masking': True, 'local_attention': False, 'window_size': 1, 'rotary_dim': -1, 'rotate_half': False, 'rotate_every_two': True, 'return_tuple': True, 'mlp_after_attn': True, 'mlp_act_func_type': <ActivationFuncType.GELU: 1>, 'specialized_mode': False, 'training_mp_size': 1, 'bigscience_bloom': False, 'max_out_tokens': 1024, 'scale_attn_by_inverse_layer_idx': False, 'enable_qkv_quantization': False, 'use_mup': False, 'return_single_tuple': False}
------------------------------------------------------
Free memory : 21.991699 (GigaBytes)
Total memory: 23.650208 (GigaBytes)
Requested memory: 0.156250 (GigaBytes)
Setting maximum total tokens (input + output) to 1024
------------------------------------------------------
[WARNING] gemm_config.in is not found; using default GEMM algo
[FT][WARNING] Skip NCCL initialization since requested tensor/pipeline parallel sizes are equals to 1.
[FT][INFO] Device NVIDIA TITAN RTX
this. gmtCreate = gmtCreate ; } public Date getGmtModified ( ) { return gmtModified ; } public void setGmtModified ( Date gmtModified ) { this. gmtModified = gmtModified ; }
this. gmtCreate = gmtCreate ; } public Date getGmtModified ( ) { return gmtModified ; } public void setGmtModified ( Date gmtModified ) { this. gmtModified = gmtModified ; }
this. gmtCreate = gmtCreate ; } public Date getGmtModified ( ) { return gmtModified ; } public void setGmtModified ( Date gmtModified ) { this. gmtModified = gmtModified ; }
this. gmtCreate = gmtCreate ; } public Date getGmtModified ( ) { return gmtModified ; } public void setGmtModified ( Date gmtModified ) { this. gmtModified = gmtModified ; }
0.47239160537719727
this. gmtCreate = gmtCreate ; } public Date getGmtModified ( ) { return gmtModified ; } public void setGmtModified ( Date gmtModified ) { this. gmtModified = gmtModified ; }
this. gmtCreate = gmtCreate ; } public Date getGmtModify ( ) { return gmtModify ; } public void setGmtModify ( Date gmtModify ) { this. gmtModify = gmtModify
this. gmtCreate = gmtCreate ; } }<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
this. gmtCreate = gmtCreate ; } @ Override public String toString ( ) { return ToStringBuilder. reflectionToString ( this, ToStringStyle. SHORT_PREFIX_STYLE ) ; }
0.07379007339477539
Reactions are currently unavailable

