Skip to content

Commit

Permalink
[python] Fixes streaming token device mismatch bug (#822)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu committed Jun 10, 2023
1 parent e91eae9 commit eabaeab
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def inference(self, inputs: Input):
if self.task == "text-generation":
tokenized_inputs = self.tokenizer(
input_data, padding=True,
return_tensors="pt").to(torch.cuda.current_device())
return_tensors="pt").to(self.device)
with torch.no_grad():
output_tokens = self.model.generate(
input_ids=tokenized_inputs.input_ids,
Expand Down
2 changes: 1 addition & 1 deletion engines/python/setup/djl_python/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def wrapped_pipeline(inputs, *args, **kwargs):
tokenizer = hf_pipeline.tokenizer
input_tokens = tokenizer(inputs, padding=True, return_tensors="pt")
if self.device:
input_tokens.to(self.device)
input_tokens = input_tokens.to(self.device)

with torch.no_grad():
output_tokens = model.generate(
Expand Down
19 changes: 10 additions & 9 deletions engines/python/setup/djl_python/streaming_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def use_hf_default_streamer(model, tokenizer, inputs, device, **kwargs):
if not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.eos_token
input_tokens = tokenizer(inputs, padding=True, return_tensors="pt")
if device:
input_tokens.to(device)
if device is not None:
input_tokens = input_tokens.to(device)

streamer = HFStreamer(tokenizer, skip_special_token=True)
generation_kwargs = dict(input_tokens, streamer=streamer, **kwargs)
Expand Down Expand Up @@ -117,8 +117,8 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs):
StreamingUtils.DEFAULT_MAX_NEW_TOKENS)
tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True)
input_ids = tokenized_inputs["input_ids"]
if device:
input_ids.to(device)
if device is not None:
input_ids = input_ids.to(device)

past_key_values = None
decoding_method = StreamingUtils._get_decoding_method(**kwargs)
Expand All @@ -144,18 +144,19 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs):
input_length] = 1 if is_pad_token_equal_to_eos_token else tokenized_inputs[
"attention_mask"]
curr_length = input_length

if generic_model_class == "Seq2SeqLM":
elif generic_model_class == "Seq2SeqLM":
attention_mask = tokenized_inputs["attention_mask"]
if device:
attention_mask.to(device)

decoder_attention_mask = None
encoder_last_hidden_state = None
decoder_input_ids = torch.tensor(tokenizer.bos_token_id,
device=device).repeat(
len(inputs)).view(-1, 1)
all_decoder_input_ids = decoder_input_ids
else:
raise ValueError(f"Unsupported model class: {generic_model_class}")

if device is not None:
attention_mask = attention_mask.to(device)

while True:
if stop_generation:
Expand Down

0 comments on commit eabaeab

Please sign in to comment.