diff --git a/src/transformers/commands/add_new_model_like.py b/src/transformers/commands/add_new_model_like.py index 7ff3883d7ad3..c7eafd7d87f7 100644 --- a/src/transformers/commands/add_new_model_like.py +++ b/src/transformers/commands/add_new_model_like.py @@ -25,6 +25,7 @@ import transformers.models.auto as auto_module from transformers.models.auto.configuration_auto import model_type_to_module_name +from ..file_utils import is_flax_available, is_tf_available, is_torch_available from ..utils import logging from . import BaseTransformersCLICommand @@ -501,7 +502,7 @@ def filter_framework_files( `List[Union[str, os.PathLike]]`: The list of filtered files. """ if frameworks is None: - return files + frameworks = get_default_frameworks() framework_to_file = {} others = [] @@ -598,6 +599,20 @@ def find_base_model_checkpoint( return "" +def get_default_frameworks(): + """ + Returns the list of frameworks (PyTorch, TensorFlow, Flax) that are installed in the environment. + """ + frameworks = [] + if is_torch_available(): + frameworks.append("pt") + if is_tf_available(): + frameworks.append("tf") + if is_flax_available(): + frameworks.append("flax") + return frameworks + + _re_model_mapping = re.compile("MODEL_([A-Z_]*)MAPPING_NAMES") @@ -616,17 +631,19 @@ def retrieve_model_classes(model_type: str, frameworks: Optional[List[str]] = No that framework as values. """ if frameworks is None: - frameworks = ["pt", "tf", "flax"] + frameworks = get_default_frameworks() modules = { - "pt": auto_module.modeling_auto, - "tf": auto_module.modeling_tf_auto, - "flax": auto_module.modeling_flax_auto, + "pt": auto_module.modeling_auto if is_torch_available() else None, + "tf": auto_module.modeling_tf_auto if is_tf_available() else None, + "flax": auto_module.modeling_flax_auto if is_flax_available() else None, } model_classes = {} for framework in frameworks: new_model_classes = [] + if modules[framework] is None: + raise ValueError(f"You selected {framework} in the frameworks, but it is not installed.") model_mappings = [attr for attr in dir(modules[framework]) if _re_model_mapping.search(attr) is not None] for model_mapping_name in model_mappings: model_mapping = getattr(modules[framework], model_mapping_name) @@ -683,9 +700,9 @@ def retrieve_info_for_model(model_type, frameworks: Optional[List[str]] = None): available_frameworks.append("pt") if frameworks is None: - frameworks = available_frameworks.copy() - else: - frameworks = [f for f in frameworks if f in available_frameworks] + frameworks = get_default_frameworks() + + frameworks = [f for f in frameworks if f in available_frameworks] model_classes = retrieve_model_classes(model_type, frameworks=frameworks) @@ -738,7 +755,7 @@ def clean_frameworks_in_init( Whether or not to keep the preprocessing (tokenizer, feature extractor, processor) imports in the init. """ if frameworks is None: - frameworks = ["pt", "tf", "flax"] + frameworks = get_default_frameworks() names = {"pt": "torch"} to_remove = [names.get(f, f) for f in ["pt", "tf", "flax"] if f not in frameworks] @@ -1040,7 +1057,7 @@ def duplicate_doc_file( content = f.read() if frameworks is None: - frameworks = ["pt", "tf", "flax"] + frameworks = get_default_frameworks() if dest_file is None: dest_file = Path(doc_file).parent / f"{new_model_patterns.model_type}.mdx" @@ -1302,7 +1319,7 @@ def __init__(self, config_file=None, path_to_repo=None, *args): self.old_model_type = config["old_model_type"] self.model_patterns = ModelPatterns(**config["new_model_patterns"]) self.add_copied_from = config.get("add_copied_from", True) - self.frameworks = config.get("frameworks", ["pt", "tf", "flax"]) + self.frameworks = config.get("frameworks", get_default_frameworks()) self.old_checkpoint = config.get("old_checkpoint", None) else: (