diff --git a/openllm-python/src/openllm_cli/_factory.py b/openllm-python/src/openllm_cli/_factory.py index 40e06bf7e..fff432c6a 100644 --- a/openllm-python/src/openllm_cli/_factory.py +++ b/openllm-python/src/openllm_cli/_factory.py @@ -304,9 +304,9 @@ def machine_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[ def dtype_option(f: _AnyCallable | None = None, **attrs: t.Any) -> t.Callable[[FC], FC]: return cli_option( '--dtype', - type=click.Choice(['float16', 'float32', 'bfloat16']), + type=click.Choice(['float16', 'float32', 'bfloat16', 'auto']), envvar='TORCH_DTYPE', - default='float16', + default='auto', help='Optional dtype for casting tensors for running inference.', **attrs, )(f)