Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/transformers/utils/auto_docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,7 @@ def contains_type(type_hint, target_type) -> tuple[bool, Optional[object]]:
if args == ():
try:
return issubclass(type_hint, target_type), type_hint
except Exception as _:
except Exception:
return issubclass(type(type_hint), target_type), type_hint
found_type_tuple = [contains_type(arg, target_type)[0] for arg in args]
found_type = any(found_type_tuple)
Expand All @@ -1112,6 +1112,8 @@ def get_model_name(obj):
Get the model name from the file path of the object.
"""
path = inspect.getsourcefile(obj)
if path is None:
return None
if path.split(os.path.sep)[-3] != "models":
return None
file_name = path.split(os.path.sep)[-1]
Expand Down Expand Up @@ -1783,9 +1785,10 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No

is_dataclass = False
docstring_init = ""
docstring_args = ""
if "PreTrainedModel" in (x.__name__ for x in cls.__mro__):
docstring_init = auto_method_docstring(
cls.__init__, parent_class=cls, custom_args=custom_args
cls.__init__, parent_class=cls, custom_args=custom_args, checkpoint=checkpoint
).__doc__.replace("Args:", "Parameters:")
elif "ModelOutput" in (x.__name__ for x in cls.__mro__):
# We have a data class
Expand All @@ -1797,6 +1800,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
cls.__init__,
parent_class=cls,
custom_args=custom_args,
checkpoint=checkpoint,
source_args_dict=get_args_doc_from_source(ModelOutputArgs),
).__doc__
indent_level = get_indent_level(cls)
Expand Down Expand Up @@ -1836,7 +1840,7 @@ def auto_class_docstring(cls, custom_intro=None, custom_args=None, checkpoint=No
docstring += docstring_args if docstring_args else "\nArgs:\n"
source_args_dict = get_args_doc_from_source(ModelOutputArgs)
doc_class = cls.__doc__ if cls.__doc__ else ""
documented_kwargs, _ = parse_docstring(doc_class)
documented_kwargs = parse_docstring(doc_class)[0]
for param_name, param_type_annotation in cls.__annotations__.items():
param_type = str(param_type_annotation)
optional = False
Expand Down