Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 28 additions & 11 deletions src/transformers/commands/add_new_model_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import transformers.models.auto as auto_module
from transformers.models.auto.configuration_auto import model_type_to_module_name

from ..file_utils import is_flax_available, is_tf_available, is_torch_available
from ..utils import logging
from . import BaseTransformersCLICommand

Expand Down Expand Up @@ -501,7 +502,7 @@ def filter_framework_files(
`List[Union[str, os.PathLike]]`: The list of filtered files.
"""
if frameworks is None:
return files
frameworks = get_default_frameworks()

framework_to_file = {}
others = []
Expand Down Expand Up @@ -598,6 +599,20 @@ def find_base_model_checkpoint(
return ""


def get_default_frameworks():
"""
Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment.
"""
frameworks = []
if is_torch_available():
frameworks.append("pt")
if is_tf_available():
frameworks.append("tf")
if is_flax_available():
frameworks.append("flax")
return frameworks


_re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES")


Expand All @@ -616,17 +631,19 @@ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = No
that framework as values.
"""
if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()

modules = {
"pt": auto_module.modeling_auto,
"tf": auto_module.modeling_tf_auto,
"flax": auto_module.modeling_flax_auto,
"pt": auto_module.modeling_auto if is_torch_available() else None,
"tf": auto_module.modeling_tf_auto if is_tf_available() else None,
"flax": auto_module.modeling_flax_auto if is_flax_available() else None,
}

model_classes = {}
for framework in frameworks:
new_model_classes = []
if modules[framework] is None:
raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.")
model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None]
for model_mapping_name in model_mappings:
model_mapping = getattr(modules[framework], model_mapping_name)
Expand Down Expand Up @@ -683,9 +700,9 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None):
available_frameworks.append("pt")

if frameworks is None:
frameworks = available_frameworks.copy()
else:
frameworks = [f for f in frameworks if f in available_frameworks]
frameworks = get_default_frameworks()

frameworks = [f for f in frameworks if f in available_frameworks]

model_classes = retrieve_model_classes(model_type, frameworks=frameworks)

Expand Down Expand Up @@ -738,7 +755,7 @@ def clean_frameworks_in_init(
Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init.
"""
if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()

names = {"pt": "torch"}
to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks]
Expand Down Expand Up @@ -1040,7 +1057,7 @@ def duplicate_doc_file(
content = f.read()

if frameworks is None:
frameworks = ["pt", "tf", "flax"]
frameworks = get_default_frameworks()
if dest_file is None:
dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx"

Expand Down Expand Up @@ -1302,7 +1319,7 @@ def __init__(self, config_file=None, path_to_repo=None, *args):
self.old_model_type = config["old_model_type"]
self.model_patterns = ModelPatterns(**config["new_model_patterns"])
self.add_copied_from = config.get("add_copied_from", True)
self.frameworks = config.get("frameworks", ["pt", "tf", "flax"])
self.frameworks = config.get("frameworks", get_default_frameworks())
self.old_checkpoint = config.get("old_checkpoint", None)
else:
(
Expand Down