diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index c259d2035573..a9d9a8cba788 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -1098,7 +1098,7 @@ def contains_type(type_hint, target_type) -> tuple[bool, Optional[object]]: if args == (): try: return issubclass(type_hint, target_type), type_hint - except Exception as _: + except Exception: return issubclass(type(type_hint), target_type), type_hint found_type_tuple = [contains_type(arg, target_type)[0] for arg in args] found_type = any(found_type_tuple) @@ -1112,6 +1112,8 @@ def get_model_name(obj): Get the model name from the file path of the object. """ path = inspect.getsourcefile(obj) + if path is None: + return None if path.split(os.path.sep)[-3] != "models": return None file_name = path.split(os.path.sep)[-1] @@ -1783,9 +1785,10 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No is_dataclass = False docstring_init = "" + docstring_args = "" if "PreTrainedModel" in (x.__name__ for x in cls.__mro__): docstring_init = auto_method_docstring( - cls.__init__, parent_class=cls, custom_args=custom_args + cls.__init__, parent_class=cls, custom_args=custom_args, checkpoint=checkpoint ).__doc__.replace("Args:", "Parameters:") elif "ModelOutput" in (x.__name__ for x in cls.__mro__): # We have a data class @@ -1797,6 +1800,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No cls.__init__, parent_class=cls, custom_args=custom_args, + checkpoint=checkpoint, source_args_dict=get_args_doc_from_source(ModelOutputArgs), ).__doc__ indent_level = get_indent_level(cls) @@ -1836,7 +1840,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No docstring += docstring_args if docstring_args else "\nArgs:\n" source_args_dict = get_args_doc_from_source(ModelOutputArgs) doc_class = cls.__doc__ if cls.__doc__ else "" - documented_kwargs, _ = parse_docstring(doc_class) + documented_kwargs = parse_docstring(doc_class)[0] for param_name, param_type_annotation in cls.__annotations__.items(): param_type = str(param_type_annotation) optional = False