Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/transformers/cli/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Callable
from datetime import date
from pathlib import Path
from typing import Annotated, Any, Optional, Union
from typing import Annotated, Any

import typer

Expand Down Expand Up @@ -95,7 +95,7 @@ def visit_SimpleStatementLine(self, node: cst.SimpleStatementLine):

def add_new_model_like(
repo_path: Annotated[
Optional[str], typer.Argument(help="When not using an editable install, the path to the Transformers repo.")
str | None, typer.Argument(help="When not using an editable install, the path to the Transformers repo.")
] = None,
):
"""
Expand Down Expand Up @@ -156,7 +156,7 @@ def __init__(self, lowercase_name: str):
self.processor_class = PROCESSOR_MAPPING_NAMES.get(self.lowercase_name, None)


def add_content_to_file(file_name: Union[str, os.PathLike], new_content: str, add_after: str):
def add_content_to_file(file_name: str | os.PathLike, new_content: str, add_after: str):
"""
A utility to add some content inside a given file.
Expand Down Expand Up @@ -614,9 +614,9 @@ def _add_new_model_like_internal(

def get_user_field(
question: str,
default_value: Optional[str] = None,
convert_to: Optional[Callable] = None,
fallback_message: Optional[str] = None,
default_value: str | None = None,
convert_to: Callable | None = None,
fallback_message: str | None = None,
) -> Any:
"""
A utility function that asks a question to the user to get an answer, potentially looping until it gets a valid
Expand Down
18 changes: 9 additions & 9 deletions src/transformers/cli/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import string
import time
from collections.abc import AsyncIterator
from typing import Annotated, Optional
from typing import Annotated

import click
import typer
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(
base_url: Annotated[str, typer.Argument(help="Base url to connect to (e.g. http://localhost:8000/v1).")],
model_id: Annotated[str, typer.Argument(help="ID of the model to use (e.g. 'HuggingFaceTB/SmolLM3-3B').")],
generate_flags: Annotated[
Optional[list[str]],
list[str] | None,
typer.Argument(
help=(
"Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, "
Expand All @@ -227,15 +227,15 @@ def __init__(
] = None,
# General settings
user: Annotated[
Optional[str],
str | None,
typer.Option(help="Username to display in chat interface. Defaults to the current user's name."),
] = None,
system_prompt: Annotated[Optional[str], typer.Option(help="System prompt.")] = None,
system_prompt: Annotated[str | None, typer.Option(help="System prompt.")] = None,
save_folder: Annotated[str, typer.Option(help="Folder to save chat history.")] = "./chat_history/",
examples_path: Annotated[Optional[str], typer.Option(help="Path to a yaml file with examples.")] = None,
examples_path: Annotated[str | None, typer.Option(help="Path to a yaml file with examples.")] = None,
# Generation settings
generation_config: Annotated[
Optional[str],
str | None,
typer.Option(
help="Path to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config."
),
Expand Down Expand Up @@ -455,7 +455,7 @@ async def _inner_run(self):
break


def load_generation_config(generation_config: Optional[str]) -> GenerationConfig:
def load_generation_config(generation_config: str | None) -> GenerationConfig:
if generation_config is None:
return GenerationConfig()

Expand All @@ -467,7 +467,7 @@ def load_generation_config(generation_config: Optional[str]) -> GenerationConfig
return GenerationConfig.from_pretrained(generation_config)


def parse_generate_flags(generate_flags: Optional[list[str]]) -> dict:
def parse_generate_flags(generate_flags: list[str] | None) -> dict:
"""Parses the generate flags from the user input into a dictionary of `generate` kwargs."""
if generate_flags is None or len(generate_flags) == 0:
return {}
Expand Down Expand Up @@ -521,7 +521,7 @@ def is_number(s: str) -> bool:
return processed_generate_flags


def new_chat_history(system_prompt: Optional[str] = None) -> list[dict]:
def new_chat_history(system_prompt: str | None = None) -> list[dict]:
"""Returns a new chat conversation."""
return [{"role": "system", "content": system_prompt}] if system_prompt else []

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/cli/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Annotated, Optional
from typing import Annotated

import typer


def download(
model_id: Annotated[str, typer.Argument(help="The model ID to download")],
cache_dir: Annotated[Optional[str], typer.Option(help="Directory where to save files.")] = None,
cache_dir: Annotated[str | None, typer.Option(help="Directory where to save files.")] = None,
force_download: Annotated[
bool, typer.Option(help="If set, the files will be downloaded even if they are already cached locally.")
] = False,
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/cli/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Annotated, Optional
from typing import Annotated

import typer

Expand All @@ -28,27 +28,27 @@

def run(
task: Annotated[TaskEnum, typer.Argument(help="Task to run", case_sensitive=False)], # type: ignore
input: Annotated[Optional[str], typer.Option(help="Path to the file to use for inference")] = None,
input: Annotated[str | None, typer.Option(help="Path to the file to use for inference")] = None,
output: Annotated[
Optional[str], typer.Option(help="Path to the file that will be used post to write results.")
str | None, typer.Option(help="Path to the file that will be used post to write results.")
] = None,
model: Annotated[
Optional[str],
str | None,
typer.Option(
help="Name or path to the model to instantiate. If not provided, will use the default model for that task."
),
] = None,
config: Annotated[
Optional[str],
str | None,
typer.Option(
help="Name or path to the model's config to instantiate. If not provided, will use the model's one."
),
] = None,
tokenizer: Annotated[
Optional[str], typer.Option(help="Name of the tokenizer to use. If not provided, will use the model's one.")
str | None, typer.Option(help="Name of the tokenizer to use. If not provided, will use the model's one.")
] = None,
column: Annotated[
Optional[str],
str | None,
typer.Option(help="Name of the column to use as input. For multi columns input use 'column1,columns2'"),
] = None,
format: Annotated[FormatEnum, typer.Option(help="Input format to read from", case_sensitive=False)] = "pipe", # type: ignore
Expand Down
26 changes: 13 additions & 13 deletions src/transformers/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def __init__(
self,
model: "PreTrainedModel",
timeout_seconds: int,
processor: Optional[Union["ProcessorMixin", "PreTrainedTokenizerFast"]] = None,
processor: Union["ProcessorMixin", "PreTrainedTokenizerFast"] | None = None,
):
self.model = model
self._name_or_path = str(model.name_or_path)
Expand Down Expand Up @@ -363,7 +363,7 @@ def __init__(
),
] = "auto",
dtype: Annotated[
Optional[str],
str | None,
typer.Option(
help="Override the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights."
),
Expand All @@ -372,7 +372,7 @@ def __init__(
bool, typer.Option(help="Whether to trust remote code when loading a model.")
] = False,
attn_implementation: Annotated[
Optional[str],
str | None,
typer.Option(
help="Which attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
),
Expand All @@ -390,7 +390,7 @@ def __init__(
str, typer.Option(help="Logging level as a string. Example: 'info' or 'warning'.")
] = "info",
default_seed: Annotated[
Optional[int], typer.Option(help="The default seed for torch, should be an integer.")
int | None, typer.Option(help="The default seed for torch, should be an integer.")
] = None,
enable_cors: Annotated[
bool,
Expand All @@ -400,7 +400,7 @@ def __init__(
] = False,
input_validation: Annotated[bool, typer.Option(help="Whether to turn on strict input validation.")] = False,
force_model: Annotated[
Optional[str],
str | None,
typer.Option(
help="Name of the model to be forced on all requests. This is useful for testing Apps that don't allow changing models in the request."
),
Expand Down Expand Up @@ -445,7 +445,7 @@ def __init__(
# Internal state:
# 1. Tracks models in memory, to prevent reloading the model unnecessarily
self.loaded_models: dict[str, TimedModel] = {}
self.running_continuous_batching_manager: Optional[ContinuousBatchingManager] = None
self.running_continuous_batching_manager: ContinuousBatchingManager | None = None

# 2. preserves information about the last call and last KV cache, to determine whether we can reuse the KV
# cache and avoid re-running prefill
Expand Down Expand Up @@ -648,13 +648,13 @@ def validate_transcription_request(self, request: dict):
def build_chat_completion_chunk(
self,
request_id: str = "",
content: Optional[int] = None,
model: Optional[str] = None,
role: Optional[str] = None,
finish_reason: Optional[str] = None,
tool_calls: Optional[list["ChoiceDeltaToolCall"]] = None,
decode_stream: Optional[DecodeStream] = None,
tokenizer: Optional[PreTrainedTokenizerFast] = None,
content: int | None = None,
model: str | None = None,
role: str | None = None,
finish_reason: str | None = None,
tool_calls: list["ChoiceDeltaToolCall"] | None = None,
decode_stream: DecodeStream | None = None,
tokenizer: PreTrainedTokenizerFast | None = None,
) -> ChatCompletionChunk:
"""
Builds a chunk of a streaming OpenAI Chat Completion response.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/cli/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import io
import os
import platform
from typing import Annotated, Optional
from typing import Annotated

import huggingface_hub
import typer
Expand All @@ -40,7 +40,7 @@

def env(
accelerate_config_file: Annotated[
Optional[str],
str | None,
typer.Argument(help="The accelerate config file to use for the default values in the launching script."),
] = None,
) -> None:
Expand Down
Loading