Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/transformers/commands/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,14 @@ class ChatArguments:
def __post_init__(self):
"""Only used for BC `torch_dtype` argument."""
# In this case only the BC torch_dtype was given
if self.torch_dtype is not None and self.dtype == "auto":
self.dtype = self.torch_dtype
if self.torch_dtype is not None:
if self.dtype is None:
self.dtype = self.torch_dtype
elif self.torch_dtype != self.dtype:
raise ValueError(
f"`torch_dtype` {self.torch_dtype} and `dtype` {self.dtype} have different values. `torch_dtype` is deprecated and "
"will be removed in 4.59.0, please set `dtype` instead."
)


def chat_command_factory(args: Namespace):
Expand Down
10 changes: 8 additions & 2 deletions src/transformers/commands/serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,14 @@ class ServeArguments:
def __post_init__(self):
"""Only used for BC `torch_dtype` argument."""
# In this case only the BC torch_dtype was given
if self.torch_dtype is not None and self.dtype == "auto":
self.dtype = self.torch_dtype
if self.torch_dtype is not None:
if self.dtype is None:
self.dtype = self.torch_dtype
elif self.torch_dtype != self.dtype:
raise ValueError(
f"`torch_dtype` {self.torch_dtype} and `dtype` {self.dtype} have different values. `torch_dtype` is deprecated and "
"will be removed in 4.59.0, please set `dtype` instead."
)


class ServeCommand(BaseTransformersCLICommand):
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/keypoint_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __call__(
def preprocess(self, images, timeout=None):
images = [load_image(image, timeout=timeout) for image in images]
model_inputs = self.image_processor(images=images, return_tensors=self.framework)
model_inputs = model_inputs.to(self.torch_dtype)
model_inputs = model_inputs.to(self.dtype)
target_sizes = [image.size for image in images]
preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
return preprocess_outputs
Expand Down