diff --git a/engines/python/setup/djl_python/deepspeed.py b/engines/python/setup/djl_python/deepspeed.py index da6e68b32..e28d0a777 100644 --- a/engines/python/setup/djl_python/deepspeed.py +++ b/engines/python/setup/djl_python/deepspeed.py @@ -330,7 +330,7 @@ def inference(self, inputs: Input): if self.task == "text-generation": tokenized_inputs = self.tokenizer( input_data, padding=True, - return_tensors="pt").to(torch.cuda.current_device()) + return_tensors="pt").to(self.device) with torch.no_grad(): output_tokens = self.model.generate( input_ids=tokenized_inputs.input_ids, diff --git a/engines/python/setup/djl_python/huggingface.py b/engines/python/setup/djl_python/huggingface.py index 3d307aab5..99f89db80 100644 --- a/engines/python/setup/djl_python/huggingface.py +++ b/engines/python/setup/djl_python/huggingface.py @@ -281,7 +281,7 @@ def wrapped_pipeline(inputs, *args, **kwargs): tokenizer = hf_pipeline.tokenizer input_tokens = tokenizer(inputs, padding=True, return_tensors="pt") if self.device: - input_tokens.to(self.device) + input_tokens = input_tokens.to(self.device) with torch.no_grad(): output_tokens = model.generate( diff --git a/engines/python/setup/djl_python/streaming_utils.py b/engines/python/setup/djl_python/streaming_utils.py index c5e71960f..87241dffb 100644 --- a/engines/python/setup/djl_python/streaming_utils.py +++ b/engines/python/setup/djl_python/streaming_utils.py @@ -71,8 +71,8 @@ def use_hf_default_streamer(model, tokenizer, inputs, device, **kwargs): if not tokenizer.pad_token: tokenizer.pad_token = tokenizer.eos_token input_tokens = tokenizer(inputs, padding=True, return_tensors="pt") - if device: - input_tokens.to(device) + if device is not None: + input_tokens = input_tokens.to(device) streamer = HFStreamer(tokenizer, skip_special_token=True) generation_kwargs = dict(input_tokens, streamer=streamer, **kwargs) @@ -117,8 +117,8 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs): StreamingUtils.DEFAULT_MAX_NEW_TOKENS) tokenized_inputs = tokenizer(inputs, return_tensors="pt", padding=True) input_ids = tokenized_inputs["input_ids"] - if device: - input_ids.to(device) + if device is not None: + input_ids = input_ids.to(device) past_key_values = None decoding_method = StreamingUtils._get_decoding_method(**kwargs) @@ -144,18 +144,19 @@ def _hf_model_stream_generator(model, tokenizer, inputs, device, **kwargs): input_length] = 1 if is_pad_token_equal_to_eos_token else tokenized_inputs[ "attention_mask"] curr_length = input_length - - if generic_model_class == "Seq2SeqLM": + elif generic_model_class == "Seq2SeqLM": attention_mask = tokenized_inputs["attention_mask"] - if device: - attention_mask.to(device) - decoder_attention_mask = None encoder_last_hidden_state = None decoder_input_ids = torch.tensor(tokenizer.bos_token_id, device=device).repeat( len(inputs)).view(-1, 1) all_decoder_input_ids = decoder_input_ids + else: + raise ValueError(f"Unsupported model class: {generic_model_class}") + + if device is not None: + attention_mask = attention_mask.to(device) while True: if stop_generation: