From 87d3a5c72b03284c471bbdb2f492094df68d4f8a Mon Sep 17 00:00:00 2001 From: Yohei Tamura Date: Tue, 25 Aug 2020 17:57:08 +0900 Subject: [PATCH] Add typing.overload for convert_ids_tokens (#6637) * add overload for type checker * black --- src/transformers/tokenization_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 3121980c0d5bbe..36bec61d01f332 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -20,7 +20,7 @@ import logging import re import unicodedata -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, overload from .file_utils import add_end_docstrings from .tokenization_utils_base import ( @@ -657,6 +657,14 @@ def get_special_tokens_mask( """ return [0] * ((len(token_ids_1) if token_ids_1 else 0) + len(token_ids_0)) + @overload + def convert_ids_to_tokens(self, ids: int, skip_special_tokens: bool = False) -> str: + ... + + @overload + def convert_ids_to_tokens(self, ids: List[int], skip_special_tokens: bool = False) -> List[str]: + ... + def convert_ids_to_tokens( self, ids: Union[int, List[int]], skip_special_tokens: bool = False ) -> Union[str, List[str]]: