From f74f56078e360aac3d20b9a4a28c3d4fe9c13c14 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 14 Apr 2025 17:15:02 +0800 Subject: [PATCH 1/3] temp commit --- mindnlp/__init__.py | 1 + mindnlp/transformers.py | 2 ++ 2 files changed, 3 insertions(+) create mode 100644 mindnlp/transformers.py diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index 27e734e48..f543cdb72 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -20,6 +20,7 @@ import sys import platform from packaging import version +import importlib # huggingface env if os.environ.get('HF_ENDPOINT', None) is None: diff --git a/mindnlp/transformers.py b/mindnlp/transformers.py new file mode 100644 index 000000000..f99ea50ca --- /dev/null +++ b/mindnlp/transformers.py @@ -0,0 +1,2 @@ +import transformers +from transformers import * From 27130bdd8b1bbfd1ea193446b6b0a7b5c49567e2 Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 7 Jul 2025 15:20:55 +0800 Subject: [PATCH 2/3] use new core to adapt huggingface library --- mindnlp/__init__.py | 1 - mindnlp/transformers.py | 2 -- 2 files changed, 3 deletions(-) delete mode 100644 mindnlp/transformers.py diff --git a/mindnlp/__init__.py b/mindnlp/__init__.py index f543cdb72..27e734e48 100644 --- a/mindnlp/__init__.py +++ b/mindnlp/__init__.py @@ -20,7 +20,6 @@ import sys import platform from packaging import version -import importlib # huggingface env if os.environ.get('HF_ENDPOINT', None) is None: diff --git a/mindnlp/transformers.py b/mindnlp/transformers.py deleted file mode 100644 index f99ea50ca..000000000 --- a/mindnlp/transformers.py +++ /dev/null @@ -1,2 +0,0 @@ -import transformers -from transformers import * From 97680c2e967cfdd5b41479b3922fead449f8a70a Mon Sep 17 00:00:00 2001 From: lvyufeng Date: Mon, 7 Jul 2025 18:03:54 +0800 Subject: [PATCH 3/3] use lazymodule to load transformers submodules --- mindnlp/core/nn/functional.py | 7 +- mindnlp/transformers/__init__.py | 4449 ++++++++++++++++- mindnlp/transformers/models/__init__.py | 2 - mindnlp/transformers/models/auto.py | 16 - .../transformers/models/auto_bk/__init__.py | 21 - .../models/auto_bk/auto_factory.py | 772 --- .../models/auto_bk/modeling_auto.py | 448 -- mindnlp/transformers/pipelines.py | 0 mindnlp/utils/testing_utils.py | 2102 ++++++++ 9 files changed, 6553 insertions(+), 1264 deletions(-) delete mode 100644 mindnlp/transformers/models/__init__.py delete mode 100644 mindnlp/transformers/models/auto.py delete mode 100644 mindnlp/transformers/models/auto_bk/__init__.py delete mode 100644 mindnlp/transformers/models/auto_bk/auto_factory.py delete mode 100644 mindnlp/transformers/models/auto_bk/modeling_auto.py delete mode 100644 mindnlp/transformers/pipelines.py create mode 100644 mindnlp/utils/testing_utils.py diff --git a/mindnlp/core/nn/functional.py b/mindnlp/core/nn/functional.py index a2d180c18..e7db4975d 100644 --- a/mindnlp/core/nn/functional.py +++ b/mindnlp/core/nn/functional.py @@ -26,10 +26,9 @@ def relu(input, inplace=False): return execute('relu', input) def tanh(input, inplace=False): - if inplace: - execute('inplace_tanh', input) - return input - return execute('tanh', input) + if use_pyboost(): + return mint.nn.functional.tanh(input) + return ops.tanh(input) def sigmoid(input): diff --git a/mindnlp/transformers/__init__.py b/mindnlp/transformers/__init__.py index aed4fa323..e02e58ecb 100644 --- a/mindnlp/transformers/__init__.py +++ b/mindnlp/transformers/__init__.py @@ -1 +1,4448 @@ -from .models import * +import sys +import transformers +from transformers.utils import OptionalDependencyNotAvailable, _LazyModule +from transformers.utils.import_utils import * + + +# Base objects, independent of any specific backend +_import_structure = { + # "agents": [ + # "Agent", + # "CodeAgent", + # "HfApiEngine", + # "ManagedAgent", + # "PipelineTool", + # "ReactAgent", + # "ReactCodeAgent", + # "ReactJsonAgent", + # "Tool", + # "Toolbox", + # "ToolCollection", + # "TransformersEngine", + # "launch_gradio_demo", + # "load_tool", + # "stream_to_gradio", + # "tool", + # ], + "audio_utils": [], + "commands": [], + "configuration_utils": ["PretrainedConfig"], + "convert_slow_tokenizers_checkpoints_to_fast": [], + "data": [ + "DataProcessor", + "InputExample", + "InputFeatures", + "SingleSentenceClassificationProcessor", + "SquadExample", + "SquadFeatures", + "SquadV1Processor", + "SquadV2Processor", + "glue_compute_metrics", + "glue_convert_examples_to_features", + "glue_output_modes", + "glue_processors", + "glue_tasks_num_labels", + "squad_convert_examples_to_features", + "xnli_compute_metrics", + "xnli_output_modes", + "xnli_processors", + "xnli_tasks_num_labels", + ], + "data.data_collator": [ + "DataCollator", + "DataCollatorForLanguageModeling", + "DataCollatorForMultipleChoice", + "DataCollatorForPermutationLanguageModeling", + "DataCollatorForSeq2Seq", + "DataCollatorForSOP", + "DataCollatorForTokenClassification", + "DataCollatorForWholeWordMask", + "DataCollatorWithFlattening", + "DataCollatorWithPadding", + "DefaultDataCollator", + "default_data_collator", + ], + "data.metrics": [], + "data.processors": [], + "debug_utils": [], + "dependency_versions_check": [], + "dependency_versions_table": [], + "dynamic_module_utils": [], + "feature_extraction_sequence_utils": ["SequenceFeatureExtractor"], + "feature_extraction_utils": ["BatchFeature", "FeatureExtractionMixin"], + "file_utils": [], + "generation": [ + "AsyncTextIteratorStreamer", + "CompileConfig", + "GenerationConfig", + "TextIteratorStreamer", + "TextStreamer", + "WatermarkingConfig", + ], + "hf_argparser": ["HfArgumentParser"], + "hyperparameter_search": [], + "image_transforms": [], + "loss": [], + "modelcard": ["ModelCard"], + # Models + "models": [], + "models.albert": ["AlbertConfig"], + "models.align": [ + "AlignConfig", + "AlignProcessor", + "AlignTextConfig", + "AlignVisionConfig", + ], + "models.altclip": [ + "AltCLIPConfig", + "AltCLIPProcessor", + "AltCLIPTextConfig", + "AltCLIPVisionConfig", + ], + "models.aria": [ + "AriaConfig", + "AriaProcessor", + "AriaTextConfig", + ], + "models.audio_spectrogram_transformer": [ + "ASTConfig", + "ASTFeatureExtractor", + ], + "models.auto": [ + "CONFIG_MAPPING", + "FEATURE_EXTRACTOR_MAPPING", + "IMAGE_PROCESSOR_MAPPING", + "MODEL_NAMES_MAPPING", + "PROCESSOR_MAPPING", + "TOKENIZER_MAPPING", + "AutoConfig", + "AutoFeatureExtractor", + "AutoImageProcessor", + "AutoProcessor", + "AutoTokenizer", + ], + "models.autoformer": ["AutoformerConfig"], + "models.aya_vision": ["AyaVisionConfig", "AyaVisionProcessor"], + "models.bamba": ["BambaConfig"], + "models.bark": [ + "BarkCoarseConfig", + "BarkConfig", + "BarkFineConfig", + "BarkProcessor", + "BarkSemanticConfig", + ], + "models.bart": ["BartConfig", "BartTokenizer"], + "models.barthez": [], + "models.bartpho": [], + "models.beit": ["BeitConfig"], + "models.bert": [ + "BasicTokenizer", + "BertConfig", + "BertTokenizer", + "WordpieceTokenizer", + ], + "models.bert_generation": ["BertGenerationConfig"], + "models.bert_japanese": [ + "BertJapaneseTokenizer", + "CharacterTokenizer", + "MecabTokenizer", + ], + "models.bertweet": ["BertweetTokenizer"], + "models.big_bird": ["BigBirdConfig"], + "models.bigbird_pegasus": ["BigBirdPegasusConfig"], + "models.biogpt": [ + "BioGptConfig", + "BioGptTokenizer", + ], + "models.bit": ["BitConfig"], + "models.blenderbot": [ + "BlenderbotConfig", + "BlenderbotTokenizer", + ], + "models.blenderbot_small": [ + "BlenderbotSmallConfig", + "BlenderbotSmallTokenizer", + ], + "models.blip": [ + "BlipConfig", + "BlipProcessor", + "BlipTextConfig", + "BlipVisionConfig", + ], + "models.blip_2": [ + "Blip2Config", + "Blip2Processor", + "Blip2QFormerConfig", + "Blip2VisionConfig", + ], + "models.bloom": ["BloomConfig"], + "models.bridgetower": [ + "BridgeTowerConfig", + "BridgeTowerProcessor", + "BridgeTowerTextConfig", + "BridgeTowerVisionConfig", + ], + "models.bros": [ + "BrosConfig", + "BrosProcessor", + ], + "models.byt5": ["ByT5Tokenizer"], + "models.camembert": ["CamembertConfig"], + "models.canine": [ + "CanineConfig", + "CanineTokenizer", + ], + "models.chameleon": [ + "ChameleonConfig", + "ChameleonProcessor", + "ChameleonVQVAEConfig", + ], + "models.chinese_clip": [ + "ChineseCLIPConfig", + "ChineseCLIPProcessor", + "ChineseCLIPTextConfig", + "ChineseCLIPVisionConfig", + ], + "models.clap": [ + "ClapAudioConfig", + "ClapConfig", + "ClapProcessor", + "ClapTextConfig", + ], + "models.clip": [ + "CLIPConfig", + "CLIPProcessor", + "CLIPTextConfig", + "CLIPTokenizer", + "CLIPVisionConfig", + ], + "models.clipseg": [ + "CLIPSegConfig", + "CLIPSegProcessor", + "CLIPSegTextConfig", + "CLIPSegVisionConfig", + ], + "models.clvp": [ + "ClvpConfig", + "ClvpDecoderConfig", + "ClvpEncoderConfig", + "ClvpFeatureExtractor", + "ClvpProcessor", + "ClvpTokenizer", + ], + "models.code_llama": [], + "models.codegen": [ + "CodeGenConfig", + "CodeGenTokenizer", + ], + "models.cohere": ["CohereConfig"], + "models.cohere2": ["Cohere2Config"], + "models.colpali": [ + "ColPaliConfig", + "ColPaliProcessor", + ], + "models.conditional_detr": ["ConditionalDetrConfig"], + "models.convbert": [ + "ConvBertConfig", + "ConvBertTokenizer", + ], + "models.convnext": ["ConvNextConfig"], + "models.convnextv2": ["ConvNextV2Config"], + "models.cpm": [], + "models.cpmant": [ + "CpmAntConfig", + "CpmAntTokenizer", + ], + "models.ctrl": [ + "CTRLConfig", + "CTRLTokenizer", + ], + "models.cvt": ["CvtConfig"], + "models.dab_detr": ["DabDetrConfig"], + "models.dac": ["DacConfig", "DacFeatureExtractor"], + "models.data2vec": [ + "Data2VecAudioConfig", + "Data2VecTextConfig", + "Data2VecVisionConfig", + ], + "models.dbrx": ["DbrxConfig"], + "models.deberta": [ + "DebertaConfig", + "DebertaTokenizer", + ], + "models.deberta_v2": ["DebertaV2Config"], + "models.decision_transformer": ["DecisionTransformerConfig"], + "models.deepseek_v3": ["DeepseekV3Config"], + "models.deformable_detr": ["DeformableDetrConfig"], + "models.deit": ["DeiTConfig"], + "models.deprecated": [], + "models.deprecated.bort": [], + "models.deprecated.deta": ["DetaConfig"], + "models.deprecated.efficientformer": ["EfficientFormerConfig"], + "models.deprecated.ernie_m": ["ErnieMConfig"], + "models.deprecated.gptsan_japanese": [ + "GPTSanJapaneseConfig", + "GPTSanJapaneseTokenizer", + ], + "models.deprecated.graphormer": ["GraphormerConfig"], + "models.deprecated.jukebox": [ + "JukeboxConfig", + "JukeboxPriorConfig", + "JukeboxTokenizer", + "JukeboxVQVAEConfig", + ], + "models.deprecated.mctct": [ + "MCTCTConfig", + "MCTCTFeatureExtractor", + "MCTCTProcessor", + ], + "models.deprecated.mega": ["MegaConfig"], + "models.deprecated.mmbt": ["MMBTConfig"], + "models.deprecated.nat": ["NatConfig"], + "models.deprecated.nezha": ["NezhaConfig"], + "models.deprecated.open_llama": ["OpenLlamaConfig"], + "models.deprecated.qdqbert": ["QDQBertConfig"], + "models.deprecated.realm": [ + "RealmConfig", + "RealmTokenizer", + ], + "models.deprecated.retribert": [ + "RetriBertConfig", + "RetriBertTokenizer", + ], + "models.deprecated.speech_to_text_2": [ + "Speech2Text2Config", + "Speech2Text2Processor", + "Speech2Text2Tokenizer", + ], + "models.deprecated.tapex": ["TapexTokenizer"], + "models.deprecated.trajectory_transformer": ["TrajectoryTransformerConfig"], + "models.deprecated.transfo_xl": [ + "TransfoXLConfig", + "TransfoXLCorpus", + "TransfoXLTokenizer", + ], + "models.deprecated.tvlt": [ + "TvltConfig", + "TvltFeatureExtractor", + "TvltProcessor", + ], + "models.deprecated.van": ["VanConfig"], + "models.deprecated.vit_hybrid": ["ViTHybridConfig"], + "models.deprecated.xlm_prophetnet": ["XLMProphetNetConfig"], + "models.depth_anything": ["DepthAnythingConfig"], + "models.depth_pro": ["DepthProConfig"], + "models.detr": ["DetrConfig"], + "models.dialogpt": [], + "models.diffllama": ["DiffLlamaConfig"], + "models.dinat": ["DinatConfig"], + "models.dinov2": ["Dinov2Config"], + "models.dinov2_with_registers": ["Dinov2WithRegistersConfig"], + "models.distilbert": [ + "DistilBertConfig", + "DistilBertTokenizer", + ], + "models.dit": [], + "models.donut": [ + "DonutProcessor", + "DonutSwinConfig", + ], + "models.dpr": [ + "DPRConfig", + "DPRContextEncoderTokenizer", + "DPRQuestionEncoderTokenizer", + "DPRReaderOutput", + "DPRReaderTokenizer", + ], + "models.dpt": ["DPTConfig"], + "models.efficientnet": ["EfficientNetConfig"], + "models.electra": [ + "ElectraConfig", + "ElectraTokenizer", + ], + "models.emu3": [ + "Emu3Config", + "Emu3Processor", + "Emu3TextConfig", + "Emu3VQVAEConfig", + ], + "models.encodec": [ + "EncodecConfig", + "EncodecFeatureExtractor", + ], + "models.encoder_decoder": ["EncoderDecoderConfig"], + "models.ernie": ["ErnieConfig"], + "models.esm": ["EsmConfig", "EsmTokenizer"], + "models.falcon": ["FalconConfig"], + "models.falcon_mamba": ["FalconMambaConfig"], + "models.fastspeech2_conformer": [ + "FastSpeech2ConformerConfig", + "FastSpeech2ConformerHifiGanConfig", + "FastSpeech2ConformerTokenizer", + "FastSpeech2ConformerWithHifiGanConfig", + ], + "models.flaubert": ["FlaubertConfig", "FlaubertTokenizer"], + "models.flava": [ + "FlavaConfig", + "FlavaImageCodebookConfig", + "FlavaImageConfig", + "FlavaMultimodalConfig", + "FlavaTextConfig", + ], + "models.fnet": ["FNetConfig"], + "models.focalnet": ["FocalNetConfig"], + "models.fsmt": [ + "FSMTConfig", + "FSMTTokenizer", + ], + "models.funnel": [ + "FunnelConfig", + "FunnelTokenizer", + ], + "models.fuyu": ["FuyuConfig"], + "models.gemma": ["GemmaConfig"], + "models.gemma2": ["Gemma2Config"], + "models.gemma3": ["Gemma3Config", "Gemma3Processor", "Gemma3TextConfig"], + "models.git": [ + "GitConfig", + "GitProcessor", + "GitVisionConfig", + ], + "models.glm": ["GlmConfig"], + "models.glpn": ["GLPNConfig"], + "models.got_ocr2": [ + "GotOcr2Config", + "GotOcr2Processor", + "GotOcr2VisionConfig", + ], + "models.gpt2": [ + "GPT2Config", + "GPT2Tokenizer", + ], + "models.gpt_bigcode": ["GPTBigCodeConfig"], + "models.gpt_neo": ["GPTNeoConfig"], + "models.gpt_neox": ["GPTNeoXConfig"], + "models.gpt_neox_japanese": ["GPTNeoXJapaneseConfig"], + "models.gpt_sw3": [], + "models.gptj": ["GPTJConfig"], + "models.granite": ["GraniteConfig"], + "models.granitemoe": ["GraniteMoeConfig"], + "models.granitemoeshared": ["GraniteMoeSharedConfig"], + "models.grounding_dino": [ + "GroundingDinoConfig", + "GroundingDinoProcessor", + ], + "models.groupvit": [ + "GroupViTConfig", + "GroupViTTextConfig", + "GroupViTVisionConfig", + ], + "models.helium": ["HeliumConfig"], + "models.herbert": ["HerbertTokenizer"], + "models.hiera": ["HieraConfig"], + "models.hubert": ["HubertConfig"], + "models.ibert": ["IBertConfig"], + "models.idefics": ["IdeficsConfig"], + "models.idefics2": ["Idefics2Config"], + "models.idefics3": ["Idefics3Config"], + "models.ijepa": ["IJepaConfig"], + "models.imagegpt": ["ImageGPTConfig"], + "models.informer": ["InformerConfig"], + "models.instructblip": [ + "InstructBlipConfig", + "InstructBlipProcessor", + "InstructBlipQFormerConfig", + "InstructBlipVisionConfig", + ], + "models.instructblipvideo": [ + "InstructBlipVideoConfig", + "InstructBlipVideoProcessor", + "InstructBlipVideoQFormerConfig", + "InstructBlipVideoVisionConfig", + ], + "models.jamba": ["JambaConfig"], + "models.jetmoe": ["JetMoeConfig"], + "models.kosmos2": [ + "Kosmos2Config", + "Kosmos2Processor", + ], + "models.layoutlm": [ + "LayoutLMConfig", + "LayoutLMTokenizer", + ], + "models.layoutlmv2": [ + "LayoutLMv2Config", + "LayoutLMv2FeatureExtractor", + "LayoutLMv2ImageProcessor", + "LayoutLMv2Processor", + "LayoutLMv2Tokenizer", + ], + "models.layoutlmv3": [ + "LayoutLMv3Config", + "LayoutLMv3FeatureExtractor", + "LayoutLMv3ImageProcessor", + "LayoutLMv3Processor", + "LayoutLMv3Tokenizer", + ], + "models.layoutxlm": ["LayoutXLMProcessor"], + "models.led": ["LEDConfig", "LEDTokenizer"], + "models.levit": ["LevitConfig"], + "models.lilt": ["LiltConfig"], + "models.llama": ["LlamaConfig"], + "models.llama4": [ + "Llama4Config", + "Llama4Processor", + "Llama4TextConfig", + "Llama4VisionConfig", + ], + "models.llava": [ + "LlavaConfig", + "LlavaProcessor", + ], + "models.llava_next": [ + "LlavaNextConfig", + "LlavaNextProcessor", + ], + "models.llava_next_video": [ + "LlavaNextVideoConfig", + "LlavaNextVideoProcessor", + ], + "models.llava_onevision": ["LlavaOnevisionConfig", "LlavaOnevisionProcessor"], + "models.longformer": [ + "LongformerConfig", + "LongformerTokenizer", + ], + "models.longt5": ["LongT5Config"], + "models.luke": [ + "LukeConfig", + "LukeTokenizer", + ], + "models.lxmert": [ + "LxmertConfig", + "LxmertTokenizer", + ], + "models.m2m_100": ["M2M100Config"], + "models.mamba": ["MambaConfig"], + "models.mamba2": ["Mamba2Config"], + "models.marian": ["MarianConfig"], + "models.markuplm": [ + "MarkupLMConfig", + "MarkupLMFeatureExtractor", + "MarkupLMProcessor", + "MarkupLMTokenizer", + ], + "models.mask2former": ["Mask2FormerConfig"], + "models.maskformer": [ + "MaskFormerConfig", + "MaskFormerSwinConfig", + ], + "models.mbart": ["MBartConfig"], + "models.mbart50": [], + "models.megatron_bert": ["MegatronBertConfig"], + "models.megatron_gpt2": [], + "models.mgp_str": [ + "MgpstrConfig", + "MgpstrProcessor", + "MgpstrTokenizer", + ], + "models.mimi": ["MimiConfig"], + "models.mistral": ["MistralConfig"], + "models.mistral3": ["Mistral3Config"], + "models.mixtral": ["MixtralConfig"], + "models.mllama": [ + "MllamaConfig", + "MllamaProcessor", + ], + "models.mluke": [], + "models.mobilebert": [ + "MobileBertConfig", + "MobileBertTokenizer", + ], + "models.mobilenet_v1": ["MobileNetV1Config"], + "models.mobilenet_v2": ["MobileNetV2Config"], + "models.mobilevit": ["MobileViTConfig"], + "models.mobilevitv2": ["MobileViTV2Config"], + "models.modernbert": ["ModernBertConfig"], + "models.moonshine": ["MoonshineConfig"], + "models.moshi": [ + "MoshiConfig", + "MoshiDepthConfig", + ], + "models.mpnet": [ + "MPNetConfig", + "MPNetTokenizer", + ], + "models.mpt": ["MptConfig"], + "models.mra": ["MraConfig"], + "models.mt5": ["MT5Config"], + "models.musicgen": [ + "MusicgenConfig", + "MusicgenDecoderConfig", + ], + "models.musicgen_melody": [ + "MusicgenMelodyConfig", + "MusicgenMelodyDecoderConfig", + ], + "models.mvp": ["MvpConfig", "MvpTokenizer"], + "models.myt5": ["MyT5Tokenizer"], + "models.nemotron": ["NemotronConfig"], + "models.nllb": [], + "models.nllb_moe": ["NllbMoeConfig"], + "models.nougat": ["NougatProcessor"], + "models.nystromformer": ["NystromformerConfig"], + "models.olmo": ["OlmoConfig"], + "models.olmo2": ["Olmo2Config"], + "models.olmoe": ["OlmoeConfig"], + "models.omdet_turbo": [ + "OmDetTurboConfig", + "OmDetTurboProcessor", + ], + "models.oneformer": [ + "OneFormerConfig", + "OneFormerProcessor", + ], + "models.openai": [ + "OpenAIGPTConfig", + "OpenAIGPTTokenizer", + ], + "models.opt": ["OPTConfig"], + "models.owlv2": [ + "Owlv2Config", + "Owlv2Processor", + "Owlv2TextConfig", + "Owlv2VisionConfig", + ], + "models.owlvit": [ + "OwlViTConfig", + "OwlViTProcessor", + "OwlViTTextConfig", + "OwlViTVisionConfig", + ], + "models.paligemma": ["PaliGemmaConfig"], + "models.patchtsmixer": ["PatchTSMixerConfig"], + "models.patchtst": ["PatchTSTConfig"], + "models.pegasus": [ + "PegasusConfig", + "PegasusTokenizer", + ], + "models.pegasus_x": ["PegasusXConfig"], + "models.perceiver": [ + "PerceiverConfig", + "PerceiverTokenizer", + ], + "models.persimmon": ["PersimmonConfig"], + "models.phi": ["PhiConfig"], + "models.phi3": ["Phi3Config"], + "models.phi4_multimodal": [ + "Phi4MultimodalAudioConfig", + "Phi4MultimodalConfig", + "Phi4MultimodalFeatureExtractor", + "Phi4MultimodalProcessor", + "Phi4MultimodalVisionConfig", + ], + "models.phimoe": ["PhimoeConfig"], + "models.phobert": ["PhobertTokenizer"], + "models.pix2struct": [ + "Pix2StructConfig", + "Pix2StructProcessor", + "Pix2StructTextConfig", + "Pix2StructVisionConfig", + ], + "models.pixtral": ["PixtralProcessor", "PixtralVisionConfig"], + "models.plbart": ["PLBartConfig"], + "models.poolformer": ["PoolFormerConfig"], + "models.pop2piano": ["Pop2PianoConfig"], + "models.prompt_depth_anything": ["PromptDepthAnythingConfig"], + "models.prophetnet": [ + "ProphetNetConfig", + "ProphetNetTokenizer", + ], + "models.pvt": ["PvtConfig"], + "models.pvt_v2": ["PvtV2Config"], + "models.qwen2": [ + "Qwen2Config", + "Qwen2Tokenizer", + ], + "models.qwen2_5_vl": [ + "Qwen2_5_VLConfig", + "Qwen2_5_VLProcessor", + ], + "models.qwen2_audio": [ + "Qwen2AudioConfig", + "Qwen2AudioEncoderConfig", + "Qwen2AudioProcessor", + ], + "models.qwen2_moe": ["Qwen2MoeConfig"], + "models.qwen2_vl": [ + "Qwen2VLConfig", + "Qwen2VLProcessor", + ], + "models.qwen3": ["Qwen3Config"], + "models.qwen3_moe": ["Qwen3MoeConfig"], + "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], + "models.recurrent_gemma": ["RecurrentGemmaConfig"], + "models.reformer": ["ReformerConfig"], + "models.regnet": ["RegNetConfig"], + "models.rembert": ["RemBertConfig"], + "models.resnet": ["ResNetConfig"], + "models.roberta": [ + "RobertaConfig", + "RobertaTokenizer", + ], + "models.roberta_prelayernorm": ["RobertaPreLayerNormConfig"], + "models.roc_bert": [ + "RoCBertConfig", + "RoCBertTokenizer", + ], + "models.roformer": [ + "RoFormerConfig", + "RoFormerTokenizer", + ], + "models.rt_detr": ["RTDetrConfig", "RTDetrResNetConfig"], + "models.rt_detr_v2": ["RTDetrV2Config"], + "models.rwkv": ["RwkvConfig"], + "models.sam": [ + "SamConfig", + "SamMaskDecoderConfig", + "SamProcessor", + "SamPromptEncoderConfig", + "SamVisionConfig", + ], + "models.seamless_m4t": [ + "SeamlessM4TConfig", + "SeamlessM4TFeatureExtractor", + "SeamlessM4TProcessor", + ], + "models.seamless_m4t_v2": ["SeamlessM4Tv2Config"], + "models.segformer": ["SegformerConfig"], + "models.seggpt": ["SegGptConfig"], + "models.sew": ["SEWConfig"], + "models.sew_d": ["SEWDConfig"], + "models.shieldgemma2": [ + "ShieldGemma2Config", + "ShieldGemma2Processor", + ], + "models.siglip": [ + "SiglipConfig", + "SiglipProcessor", + "SiglipTextConfig", + "SiglipVisionConfig", + ], + "models.siglip2": [ + "Siglip2Config", + "Siglip2Processor", + "Siglip2TextConfig", + "Siglip2VisionConfig", + ], + "models.smolvlm": ["SmolVLMConfig"], + "models.speech_encoder_decoder": ["SpeechEncoderDecoderConfig"], + "models.speech_to_text": [ + "Speech2TextConfig", + "Speech2TextFeatureExtractor", + "Speech2TextProcessor", + ], + "models.speecht5": [ + "SpeechT5Config", + "SpeechT5FeatureExtractor", + "SpeechT5HifiGanConfig", + "SpeechT5Processor", + ], + "models.splinter": [ + "SplinterConfig", + "SplinterTokenizer", + ], + "models.squeezebert": [ + "SqueezeBertConfig", + "SqueezeBertTokenizer", + ], + "models.stablelm": ["StableLmConfig"], + "models.starcoder2": ["Starcoder2Config"], + "models.superglue": ["SuperGlueConfig"], + "models.superpoint": ["SuperPointConfig"], + "models.swiftformer": ["SwiftFormerConfig"], + "models.swin": ["SwinConfig"], + "models.swin2sr": ["Swin2SRConfig"], + "models.swinv2": ["Swinv2Config"], + "models.switch_transformers": ["SwitchTransformersConfig"], + "models.t5": ["T5Config"], + "models.table_transformer": ["TableTransformerConfig"], + "models.tapas": [ + "TapasConfig", + "TapasTokenizer", + ], + "models.textnet": ["TextNetConfig"], + "models.time_series_transformer": ["TimeSeriesTransformerConfig"], + "models.timesformer": ["TimesformerConfig"], + "models.timm_backbone": ["TimmBackboneConfig"], + "models.timm_wrapper": ["TimmWrapperConfig"], + "models.trocr": [ + "TrOCRConfig", + "TrOCRProcessor", + ], + "models.tvp": [ + "TvpConfig", + "TvpProcessor", + ], + "models.udop": [ + "UdopConfig", + "UdopProcessor", + ], + "models.umt5": ["UMT5Config"], + "models.unispeech": ["UniSpeechConfig"], + "models.unispeech_sat": ["UniSpeechSatConfig"], + "models.univnet": [ + "UnivNetConfig", + "UnivNetFeatureExtractor", + ], + "models.upernet": ["UperNetConfig"], + "models.video_llava": ["VideoLlavaConfig"], + "models.videomae": ["VideoMAEConfig"], + "models.vilt": [ + "ViltConfig", + "ViltFeatureExtractor", + "ViltImageProcessor", + "ViltProcessor", + ], + "models.vipllava": ["VipLlavaConfig"], + "models.vision_encoder_decoder": ["VisionEncoderDecoderConfig"], + "models.vision_text_dual_encoder": [ + "VisionTextDualEncoderConfig", + "VisionTextDualEncoderProcessor", + ], + "models.visual_bert": ["VisualBertConfig"], + "models.vit": ["ViTConfig"], + "models.vit_mae": ["ViTMAEConfig"], + "models.vit_msn": ["ViTMSNConfig"], + "models.vitdet": ["VitDetConfig"], + "models.vitmatte": ["VitMatteConfig"], + "models.vitpose": ["VitPoseConfig"], + "models.vitpose_backbone": ["VitPoseBackboneConfig"], + "models.vits": [ + "VitsConfig", + "VitsTokenizer", + ], + "models.vivit": ["VivitConfig"], + "models.wav2vec2": [ + "Wav2Vec2Config", + "Wav2Vec2CTCTokenizer", + "Wav2Vec2FeatureExtractor", + "Wav2Vec2Processor", + "Wav2Vec2Tokenizer", + ], + "models.wav2vec2_bert": [ + "Wav2Vec2BertConfig", + "Wav2Vec2BertProcessor", + ], + "models.wav2vec2_conformer": ["Wav2Vec2ConformerConfig"], + "models.wav2vec2_phoneme": ["Wav2Vec2PhonemeCTCTokenizer"], + "models.wav2vec2_with_lm": ["Wav2Vec2ProcessorWithLM"], + "models.wavlm": ["WavLMConfig"], + "models.whisper": [ + "WhisperConfig", + "WhisperFeatureExtractor", + "WhisperProcessor", + "WhisperTokenizer", + ], + "models.x_clip": [ + "XCLIPConfig", + "XCLIPProcessor", + "XCLIPTextConfig", + "XCLIPVisionConfig", + ], + "models.xglm": ["XGLMConfig"], + "models.xlm": ["XLMConfig", "XLMTokenizer"], + "models.xlm_roberta": ["XLMRobertaConfig"], + "models.xlm_roberta_xl": ["XLMRobertaXLConfig"], + "models.xlnet": ["XLNetConfig"], + "models.xmod": ["XmodConfig"], + "models.yolos": ["YolosConfig"], + "models.yoso": ["YosoConfig"], + "models.zamba": ["ZambaConfig"], + "models.zamba2": ["Zamba2Config"], + "models.zoedepth": ["ZoeDepthConfig"], + "onnx": [], + "pipelines": [ + "AudioClassificationPipeline", + "AutomaticSpeechRecognitionPipeline", + "CsvPipelineDataFormat", + "DepthEstimationPipeline", + "DocumentQuestionAnsweringPipeline", + "FeatureExtractionPipeline", + "FillMaskPipeline", + "ImageClassificationPipeline", + "ImageFeatureExtractionPipeline", + "ImageSegmentationPipeline", + "ImageTextToTextPipeline", + "ImageToImagePipeline", + "ImageToTextPipeline", + "JsonPipelineDataFormat", + "MaskGenerationPipeline", + "NerPipeline", + "ObjectDetectionPipeline", + "PipedPipelineDataFormat", + "Pipeline", + "PipelineDataFormat", + "QuestionAnsweringPipeline", + "SummarizationPipeline", + "TableQuestionAnsweringPipeline", + "Text2TextGenerationPipeline", + "TextClassificationPipeline", + "TextGenerationPipeline", + "TextToAudioPipeline", + "TokenClassificationPipeline", + "TranslationPipeline", + "VideoClassificationPipeline", + "VisualQuestionAnsweringPipeline", + "ZeroShotAudioClassificationPipeline", + "ZeroShotClassificationPipeline", + "ZeroShotImageClassificationPipeline", + "ZeroShotObjectDetectionPipeline", + "pipeline", + ], + "processing_utils": ["ProcessorMixin"], + "quantizers": [], + "testing_utils": [], + "tokenization_utils": ["PreTrainedTokenizer"], + "tokenization_utils_base": [ + "AddedToken", + "BatchEncoding", + "CharSpan", + "PreTrainedTokenizerBase", + "SpecialTokensMixin", + "TokenSpan", + ], + "utils": [ + "CONFIG_NAME", + "MODEL_CARD_NAME", + "PYTORCH_PRETRAINED_BERT_CACHE", + "PYTORCH_TRANSFORMERS_CACHE", + "SPIECE_UNDERLINE", + "TRANSFORMERS_CACHE", + "WEIGHTS_NAME", + "TensorType", + "add_end_docstrings", + "add_start_docstrings", + "is_apex_available", + "is_av_available", + "is_bitsandbytes_available", + "is_datasets_available", + "is_faiss_available", + "is_flax_available", + "is_keras_nlp_available", + "is_phonemizer_available", + "is_psutil_available", + "is_py3nvml_available", + "is_pyctcdecode_available", + "is_sacremoses_available", + "is_safetensors_available", + "is_scipy_available", + "is_sentencepiece_available", + "is_sklearn_available", + "is_speech_available", + "is_tensorflow_text_available", + "is_timm_available", + "is_tokenizers_available", + "is_torch_available", + "is_torch_hpu_available", + "is_torch_mlu_available", + "is_torch_musa_available", + "is_torch_neuroncore_available", + "is_torch_npu_available", + "is_torchvision_available", + "is_torch_xla_available", + "is_torch_xpu_available", + "is_vision_available", + "logging", + ], + "utils.quantization_config": [ + "AqlmConfig", + "AwqConfig", + "BitNetConfig", + "BitsAndBytesConfig", + "CompressedTensorsConfig", + "EetqConfig", + "FbgemmFp8Config", + "FineGrainedFP8Config", + "GPTQConfig", + "HiggsConfig", + "HqqConfig", + "QuantoConfig", + "QuarkConfig", + "SpQRConfig", + "TorchAoConfig", + "VptqConfig", + ], +} + +# sentencepiece-backed objects +try: + if not is_sentencepiece_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_sentencepiece_objects + + _import_structure["utils.dummy_sentencepiece_objects"] = [ + name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_") + ] +else: + _import_structure["models.albert"].append("AlbertTokenizer") + _import_structure["models.barthez"].append("BarthezTokenizer") + _import_structure["models.bartpho"].append("BartphoTokenizer") + _import_structure["models.bert_generation"].append("BertGenerationTokenizer") + _import_structure["models.big_bird"].append("BigBirdTokenizer") + _import_structure["models.camembert"].append("CamembertTokenizer") + _import_structure["models.code_llama"].append("CodeLlamaTokenizer") + _import_structure["models.cpm"].append("CpmTokenizer") + _import_structure["models.deberta_v2"].append("DebertaV2Tokenizer") + _import_structure["models.deprecated.ernie_m"].append("ErnieMTokenizer") + _import_structure["models.deprecated.xlm_prophetnet"].append("XLMProphetNetTokenizer") + _import_structure["models.fnet"].append("FNetTokenizer") + _import_structure["models.gemma"].append("GemmaTokenizer") + _import_structure["models.gpt_sw3"].append("GPTSw3Tokenizer") + _import_structure["models.layoutxlm"].append("LayoutXLMTokenizer") + _import_structure["models.llama"].append("LlamaTokenizer") + _import_structure["models.m2m_100"].append("M2M100Tokenizer") + _import_structure["models.marian"].append("MarianTokenizer") + _import_structure["models.mbart"].append("MBartTokenizer") + _import_structure["models.mbart50"].append("MBart50Tokenizer") + _import_structure["models.mluke"].append("MLukeTokenizer") + _import_structure["models.mt5"].append("MT5Tokenizer") + _import_structure["models.nllb"].append("NllbTokenizer") + _import_structure["models.pegasus"].append("PegasusTokenizer") + _import_structure["models.plbart"].append("PLBartTokenizer") + _import_structure["models.reformer"].append("ReformerTokenizer") + _import_structure["models.rembert"].append("RemBertTokenizer") + _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizer") + _import_structure["models.siglip"].append("SiglipTokenizer") + _import_structure["models.speech_to_text"].append("Speech2TextTokenizer") + _import_structure["models.speecht5"].append("SpeechT5Tokenizer") + _import_structure["models.t5"].append("T5Tokenizer") + _import_structure["models.udop"].append("UdopTokenizer") + _import_structure["models.xglm"].append("XGLMTokenizer") + _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") + _import_structure["models.xlnet"].append("XLNetTokenizer") + +# tokenizers-backed objects +try: + if not is_tokenizers_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_tokenizers_objects + + _import_structure["utils.dummy_tokenizers_objects"] = [ + name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") + ] +else: + # Fast tokenizers structure + _import_structure["models.albert"].append("AlbertTokenizerFast") + _import_structure["models.bart"].append("BartTokenizerFast") + _import_structure["models.barthez"].append("BarthezTokenizerFast") + _import_structure["models.bert"].append("BertTokenizerFast") + _import_structure["models.big_bird"].append("BigBirdTokenizerFast") + _import_structure["models.blenderbot"].append("BlenderbotTokenizerFast") + _import_structure["models.blenderbot_small"].append("BlenderbotSmallTokenizerFast") + _import_structure["models.bloom"].append("BloomTokenizerFast") + _import_structure["models.camembert"].append("CamembertTokenizerFast") + _import_structure["models.clip"].append("CLIPTokenizerFast") + _import_structure["models.code_llama"].append("CodeLlamaTokenizerFast") + _import_structure["models.codegen"].append("CodeGenTokenizerFast") + _import_structure["models.cohere"].append("CohereTokenizerFast") + _import_structure["models.convbert"].append("ConvBertTokenizerFast") + _import_structure["models.cpm"].append("CpmTokenizerFast") + _import_structure["models.deberta"].append("DebertaTokenizerFast") + _import_structure["models.deberta_v2"].append("DebertaV2TokenizerFast") + _import_structure["models.deprecated.realm"].append("RealmTokenizerFast") + _import_structure["models.deprecated.retribert"].append("RetriBertTokenizerFast") + _import_structure["models.distilbert"].append("DistilBertTokenizerFast") + _import_structure["models.dpr"].extend( + [ + "DPRContextEncoderTokenizerFast", + "DPRQuestionEncoderTokenizerFast", + "DPRReaderTokenizerFast", + ] + ) + _import_structure["models.electra"].append("ElectraTokenizerFast") + _import_structure["models.fnet"].append("FNetTokenizerFast") + _import_structure["models.funnel"].append("FunnelTokenizerFast") + _import_structure["models.gemma"].append("GemmaTokenizerFast") + _import_structure["models.gpt2"].append("GPT2TokenizerFast") + _import_structure["models.gpt_neox"].append("GPTNeoXTokenizerFast") + _import_structure["models.gpt_neox_japanese"].append("GPTNeoXJapaneseTokenizer") + _import_structure["models.herbert"].append("HerbertTokenizerFast") + _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") + _import_structure["models.layoutlmv2"].append("LayoutLMv2TokenizerFast") + _import_structure["models.layoutlmv3"].append("LayoutLMv3TokenizerFast") + _import_structure["models.layoutxlm"].append("LayoutXLMTokenizerFast") + _import_structure["models.led"].append("LEDTokenizerFast") + _import_structure["models.llama"].append("LlamaTokenizerFast") + _import_structure["models.longformer"].append("LongformerTokenizerFast") + _import_structure["models.lxmert"].append("LxmertTokenizerFast") + _import_structure["models.markuplm"].append("MarkupLMTokenizerFast") + _import_structure["models.mbart"].append("MBartTokenizerFast") + _import_structure["models.mbart50"].append("MBart50TokenizerFast") + _import_structure["models.mobilebert"].append("MobileBertTokenizerFast") + _import_structure["models.mpnet"].append("MPNetTokenizerFast") + _import_structure["models.mt5"].append("MT5TokenizerFast") + _import_structure["models.mvp"].append("MvpTokenizerFast") + _import_structure["models.nllb"].append("NllbTokenizerFast") + _import_structure["models.nougat"].append("NougatTokenizerFast") + _import_structure["models.openai"].append("OpenAIGPTTokenizerFast") + _import_structure["models.pegasus"].append("PegasusTokenizerFast") + _import_structure["models.qwen2"].append("Qwen2TokenizerFast") + _import_structure["models.reformer"].append("ReformerTokenizerFast") + _import_structure["models.rembert"].append("RemBertTokenizerFast") + _import_structure["models.roberta"].append("RobertaTokenizerFast") + _import_structure["models.roformer"].append("RoFormerTokenizerFast") + _import_structure["models.seamless_m4t"].append("SeamlessM4TTokenizerFast") + _import_structure["models.splinter"].append("SplinterTokenizerFast") + _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") + _import_structure["models.t5"].append("T5TokenizerFast") + _import_structure["models.udop"].append("UdopTokenizerFast") + _import_structure["models.whisper"].append("WhisperTokenizerFast") + _import_structure["models.xglm"].append("XGLMTokenizerFast") + _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") + _import_structure["models.xlnet"].append("XLNetTokenizerFast") + _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"] + + +try: + if not (is_sentencepiece_available() and is_tokenizers_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_sentencepiece_and_tokenizers_objects + + _import_structure["utils.dummy_sentencepiece_and_tokenizers_objects"] = [ + name for name in dir(dummy_sentencepiece_and_tokenizers_objects) if not name.startswith("_") + ] +else: + _import_structure["convert_slow_tokenizer"] = [ + "SLOW_TO_FAST_CONVERTERS", + "convert_slow_tokenizer", + ] + +# Vision-specific objects +try: + if not is_vision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_vision_objects + + _import_structure["utils.dummy_vision_objects"] = [ + name for name in dir(dummy_vision_objects) if not name.startswith("_") + ] +else: + _import_structure["image_processing_base"] = ["ImageProcessingMixin"] + _import_structure["image_processing_utils"] = ["BaseImageProcessor"] + _import_structure["image_utils"] = ["ImageFeatureExtractionMixin"] + _import_structure["models.aria"].extend(["AriaImageProcessor"]) + _import_structure["models.beit"].extend(["BeitFeatureExtractor", "BeitImageProcessor"]) + _import_structure["models.bit"].extend(["BitImageProcessor"]) + _import_structure["models.blip"].extend(["BlipImageProcessor"]) + _import_structure["models.bridgetower"].append("BridgeTowerImageProcessor") + _import_structure["models.chameleon"].append("ChameleonImageProcessor") + _import_structure["models.chinese_clip"].extend(["ChineseCLIPFeatureExtractor", "ChineseCLIPImageProcessor"]) + _import_structure["models.clip"].extend(["CLIPFeatureExtractor", "CLIPImageProcessor"]) + _import_structure["models.conditional_detr"].extend( + ["ConditionalDetrFeatureExtractor", "ConditionalDetrImageProcessor"] + ) + _import_structure["models.convnext"].extend(["ConvNextFeatureExtractor", "ConvNextImageProcessor"]) + _import_structure["models.deformable_detr"].extend( + ["DeformableDetrFeatureExtractor", "DeformableDetrImageProcessor"] + ) + _import_structure["models.deit"].extend(["DeiTFeatureExtractor", "DeiTImageProcessor"]) + _import_structure["models.deprecated.deta"].append("DetaImageProcessor") + _import_structure["models.deprecated.efficientformer"].append("EfficientFormerImageProcessor") + _import_structure["models.deprecated.tvlt"].append("TvltImageProcessor") + _import_structure["models.deprecated.vit_hybrid"].extend(["ViTHybridImageProcessor"]) + _import_structure["models.depth_pro"].extend(["DepthProImageProcessor", "DepthProImageProcessorFast"]) + _import_structure["models.detr"].extend(["DetrFeatureExtractor", "DetrImageProcessor"]) + _import_structure["models.donut"].extend(["DonutFeatureExtractor", "DonutImageProcessor"]) + _import_structure["models.dpt"].extend(["DPTFeatureExtractor", "DPTImageProcessor"]) + _import_structure["models.efficientnet"].append("EfficientNetImageProcessor") + _import_structure["models.emu3"].append("Emu3ImageProcessor") + _import_structure["models.flava"].extend(["FlavaFeatureExtractor", "FlavaImageProcessor", "FlavaProcessor"]) + _import_structure["models.fuyu"].extend(["FuyuImageProcessor", "FuyuProcessor"]) + _import_structure["models.gemma3"].append("Gemma3ImageProcessor") + _import_structure["models.glpn"].extend(["GLPNFeatureExtractor", "GLPNImageProcessor"]) + _import_structure["models.got_ocr2"].extend(["GotOcr2ImageProcessor"]) + _import_structure["models.grounding_dino"].extend(["GroundingDinoImageProcessor"]) + _import_structure["models.idefics"].extend(["IdeficsImageProcessor"]) + _import_structure["models.idefics2"].extend(["Idefics2ImageProcessor"]) + _import_structure["models.idefics3"].extend(["Idefics3ImageProcessor"]) + _import_structure["models.imagegpt"].extend(["ImageGPTFeatureExtractor", "ImageGPTImageProcessor"]) + _import_structure["models.instructblipvideo"].extend(["InstructBlipVideoImageProcessor"]) + _import_structure["models.layoutlmv2"].extend(["LayoutLMv2FeatureExtractor", "LayoutLMv2ImageProcessor"]) + _import_structure["models.layoutlmv3"].extend(["LayoutLMv3FeatureExtractor", "LayoutLMv3ImageProcessor"]) + _import_structure["models.levit"].extend(["LevitFeatureExtractor", "LevitImageProcessor"]) + _import_structure["models.llava"].append("LlavaImageProcessor") + _import_structure["models.llava_next"].append("LlavaNextImageProcessor") + _import_structure["models.llava_next_video"].append("LlavaNextVideoImageProcessor") + _import_structure["models.llava_onevision"].extend( + ["LlavaOnevisionImageProcessor", "LlavaOnevisionVideoProcessor"] + ) + _import_structure["models.mask2former"].append("Mask2FormerImageProcessor") + _import_structure["models.maskformer"].extend(["MaskFormerFeatureExtractor", "MaskFormerImageProcessor"]) + _import_structure["models.mllama"].extend(["MllamaImageProcessor"]) + _import_structure["models.mobilenet_v1"].extend(["MobileNetV1FeatureExtractor", "MobileNetV1ImageProcessor"]) + _import_structure["models.mobilenet_v2"].extend(["MobileNetV2FeatureExtractor", "MobileNetV2ImageProcessor"]) + _import_structure["models.mobilevit"].extend(["MobileViTFeatureExtractor", "MobileViTImageProcessor"]) + _import_structure["models.nougat"].append("NougatImageProcessor") + _import_structure["models.oneformer"].extend(["OneFormerImageProcessor"]) + _import_structure["models.owlv2"].append("Owlv2ImageProcessor") + _import_structure["models.owlvit"].extend(["OwlViTFeatureExtractor", "OwlViTImageProcessor"]) + _import_structure["models.perceiver"].extend(["PerceiverFeatureExtractor", "PerceiverImageProcessor"]) + _import_structure["models.pix2struct"].extend(["Pix2StructImageProcessor"]) + _import_structure["models.pixtral"].append("PixtralImageProcessor") + _import_structure["models.poolformer"].extend(["PoolFormerFeatureExtractor", "PoolFormerImageProcessor"]) + _import_structure["models.prompt_depth_anything"].extend(["PromptDepthAnythingImageProcessor"]) + _import_structure["models.pvt"].extend(["PvtImageProcessor"]) + _import_structure["models.qwen2_vl"].extend(["Qwen2VLImageProcessor"]) + _import_structure["models.rt_detr"].extend(["RTDetrImageProcessor"]) + _import_structure["models.sam"].extend(["SamImageProcessor"]) + _import_structure["models.segformer"].extend(["SegformerFeatureExtractor", "SegformerImageProcessor"]) + _import_structure["models.seggpt"].extend(["SegGptImageProcessor"]) + _import_structure["models.siglip"].append("SiglipImageProcessor") + _import_structure["models.siglip2"].append("Siglip2ImageProcessor") + _import_structure["models.smolvlm"].extend(["SmolVLMImageProcessor"]) + _import_structure["models.superglue"].extend(["SuperGlueImageProcessor"]) + _import_structure["models.superpoint"].extend(["SuperPointImageProcessor"]) + _import_structure["models.swin2sr"].append("Swin2SRImageProcessor") + _import_structure["models.textnet"].extend(["TextNetImageProcessor"]) + _import_structure["models.tvp"].append("TvpImageProcessor") + _import_structure["models.video_llava"].append("VideoLlavaImageProcessor") + _import_structure["models.videomae"].extend(["VideoMAEFeatureExtractor", "VideoMAEImageProcessor"]) + _import_structure["models.vilt"].extend(["ViltFeatureExtractor", "ViltImageProcessor", "ViltProcessor"]) + _import_structure["models.vit"].extend(["ViTFeatureExtractor", "ViTImageProcessor"]) + _import_structure["models.vitmatte"].append("VitMatteImageProcessor") + _import_structure["models.vitpose"].append("VitPoseImageProcessor") + _import_structure["models.vivit"].append("VivitImageProcessor") + _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"]) + _import_structure["models.zoedepth"].append("ZoeDepthImageProcessor") + +try: + if not is_torchvision_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_torchvision_objects + + _import_structure["utils.dummy_torchvision_objects"] = [ + name for name in dir(dummy_torchvision_objects) if not name.startswith("_") + ] +else: + _import_structure["image_processing_utils_fast"] = ["BaseImageProcessorFast"] + _import_structure["models.blip"].append("BlipImageProcessorFast") + _import_structure["models.clip"].append("CLIPImageProcessorFast") + _import_structure["models.convnext"].append("ConvNextImageProcessorFast") + _import_structure["models.deformable_detr"].append("DeformableDetrImageProcessorFast") + _import_structure["models.deit"].append("DeiTImageProcessorFast") + _import_structure["models.depth_pro"].append("DepthProImageProcessorFast") + _import_structure["models.detr"].append("DetrImageProcessorFast") + _import_structure["models.gemma3"].append("Gemma3ImageProcessorFast") + _import_structure["models.got_ocr2"].append("GotOcr2ImageProcessorFast") + _import_structure["models.llama4"].append("Llama4ImageProcessorFast") + _import_structure["models.llava"].append("LlavaImageProcessorFast") + _import_structure["models.llava_next"].append("LlavaNextImageProcessorFast") + _import_structure["models.llava_onevision"].append("LlavaOnevisionImageProcessorFast") + _import_structure["models.phi4_multimodal"].append("Phi4MultimodalImageProcessorFast") + _import_structure["models.pixtral"].append("PixtralImageProcessorFast") + _import_structure["models.qwen2_vl"].append("Qwen2VLImageProcessorFast") + _import_structure["models.rt_detr"].append("RTDetrImageProcessorFast") + _import_structure["models.siglip"].append("SiglipImageProcessorFast") + _import_structure["models.siglip2"].append("Siglip2ImageProcessorFast") + _import_structure["models.vit"].append("ViTImageProcessorFast") + +try: + if not (is_torchvision_available() and is_timm_available()): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_timm_and_torchvision_objects + + _import_structure["utils.dummy_timm_and_torchvision_objects"] = [ + name for name in dir(dummy_timm_and_torchvision_objects) if not name.startswith("_") + ] +else: + _import_structure["models.timm_wrapper"].extend(["TimmWrapperImageProcessor"]) + +# PyTorch-backed objects +try: + if not is_torch_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_pt_objects + + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] +else: + _import_structure["model_debugging_utils"] = [ + "model_addition_debugger", + "model_addition_debugger_context", + ] + _import_structure["activations"] = [] + _import_structure["cache_utils"] = [ + "Cache", + "CacheConfig", + "DynamicCache", + "EncoderDecoderCache", + "HQQQuantizedCache", + "HybridCache", + "MambaCache", + "OffloadedCache", + "OffloadedStaticCache", + "QuantizedCache", + "QuantizedCacheConfig", + "QuantoQuantizedCache", + "SinkCache", + "SlidingWindowCache", + "StaticCache", + ] + _import_structure["data.datasets"] = [ + "GlueDataset", + "GlueDataTrainingArguments", + "LineByLineTextDataset", + "LineByLineWithRefDataset", + "LineByLineWithSOPTextDataset", + "SquadDataset", + "SquadDataTrainingArguments", + "TextDataset", + "TextDatasetForNextSentencePrediction", + ] + _import_structure["generation"].extend( + [ + "AlternatingCodebooksLogitsProcessor", + "BayesianDetectorConfig", + "BayesianDetectorModel", + "BeamScorer", + "BeamSearchScorer", + "ClassifierFreeGuidanceLogitsProcessor", + "ConstrainedBeamSearchScorer", + "Constraint", + "ConstraintListState", + "DisjunctiveConstraint", + "EncoderNoRepeatNGramLogitsProcessor", + "EncoderRepetitionPenaltyLogitsProcessor", + "EosTokenCriteria", + "EpsilonLogitsWarper", + "EtaLogitsWarper", + "ExponentialDecayLengthPenalty", + "ForcedBOSTokenLogitsProcessor", + "ForcedEOSTokenLogitsProcessor", + "GenerationMixin", + "HammingDiversityLogitsProcessor", + "InfNanRemoveLogitsProcessor", + "LogitNormalization", + "LogitsProcessor", + "LogitsProcessorList", + "MaxLengthCriteria", + "MaxTimeCriteria", + "MinLengthLogitsProcessor", + "MinNewTokensLengthLogitsProcessor", + "MinPLogitsWarper", + "NoBadWordsLogitsProcessor", + "NoRepeatNGramLogitsProcessor", + "PhrasalConstraint", + "PrefixConstrainedLogitsProcessor", + "RepetitionPenaltyLogitsProcessor", + "SequenceBiasLogitsProcessor", + "StoppingCriteria", + "StoppingCriteriaList", + "StopStringCriteria", + "SuppressTokensAtBeginLogitsProcessor", + "SuppressTokensLogitsProcessor", + "SynthIDTextWatermarkDetector", + "SynthIDTextWatermarkingConfig", + "SynthIDTextWatermarkLogitsProcessor", + "TemperatureLogitsWarper", + "TopKLogitsWarper", + "TopPLogitsWarper", + "TypicalLogitsWarper", + "UnbatchedClassifierFreeGuidanceLogitsProcessor", + "WatermarkDetector", + "WatermarkLogitsProcessor", + "WhisperTimeStampLogitsProcessor", + ] + ) + + _import_structure["modeling_flash_attention_utils"] = [] + _import_structure["modeling_outputs"] = [] + _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS", "dynamic_rope_update"] + _import_structure["modeling_utils"] = ["PreTrainedModel", "AttentionInterface"] + + # PyTorch models structure + + _import_structure["models.albert"].extend( + [ + "AlbertForMaskedLM", + "AlbertForMultipleChoice", + "AlbertForPreTraining", + "AlbertForQuestionAnswering", + "AlbertForSequenceClassification", + "AlbertForTokenClassification", + "AlbertModel", + "AlbertPreTrainedModel", + ] + ) + + _import_structure["models.align"].extend( + [ + "AlignModel", + "AlignPreTrainedModel", + "AlignTextModel", + "AlignVisionModel", + ] + ) + _import_structure["models.altclip"].extend( + [ + "AltCLIPModel", + "AltCLIPPreTrainedModel", + "AltCLIPTextModel", + "AltCLIPVisionModel", + ] + ) + _import_structure["models.aria"].extend( + [ + "AriaForConditionalGeneration", + "AriaPreTrainedModel", + "AriaTextForCausalLM", + "AriaTextModel", + "AriaTextPreTrainedModel", + ] + ) + _import_structure["models.audio_spectrogram_transformer"].extend( + [ + "ASTForAudioClassification", + "ASTModel", + "ASTPreTrainedModel", + ] + ) + _import_structure["models.auto"].extend( + [ + "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", + "MODEL_FOR_AUDIO_XVECTOR_MAPPING", + "MODEL_FOR_BACKBONE_MAPPING", + "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", + "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_CTC_MAPPING", + "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", + "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_IMAGE_MAPPING", + "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", + "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", + "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", + "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", + "MODEL_FOR_KEYPOINT_DETECTION_MAPPING", + "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", + "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MASK_GENERATION_MAPPING", + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "MODEL_FOR_OBJECT_DETECTION_MAPPING", + "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_RETRIEVAL_MAPPING", + "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TEXT_ENCODING_MAPPING", + "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", + "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", + "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", + "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", + "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", + "MODEL_FOR_VISION_2_SEQ_MAPPING", + "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", + "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", + "MODEL_MAPPING", + "MODEL_WITH_LM_HEAD_MAPPING", + "AutoBackbone", + "AutoModel", + "AutoModelForAudioClassification", + "AutoModelForAudioFrameClassification", + "AutoModelForAudioXVector", + "AutoModelForCausalLM", + "AutoModelForCTC", + "AutoModelForDepthEstimation", + "AutoModelForDocumentQuestionAnswering", + "AutoModelForImageClassification", + "AutoModelForImageSegmentation", + "AutoModelForImageTextToText", + "AutoModelForImageToImage", + "AutoModelForInstanceSegmentation", + "AutoModelForKeypointDetection", + "AutoModelForMaskedImageModeling", + "AutoModelForMaskedLM", + "AutoModelForMaskGeneration", + "AutoModelForMultipleChoice", + "AutoModelForNextSentencePrediction", + "AutoModelForObjectDetection", + "AutoModelForPreTraining", + "AutoModelForQuestionAnswering", + "AutoModelForSemanticSegmentation", + "AutoModelForSeq2SeqLM", + "AutoModelForSequenceClassification", + "AutoModelForSpeechSeq2Seq", + "AutoModelForTableQuestionAnswering", + "AutoModelForTextEncoding", + "AutoModelForTextToSpectrogram", + "AutoModelForTextToWaveform", + "AutoModelForTokenClassification", + "AutoModelForUniversalSegmentation", + "AutoModelForVideoClassification", + "AutoModelForVision2Seq", + "AutoModelForVisualQuestionAnswering", + "AutoModelForZeroShotImageClassification", + "AutoModelForZeroShotObjectDetection", + "AutoModelWithLMHead", + ] + ) + _import_structure["models.autoformer"].extend( + [ + "AutoformerForPrediction", + "AutoformerModel", + "AutoformerPreTrainedModel", + ] + ) + _import_structure["models.aya_vision"].extend(["AyaVisionForConditionalGeneration", "AyaVisionPreTrainedModel"]) + _import_structure["models.bamba"].extend( + [ + "BambaForCausalLM", + "BambaModel", + "BambaPreTrainedModel", + ] + ) + _import_structure["models.bark"].extend( + [ + "BarkCausalModel", + "BarkCoarseModel", + "BarkFineModel", + "BarkModel", + "BarkPreTrainedModel", + "BarkSemanticModel", + ] + ) + _import_structure["models.bart"].extend( + [ + "BartForCausalLM", + "BartForConditionalGeneration", + "BartForQuestionAnswering", + "BartForSequenceClassification", + "BartModel", + "BartPretrainedModel", + "BartPreTrainedModel", + "PretrainedBartModel", + ] + ) + _import_structure["models.beit"].extend( + [ + "BeitBackbone", + "BeitForImageClassification", + "BeitForMaskedImageModeling", + "BeitForSemanticSegmentation", + "BeitModel", + "BeitPreTrainedModel", + ] + ) + _import_structure["models.bert"].extend( + [ + "BertForMaskedLM", + "BertForMultipleChoice", + "BertForNextSentencePrediction", + "BertForPreTraining", + "BertForQuestionAnswering", + "BertForSequenceClassification", + "BertForTokenClassification", + "BertLMHeadModel", + "BertModel", + "BertPreTrainedModel", + ] + ) + _import_structure["models.bert_generation"].extend( + [ + "BertGenerationDecoder", + "BertGenerationEncoder", + "BertGenerationPreTrainedModel", + ] + ) + _import_structure["models.big_bird"].extend( + [ + "BigBirdForCausalLM", + "BigBirdForMaskedLM", + "BigBirdForMultipleChoice", + "BigBirdForPreTraining", + "BigBirdForQuestionAnswering", + "BigBirdForSequenceClassification", + "BigBirdForTokenClassification", + "BigBirdModel", + "BigBirdPreTrainedModel", + ] + ) + _import_structure["models.bigbird_pegasus"].extend( + [ + "BigBirdPegasusForCausalLM", + "BigBirdPegasusForConditionalGeneration", + "BigBirdPegasusForQuestionAnswering", + "BigBirdPegasusForSequenceClassification", + "BigBirdPegasusModel", + "BigBirdPegasusPreTrainedModel", + ] + ) + _import_structure["models.biogpt"].extend( + [ + "BioGptForCausalLM", + "BioGptForSequenceClassification", + "BioGptForTokenClassification", + "BioGptModel", + "BioGptPreTrainedModel", + ] + ) + _import_structure["models.bit"].extend( + [ + "BitBackbone", + "BitForImageClassification", + "BitModel", + "BitPreTrainedModel", + ] + ) + _import_structure["models.blenderbot"].extend( + [ + "BlenderbotForCausalLM", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + "BlenderbotPreTrainedModel", + ] + ) + _import_structure["models.blenderbot_small"].extend( + [ + "BlenderbotSmallForCausalLM", + "BlenderbotSmallForConditionalGeneration", + "BlenderbotSmallModel", + "BlenderbotSmallPreTrainedModel", + ] + ) + _import_structure["models.blip"].extend( + [ + "BlipForConditionalGeneration", + "BlipForImageTextRetrieval", + "BlipForQuestionAnswering", + "BlipModel", + "BlipPreTrainedModel", + "BlipTextModel", + "BlipVisionModel", + ] + ) + _import_structure["models.blip_2"].extend( + [ + "Blip2ForConditionalGeneration", + "Blip2ForImageTextRetrieval", + "Blip2Model", + "Blip2PreTrainedModel", + "Blip2QFormerModel", + "Blip2TextModelWithProjection", + "Blip2VisionModel", + "Blip2VisionModelWithProjection", + ] + ) + _import_structure["models.bloom"].extend( + [ + "BloomForCausalLM", + "BloomForQuestionAnswering", + "BloomForSequenceClassification", + "BloomForTokenClassification", + "BloomModel", + "BloomPreTrainedModel", + ] + ) + _import_structure["models.bridgetower"].extend( + [ + "BridgeTowerForContrastiveLearning", + "BridgeTowerForImageAndTextRetrieval", + "BridgeTowerForMaskedLM", + "BridgeTowerModel", + "BridgeTowerPreTrainedModel", + ] + ) + _import_structure["models.bros"].extend( + [ + "BrosForTokenClassification", + "BrosModel", + "BrosPreTrainedModel", + "BrosProcessor", + "BrosSpadeEEForTokenClassification", + "BrosSpadeELForTokenClassification", + ] + ) + _import_structure["models.camembert"].extend( + [ + "CamembertForCausalLM", + "CamembertForMaskedLM", + "CamembertForMultipleChoice", + "CamembertForQuestionAnswering", + "CamembertForSequenceClassification", + "CamembertForTokenClassification", + "CamembertModel", + "CamembertPreTrainedModel", + ] + ) + _import_structure["models.canine"].extend( + [ + "CanineForMultipleChoice", + "CanineForQuestionAnswering", + "CanineForSequenceClassification", + "CanineForTokenClassification", + "CanineModel", + "CaninePreTrainedModel", + ] + ) + _import_structure["models.chameleon"].extend( + [ + "ChameleonForConditionalGeneration", + "ChameleonModel", + "ChameleonPreTrainedModel", + "ChameleonProcessor", + "ChameleonVQVAE", + ] + ) + _import_structure["models.chinese_clip"].extend( + [ + "ChineseCLIPModel", + "ChineseCLIPPreTrainedModel", + "ChineseCLIPTextModel", + "ChineseCLIPVisionModel", + ] + ) + _import_structure["models.clap"].extend( + [ + "ClapAudioModel", + "ClapAudioModelWithProjection", + "ClapFeatureExtractor", + "ClapModel", + "ClapPreTrainedModel", + "ClapTextModel", + "ClapTextModelWithProjection", + ] + ) + _import_structure["models.clip"].extend( + [ + "CLIPForImageClassification", + "CLIPModel", + "CLIPPreTrainedModel", + "CLIPTextModel", + "CLIPTextModelWithProjection", + "CLIPVisionModel", + "CLIPVisionModelWithProjection", + ] + ) + _import_structure["models.clipseg"].extend( + [ + "CLIPSegForImageSegmentation", + "CLIPSegModel", + "CLIPSegPreTrainedModel", + "CLIPSegTextModel", + "CLIPSegVisionModel", + ] + ) + _import_structure["models.clvp"].extend( + [ + "ClvpDecoder", + "ClvpEncoder", + "ClvpForCausalLM", + "ClvpModel", + "ClvpModelForConditionalGeneration", + "ClvpPreTrainedModel", + ] + ) + _import_structure["models.codegen"].extend( + [ + "CodeGenForCausalLM", + "CodeGenModel", + "CodeGenPreTrainedModel", + ] + ) + _import_structure["models.cohere"].extend(["CohereForCausalLM", "CohereModel", "CoherePreTrainedModel"]) + _import_structure["models.cohere2"].extend(["Cohere2ForCausalLM", "Cohere2Model", "Cohere2PreTrainedModel"]) + _import_structure["models.colpali"].extend( + [ + "ColPaliForRetrieval", + "ColPaliPreTrainedModel", + ] + ) + _import_structure["models.conditional_detr"].extend( + [ + "ConditionalDetrForObjectDetection", + "ConditionalDetrForSegmentation", + "ConditionalDetrModel", + "ConditionalDetrPreTrainedModel", + ] + ) + _import_structure["models.convbert"].extend( + [ + "ConvBertForMaskedLM", + "ConvBertForMultipleChoice", + "ConvBertForQuestionAnswering", + "ConvBertForSequenceClassification", + "ConvBertForTokenClassification", + "ConvBertModel", + "ConvBertPreTrainedModel", + ] + ) + _import_structure["models.convnext"].extend( + [ + "ConvNextBackbone", + "ConvNextForImageClassification", + "ConvNextModel", + "ConvNextPreTrainedModel", + ] + ) + _import_structure["models.convnextv2"].extend( + [ + "ConvNextV2Backbone", + "ConvNextV2ForImageClassification", + "ConvNextV2Model", + "ConvNextV2PreTrainedModel", + ] + ) + _import_structure["models.cpmant"].extend( + [ + "CpmAntForCausalLM", + "CpmAntModel", + "CpmAntPreTrainedModel", + ] + ) + _import_structure["models.ctrl"].extend( + [ + "CTRLForSequenceClassification", + "CTRLLMHeadModel", + "CTRLModel", + "CTRLPreTrainedModel", + ] + ) + _import_structure["models.cvt"].extend( + [ + "CvtForImageClassification", + "CvtModel", + "CvtPreTrainedModel", + ] + ) + _import_structure["models.dab_detr"].extend( + [ + "DabDetrForObjectDetection", + "DabDetrModel", + "DabDetrPreTrainedModel", + ] + ) + _import_structure["models.dac"].extend( + [ + "DacModel", + "DacPreTrainedModel", + ] + ) + _import_structure["models.data2vec"].extend( + [ + "Data2VecAudioForAudioFrameClassification", + "Data2VecAudioForCTC", + "Data2VecAudioForSequenceClassification", + "Data2VecAudioForXVector", + "Data2VecAudioModel", + "Data2VecAudioPreTrainedModel", + "Data2VecTextForCausalLM", + "Data2VecTextForMaskedLM", + "Data2VecTextForMultipleChoice", + "Data2VecTextForQuestionAnswering", + "Data2VecTextForSequenceClassification", + "Data2VecTextForTokenClassification", + "Data2VecTextModel", + "Data2VecTextPreTrainedModel", + "Data2VecVisionForImageClassification", + "Data2VecVisionForSemanticSegmentation", + "Data2VecVisionModel", + "Data2VecVisionPreTrainedModel", + ] + ) + _import_structure["models.dbrx"].extend( + [ + "DbrxForCausalLM", + "DbrxModel", + "DbrxPreTrainedModel", + ] + ) + _import_structure["models.deberta"].extend( + [ + "DebertaForMaskedLM", + "DebertaForQuestionAnswering", + "DebertaForSequenceClassification", + "DebertaForTokenClassification", + "DebertaModel", + "DebertaPreTrainedModel", + ] + ) + _import_structure["models.deberta_v2"].extend( + [ + "DebertaV2ForMaskedLM", + "DebertaV2ForMultipleChoice", + "DebertaV2ForQuestionAnswering", + "DebertaV2ForSequenceClassification", + "DebertaV2ForTokenClassification", + "DebertaV2Model", + "DebertaV2PreTrainedModel", + ] + ) + _import_structure["models.decision_transformer"].extend( + [ + "DecisionTransformerGPT2Model", + "DecisionTransformerGPT2PreTrainedModel", + "DecisionTransformerModel", + "DecisionTransformerPreTrainedModel", + ] + ) + _import_structure["models.deepseek_v3"].extend( + [ + "DeepseekV3ForCausalLM", + "DeepseekV3Model", + "DeepseekV3PreTrainedModel", + ] + ) + _import_structure["models.deformable_detr"].extend( + [ + "DeformableDetrForObjectDetection", + "DeformableDetrModel", + "DeformableDetrPreTrainedModel", + ] + ) + _import_structure["models.deit"].extend( + [ + "DeiTForImageClassification", + "DeiTForImageClassificationWithTeacher", + "DeiTForMaskedImageModeling", + "DeiTModel", + "DeiTPreTrainedModel", + ] + ) + _import_structure["models.deprecated.deta"].extend( + [ + "DetaForObjectDetection", + "DetaModel", + "DetaPreTrainedModel", + ] + ) + _import_structure["models.deprecated.efficientformer"].extend( + [ + "EfficientFormerForImageClassification", + "EfficientFormerForImageClassificationWithTeacher", + "EfficientFormerModel", + "EfficientFormerPreTrainedModel", + ] + ) + _import_structure["models.deprecated.ernie_m"].extend( + [ + "ErnieMForInformationExtraction", + "ErnieMForMultipleChoice", + "ErnieMForQuestionAnswering", + "ErnieMForSequenceClassification", + "ErnieMForTokenClassification", + "ErnieMModel", + "ErnieMPreTrainedModel", + ] + ) + _import_structure["models.deprecated.gptsan_japanese"].extend( + [ + "GPTSanJapaneseForConditionalGeneration", + "GPTSanJapaneseModel", + "GPTSanJapanesePreTrainedModel", + ] + ) + _import_structure["models.deprecated.graphormer"].extend( + [ + "GraphormerForGraphClassification", + "GraphormerModel", + "GraphormerPreTrainedModel", + ] + ) + _import_structure["models.deprecated.jukebox"].extend( + [ + "JukeboxModel", + "JukeboxPreTrainedModel", + "JukeboxPrior", + "JukeboxVQVAE", + ] + ) + _import_structure["models.deprecated.mctct"].extend( + [ + "MCTCTForCTC", + "MCTCTModel", + "MCTCTPreTrainedModel", + ] + ) + _import_structure["models.deprecated.mega"].extend( + [ + "MegaForCausalLM", + "MegaForMaskedLM", + "MegaForMultipleChoice", + "MegaForQuestionAnswering", + "MegaForSequenceClassification", + "MegaForTokenClassification", + "MegaModel", + "MegaPreTrainedModel", + ] + ) + _import_structure["models.deprecated.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) + _import_structure["models.deprecated.nat"].extend( + [ + "NatBackbone", + "NatForImageClassification", + "NatModel", + "NatPreTrainedModel", + ] + ) + _import_structure["models.deprecated.nezha"].extend( + [ + "NezhaForMaskedLM", + "NezhaForMultipleChoice", + "NezhaForNextSentencePrediction", + "NezhaForPreTraining", + "NezhaForQuestionAnswering", + "NezhaForSequenceClassification", + "NezhaForTokenClassification", + "NezhaModel", + "NezhaPreTrainedModel", + ] + ) + _import_structure["models.deprecated.open_llama"].extend( + [ + "OpenLlamaForCausalLM", + "OpenLlamaForSequenceClassification", + "OpenLlamaModel", + "OpenLlamaPreTrainedModel", + ] + ) + _import_structure["models.deprecated.qdqbert"].extend( + [ + "QDQBertForMaskedLM", + "QDQBertForMultipleChoice", + "QDQBertForNextSentencePrediction", + "QDQBertForQuestionAnswering", + "QDQBertForSequenceClassification", + "QDQBertForTokenClassification", + "QDQBertLMHeadModel", + "QDQBertModel", + "QDQBertPreTrainedModel", + ] + ) + _import_structure["models.deprecated.realm"].extend( + [ + "RealmEmbedder", + "RealmForOpenQA", + "RealmKnowledgeAugEncoder", + "RealmPreTrainedModel", + "RealmReader", + "RealmRetriever", + "RealmScorer", + ] + ) + _import_structure["models.deprecated.retribert"].extend( + [ + "RetriBertModel", + "RetriBertPreTrainedModel", + ] + ) + _import_structure["models.deprecated.speech_to_text_2"].extend( + ["Speech2Text2ForCausalLM", "Speech2Text2PreTrainedModel"] + ) + _import_structure["models.deprecated.trajectory_transformer"].extend( + [ + "TrajectoryTransformerModel", + "TrajectoryTransformerPreTrainedModel", + ] + ) + _import_structure["models.deprecated.transfo_xl"].extend( + [ + "AdaptiveEmbedding", + "TransfoXLForSequenceClassification", + "TransfoXLLMHeadModel", + "TransfoXLModel", + "TransfoXLPreTrainedModel", + ] + ) + _import_structure["models.deprecated.tvlt"].extend( + [ + "TvltForAudioVisualClassification", + "TvltForPreTraining", + "TvltModel", + "TvltPreTrainedModel", + ] + ) + _import_structure["models.deprecated.van"].extend( + [ + "VanForImageClassification", + "VanModel", + "VanPreTrainedModel", + ] + ) + _import_structure["models.deprecated.vit_hybrid"].extend( + [ + "ViTHybridForImageClassification", + "ViTHybridModel", + "ViTHybridPreTrainedModel", + ] + ) + _import_structure["models.deprecated.xlm_prophetnet"].extend( + [ + "XLMProphetNetDecoder", + "XLMProphetNetEncoder", + "XLMProphetNetForCausalLM", + "XLMProphetNetForConditionalGeneration", + "XLMProphetNetModel", + "XLMProphetNetPreTrainedModel", + ] + ) + _import_structure["models.depth_anything"].extend( + [ + "DepthAnythingForDepthEstimation", + "DepthAnythingPreTrainedModel", + ] + ) + _import_structure["models.depth_pro"].extend( + [ + "DepthProForDepthEstimation", + "DepthProModel", + "DepthProPreTrainedModel", + ] + ) + _import_structure["models.detr"].extend( + [ + "DetrForObjectDetection", + "DetrForSegmentation", + "DetrModel", + "DetrPreTrainedModel", + ] + ) + _import_structure["models.diffllama"].extend( + [ + "DiffLlamaForCausalLM", + "DiffLlamaForQuestionAnswering", + "DiffLlamaForSequenceClassification", + "DiffLlamaForTokenClassification", + "DiffLlamaModel", + "DiffLlamaPreTrainedModel", + ] + ) + _import_structure["models.dinat"].extend( + [ + "DinatBackbone", + "DinatForImageClassification", + "DinatModel", + "DinatPreTrainedModel", + ] + ) + _import_structure["models.dinov2"].extend( + [ + "Dinov2Backbone", + "Dinov2ForImageClassification", + "Dinov2Model", + "Dinov2PreTrainedModel", + ] + ) + _import_structure["models.dinov2_with_registers"].extend( + [ + "Dinov2WithRegistersBackbone", + "Dinov2WithRegistersForImageClassification", + "Dinov2WithRegistersModel", + "Dinov2WithRegistersPreTrainedModel", + ] + ) + _import_structure["models.distilbert"].extend( + [ + "DistilBertForMaskedLM", + "DistilBertForMultipleChoice", + "DistilBertForQuestionAnswering", + "DistilBertForSequenceClassification", + "DistilBertForTokenClassification", + "DistilBertModel", + "DistilBertPreTrainedModel", + ] + ) + _import_structure["models.donut"].extend( + [ + "DonutSwinModel", + "DonutSwinPreTrainedModel", + ] + ) + _import_structure["models.dpr"].extend( + [ + "DPRContextEncoder", + "DPRPretrainedContextEncoder", + "DPRPreTrainedModel", + "DPRPretrainedQuestionEncoder", + "DPRPretrainedReader", + "DPRQuestionEncoder", + "DPRReader", + ] + ) + _import_structure["models.dpt"].extend( + [ + "DPTForDepthEstimation", + "DPTForSemanticSegmentation", + "DPTModel", + "DPTPreTrainedModel", + ] + ) + _import_structure["models.efficientnet"].extend( + [ + "EfficientNetForImageClassification", + "EfficientNetModel", + "EfficientNetPreTrainedModel", + ] + ) + _import_structure["models.electra"].extend( + [ + "ElectraForCausalLM", + "ElectraForMaskedLM", + "ElectraForMultipleChoice", + "ElectraForPreTraining", + "ElectraForQuestionAnswering", + "ElectraForSequenceClassification", + "ElectraForTokenClassification", + "ElectraModel", + "ElectraPreTrainedModel", + ] + ) + _import_structure["models.emu3"].extend( + [ + "Emu3ForCausalLM", + "Emu3ForConditionalGeneration", + "Emu3PreTrainedModel", + "Emu3TextModel", + "Emu3VQVAE", + ] + ) + _import_structure["models.encodec"].extend( + [ + "EncodecModel", + "EncodecPreTrainedModel", + ] + ) + _import_structure["models.encoder_decoder"].append("EncoderDecoderModel") + _import_structure["models.ernie"].extend( + [ + "ErnieForCausalLM", + "ErnieForMaskedLM", + "ErnieForMultipleChoice", + "ErnieForNextSentencePrediction", + "ErnieForPreTraining", + "ErnieForQuestionAnswering", + "ErnieForSequenceClassification", + "ErnieForTokenClassification", + "ErnieModel", + "ErniePreTrainedModel", + ] + ) + _import_structure["models.esm"].extend( + [ + "EsmFoldPreTrainedModel", + "EsmForMaskedLM", + "EsmForProteinFolding", + "EsmForSequenceClassification", + "EsmForTokenClassification", + "EsmModel", + "EsmPreTrainedModel", + ] + ) + _import_structure["models.falcon"].extend( + [ + "FalconForCausalLM", + "FalconForQuestionAnswering", + "FalconForSequenceClassification", + "FalconForTokenClassification", + "FalconModel", + "FalconPreTrainedModel", + ] + ) + _import_structure["models.falcon_mamba"].extend( + [ + "FalconMambaForCausalLM", + "FalconMambaModel", + "FalconMambaPreTrainedModel", + ] + ) + _import_structure["models.fastspeech2_conformer"].extend( + [ + "FastSpeech2ConformerHifiGan", + "FastSpeech2ConformerModel", + "FastSpeech2ConformerPreTrainedModel", + "FastSpeech2ConformerWithHifiGan", + ] + ) + _import_structure["models.flaubert"].extend( + [ + "FlaubertForMultipleChoice", + "FlaubertForQuestionAnswering", + "FlaubertForQuestionAnsweringSimple", + "FlaubertForSequenceClassification", + "FlaubertForTokenClassification", + "FlaubertModel", + "FlaubertPreTrainedModel", + "FlaubertWithLMHeadModel", + ] + ) + _import_structure["models.flava"].extend( + [ + "FlavaForPreTraining", + "FlavaImageCodebook", + "FlavaImageModel", + "FlavaModel", + "FlavaMultimodalModel", + "FlavaPreTrainedModel", + "FlavaTextModel", + ] + ) + _import_structure["models.fnet"].extend( + [ + "FNetForMaskedLM", + "FNetForMultipleChoice", + "FNetForNextSentencePrediction", + "FNetForPreTraining", + "FNetForQuestionAnswering", + "FNetForSequenceClassification", + "FNetForTokenClassification", + "FNetModel", + "FNetPreTrainedModel", + ] + ) + _import_structure["models.focalnet"].extend( + [ + "FocalNetBackbone", + "FocalNetForImageClassification", + "FocalNetForMaskedImageModeling", + "FocalNetModel", + "FocalNetPreTrainedModel", + ] + ) + _import_structure["models.fsmt"].extend(["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]) + _import_structure["models.funnel"].extend( + [ + "FunnelBaseModel", + "FunnelForMaskedLM", + "FunnelForMultipleChoice", + "FunnelForPreTraining", + "FunnelForQuestionAnswering", + "FunnelForSequenceClassification", + "FunnelForTokenClassification", + "FunnelModel", + "FunnelPreTrainedModel", + ] + ) + _import_structure["models.fuyu"].extend(["FuyuForCausalLM", "FuyuPreTrainedModel"]) + _import_structure["models.gemma"].extend( + [ + "GemmaForCausalLM", + "GemmaForSequenceClassification", + "GemmaForTokenClassification", + "GemmaModel", + "GemmaPreTrainedModel", + ] + ) + _import_structure["models.gemma2"].extend( + [ + "Gemma2ForCausalLM", + "Gemma2ForSequenceClassification", + "Gemma2ForTokenClassification", + "Gemma2Model", + "Gemma2PreTrainedModel", + ] + ) + _import_structure["models.gemma3"].extend( + [ + "Gemma3ForCausalLM", + "Gemma3ForConditionalGeneration", + "Gemma3PreTrainedModel", + "Gemma3TextModel", + ] + ) + _import_structure["models.git"].extend( + [ + "GitForCausalLM", + "GitModel", + "GitPreTrainedModel", + "GitVisionModel", + ] + ) + _import_structure["models.glm"].extend( + [ + "GlmForCausalLM", + "GlmForSequenceClassification", + "GlmForTokenClassification", + "GlmModel", + "GlmPreTrainedModel", + ] + ) + _import_structure["models.llama4"].extend( + [ + "Llama4ForCausalLM", + "Llama4ForConditionalGeneration", + "Llama4TextModel", + "Llama4VisionModel", + "Llama4PreTrainedModel", + ] + ) + _import_structure["models.glpn"].extend( + [ + "GLPNForDepthEstimation", + "GLPNModel", + "GLPNPreTrainedModel", + ] + ) + _import_structure["models.got_ocr2"].extend( + [ + "GotOcr2ForConditionalGeneration", + "GotOcr2PreTrainedModel", + ] + ) + _import_structure["models.gpt2"].extend( + [ + "GPT2DoubleHeadsModel", + "GPT2ForQuestionAnswering", + "GPT2ForSequenceClassification", + "GPT2ForTokenClassification", + "GPT2LMHeadModel", + "GPT2Model", + "GPT2PreTrainedModel", + ] + ) + _import_structure["models.gpt_bigcode"].extend( + [ + "GPTBigCodeForCausalLM", + "GPTBigCodeForSequenceClassification", + "GPTBigCodeForTokenClassification", + "GPTBigCodeModel", + "GPTBigCodePreTrainedModel", + ] + ) + _import_structure["models.gpt_neo"].extend( + [ + "GPTNeoForCausalLM", + "GPTNeoForQuestionAnswering", + "GPTNeoForSequenceClassification", + "GPTNeoForTokenClassification", + "GPTNeoModel", + "GPTNeoPreTrainedModel", + ] + ) + _import_structure["models.gpt_neox"].extend( + [ + "GPTNeoXForCausalLM", + "GPTNeoXForQuestionAnswering", + "GPTNeoXForSequenceClassification", + "GPTNeoXForTokenClassification", + "GPTNeoXModel", + "GPTNeoXPreTrainedModel", + ] + ) + _import_structure["models.gpt_neox_japanese"].extend( + [ + "GPTNeoXJapaneseForCausalLM", + "GPTNeoXJapaneseModel", + "GPTNeoXJapanesePreTrainedModel", + ] + ) + _import_structure["models.gptj"].extend( + [ + "GPTJForCausalLM", + "GPTJForQuestionAnswering", + "GPTJForSequenceClassification", + "GPTJModel", + "GPTJPreTrainedModel", + ] + ) + _import_structure["models.granite"].extend( + [ + "GraniteForCausalLM", + "GraniteModel", + "GranitePreTrainedModel", + ] + ) + _import_structure["models.granitemoe"].extend( + [ + "GraniteMoeForCausalLM", + "GraniteMoeModel", + "GraniteMoePreTrainedModel", + ] + ) + + _import_structure["models.granitemoeshared"].extend( + [ + "GraniteMoeSharedForCausalLM", + "GraniteMoeSharedModel", + "GraniteMoeSharedPreTrainedModel", + ] + ) + _import_structure["models.grounding_dino"].extend( + [ + "GroundingDinoForObjectDetection", + "GroundingDinoModel", + "GroundingDinoPreTrainedModel", + ] + ) + _import_structure["models.groupvit"].extend( + [ + "GroupViTModel", + "GroupViTPreTrainedModel", + "GroupViTTextModel", + "GroupViTVisionModel", + ] + ) + _import_structure["models.helium"].extend( + [ + "HeliumForCausalLM", + "HeliumForSequenceClassification", + "HeliumForTokenClassification", + "HeliumModel", + "HeliumPreTrainedModel", + ] + ) + _import_structure["models.hiera"].extend( + [ + "HieraBackbone", + "HieraForImageClassification", + "HieraForPreTraining", + "HieraModel", + "HieraPreTrainedModel", + ] + ) + _import_structure["models.hubert"].extend( + [ + "HubertForCTC", + "HubertForSequenceClassification", + "HubertModel", + "HubertPreTrainedModel", + ] + ) + _import_structure["models.ibert"].extend( + [ + "IBertForMaskedLM", + "IBertForMultipleChoice", + "IBertForQuestionAnswering", + "IBertForSequenceClassification", + "IBertForTokenClassification", + "IBertModel", + "IBertPreTrainedModel", + ] + ) + _import_structure["models.idefics"].extend( + [ + "IdeficsForVisionText2Text", + "IdeficsModel", + "IdeficsPreTrainedModel", + "IdeficsProcessor", + ] + ) + _import_structure["models.idefics2"].extend( + [ + "Idefics2ForConditionalGeneration", + "Idefics2Model", + "Idefics2PreTrainedModel", + "Idefics2Processor", + ] + ) + _import_structure["models.idefics3"].extend( + [ + "Idefics3ForConditionalGeneration", + "Idefics3Model", + "Idefics3PreTrainedModel", + "Idefics3Processor", + "Idefics3VisionConfig", + "Idefics3VisionTransformer", + ] + ) + _import_structure["models.ijepa"].extend( + [ + "IJepaForImageClassification", + "IJepaModel", + "IJepaPreTrainedModel", + ] + ) + _import_structure["models.imagegpt"].extend( + [ + "ImageGPTForCausalImageModeling", + "ImageGPTForImageClassification", + "ImageGPTModel", + "ImageGPTPreTrainedModel", + ] + ) + _import_structure["models.informer"].extend( + [ + "InformerForPrediction", + "InformerModel", + "InformerPreTrainedModel", + ] + ) + _import_structure["models.instructblip"].extend( + [ + "InstructBlipForConditionalGeneration", + "InstructBlipPreTrainedModel", + "InstructBlipQFormerModel", + "InstructBlipVisionModel", + ] + ) + _import_structure["models.instructblipvideo"].extend( + [ + "InstructBlipVideoForConditionalGeneration", + "InstructBlipVideoPreTrainedModel", + "InstructBlipVideoQFormerModel", + "InstructBlipVideoVisionModel", + ] + ) + _import_structure["models.jamba"].extend( + [ + "JambaForCausalLM", + "JambaForSequenceClassification", + "JambaModel", + "JambaPreTrainedModel", + ] + ) + _import_structure["models.jetmoe"].extend( + [ + "JetMoeForCausalLM", + "JetMoeForSequenceClassification", + "JetMoeModel", + "JetMoePreTrainedModel", + ] + ) + _import_structure["models.kosmos2"].extend( + [ + "Kosmos2ForConditionalGeneration", + "Kosmos2Model", + "Kosmos2PreTrainedModel", + ] + ) + _import_structure["models.layoutlm"].extend( + [ + "LayoutLMForMaskedLM", + "LayoutLMForQuestionAnswering", + "LayoutLMForSequenceClassification", + "LayoutLMForTokenClassification", + "LayoutLMModel", + "LayoutLMPreTrainedModel", + ] + ) + _import_structure["models.layoutlmv2"].extend( + [ + "LayoutLMv2ForQuestionAnswering", + "LayoutLMv2ForSequenceClassification", + "LayoutLMv2ForTokenClassification", + "LayoutLMv2Model", + "LayoutLMv2PreTrainedModel", + ] + ) + _import_structure["models.layoutlmv3"].extend( + [ + "LayoutLMv3ForQuestionAnswering", + "LayoutLMv3ForSequenceClassification", + "LayoutLMv3ForTokenClassification", + "LayoutLMv3Model", + "LayoutLMv3PreTrainedModel", + ] + ) + _import_structure["models.led"].extend( + [ + "LEDForConditionalGeneration", + "LEDForQuestionAnswering", + "LEDForSequenceClassification", + "LEDModel", + "LEDPreTrainedModel", + ] + ) + _import_structure["models.levit"].extend( + [ + "LevitForImageClassification", + "LevitForImageClassificationWithTeacher", + "LevitModel", + "LevitPreTrainedModel", + ] + ) + _import_structure["models.lilt"].extend( + [ + "LiltForQuestionAnswering", + "LiltForSequenceClassification", + "LiltForTokenClassification", + "LiltModel", + "LiltPreTrainedModel", + ] + ) + _import_structure["models.llama"].extend( + [ + "LlamaForCausalLM", + "LlamaForQuestionAnswering", + "LlamaForSequenceClassification", + "LlamaForTokenClassification", + "LlamaModel", + "LlamaPreTrainedModel", + ] + ) + _import_structure["models.llava"].extend( + [ + "LlavaForConditionalGeneration", + "LlavaPreTrainedModel", + ] + ) + _import_structure["models.llava_next"].extend( + [ + "LlavaNextForConditionalGeneration", + "LlavaNextPreTrainedModel", + ] + ) + _import_structure["models.phi4_multimodal"].extend( + [ + "Phi4MultimodalForCausalLM", + "Phi4MultimodalPreTrainedModel", + "Phi4MultimodalAudioModel", + "Phi4MultimodalAudioPreTrainedModel", + "Phi4MultimodalModel", + "Phi4MultimodalVisionModel", + "Phi4MultimodalVisionPreTrainedModel", + ] + ) + _import_structure["models.llava_next_video"].extend( + [ + "LlavaNextVideoForConditionalGeneration", + "LlavaNextVideoPreTrainedModel", + ] + ) + _import_structure["models.llava_onevision"].extend( + [ + "LlavaOnevisionForConditionalGeneration", + "LlavaOnevisionPreTrainedModel", + ] + ) + _import_structure["models.longformer"].extend( + [ + "LongformerForMaskedLM", + "LongformerForMultipleChoice", + "LongformerForQuestionAnswering", + "LongformerForSequenceClassification", + "LongformerForTokenClassification", + "LongformerModel", + "LongformerPreTrainedModel", + ] + ) + _import_structure["models.longt5"].extend( + [ + "LongT5EncoderModel", + "LongT5ForConditionalGeneration", + "LongT5Model", + "LongT5PreTrainedModel", + ] + ) + _import_structure["models.luke"].extend( + [ + "LukeForEntityClassification", + "LukeForEntityPairClassification", + "LukeForEntitySpanClassification", + "LukeForMaskedLM", + "LukeForMultipleChoice", + "LukeForQuestionAnswering", + "LukeForSequenceClassification", + "LukeForTokenClassification", + "LukeModel", + "LukePreTrainedModel", + ] + ) + _import_structure["models.lxmert"].extend( + [ + "LxmertEncoder", + "LxmertForPreTraining", + "LxmertForQuestionAnswering", + "LxmertModel", + "LxmertPreTrainedModel", + "LxmertVisualFeatureEncoder", + ] + ) + _import_structure["models.m2m_100"].extend( + [ + "M2M100ForConditionalGeneration", + "M2M100Model", + "M2M100PreTrainedModel", + ] + ) + _import_structure["models.mamba"].extend( + [ + "MambaForCausalLM", + "MambaModel", + "MambaPreTrainedModel", + ] + ) + _import_structure["models.mamba2"].extend( + [ + "Mamba2ForCausalLM", + "Mamba2Model", + "Mamba2PreTrainedModel", + ] + ) + _import_structure["models.marian"].extend( + ["MarianForCausalLM", "MarianModel", "MarianMTModel", "MarianPreTrainedModel"] + ) + _import_structure["models.markuplm"].extend( + [ + "MarkupLMForQuestionAnswering", + "MarkupLMForSequenceClassification", + "MarkupLMForTokenClassification", + "MarkupLMModel", + "MarkupLMPreTrainedModel", + ] + ) + _import_structure["models.mask2former"].extend( + [ + "Mask2FormerForUniversalSegmentation", + "Mask2FormerModel", + "Mask2FormerPreTrainedModel", + ] + ) + _import_structure["models.maskformer"].extend( + [ + "MaskFormerForInstanceSegmentation", + "MaskFormerModel", + "MaskFormerPreTrainedModel", + "MaskFormerSwinBackbone", + ] + ) + _import_structure["models.mbart"].extend( + [ + "MBartForCausalLM", + "MBartForConditionalGeneration", + "MBartForQuestionAnswering", + "MBartForSequenceClassification", + "MBartModel", + "MBartPreTrainedModel", + ] + ) + _import_structure["models.megatron_bert"].extend( + [ + "MegatronBertForCausalLM", + "MegatronBertForMaskedLM", + "MegatronBertForMultipleChoice", + "MegatronBertForNextSentencePrediction", + "MegatronBertForPreTraining", + "MegatronBertForQuestionAnswering", + "MegatronBertForSequenceClassification", + "MegatronBertForTokenClassification", + "MegatronBertModel", + "MegatronBertPreTrainedModel", + ] + ) + _import_structure["models.mgp_str"].extend( + [ + "MgpstrForSceneTextRecognition", + "MgpstrModel", + "MgpstrPreTrainedModel", + ] + ) + _import_structure["models.mimi"].extend( + [ + "MimiModel", + "MimiPreTrainedModel", + ] + ) + _import_structure["models.mistral"].extend( + [ + "MistralForCausalLM", + "MistralForQuestionAnswering", + "MistralForSequenceClassification", + "MistralForTokenClassification", + "MistralModel", + "MistralPreTrainedModel", + ] + ) + _import_structure["models.mistral3"].extend( + [ + "Mistral3ForConditionalGeneration", + "Mistral3PreTrainedModel", + ] + ) + _import_structure["models.mixtral"].extend( + [ + "MixtralForCausalLM", + "MixtralForQuestionAnswering", + "MixtralForSequenceClassification", + "MixtralForTokenClassification", + "MixtralModel", + "MixtralPreTrainedModel", + ] + ) + _import_structure["models.mllama"].extend( + [ + "MllamaForCausalLM", + "MllamaForConditionalGeneration", + "MllamaPreTrainedModel", + "MllamaProcessor", + "MllamaTextModel", + "MllamaVisionModel", + ] + ) + _import_structure["models.mobilebert"].extend( + [ + "MobileBertForMaskedLM", + "MobileBertForMultipleChoice", + "MobileBertForNextSentencePrediction", + "MobileBertForPreTraining", + "MobileBertForQuestionAnswering", + "MobileBertForSequenceClassification", + "MobileBertForTokenClassification", + "MobileBertModel", + "MobileBertPreTrainedModel", + ] + ) + _import_structure["models.mobilenet_v1"].extend( + [ + "MobileNetV1ForImageClassification", + "MobileNetV1Model", + "MobileNetV1PreTrainedModel", + ] + ) + _import_structure["models.mobilenet_v2"].extend( + [ + "MobileNetV2ForImageClassification", + "MobileNetV2ForSemanticSegmentation", + "MobileNetV2Model", + "MobileNetV2PreTrainedModel", + ] + ) + _import_structure["models.mobilevit"].extend( + [ + "MobileViTForImageClassification", + "MobileViTForSemanticSegmentation", + "MobileViTModel", + "MobileViTPreTrainedModel", + ] + ) + _import_structure["models.mobilevitv2"].extend( + [ + "MobileViTV2ForImageClassification", + "MobileViTV2ForSemanticSegmentation", + "MobileViTV2Model", + "MobileViTV2PreTrainedModel", + ] + ) + _import_structure["models.modernbert"].extend( + [ + "ModernBertForMaskedLM", + "ModernBertForQuestionAnswering", + "ModernBertForSequenceClassification", + "ModernBertForTokenClassification", + "ModernBertModel", + "ModernBertPreTrainedModel", + ] + ) + _import_structure["models.moonshine"].extend( + [ + "MoonshineForConditionalGeneration", + "MoonshineModel", + "MoonshinePreTrainedModel", + ] + ) + _import_structure["models.moshi"].extend( + [ + "MoshiForCausalLM", + "MoshiForConditionalGeneration", + "MoshiModel", + "MoshiPreTrainedModel", + ] + ) + _import_structure["models.mpnet"].extend( + [ + "MPNetForMaskedLM", + "MPNetForMultipleChoice", + "MPNetForQuestionAnswering", + "MPNetForSequenceClassification", + "MPNetForTokenClassification", + "MPNetModel", + "MPNetPreTrainedModel", + ] + ) + _import_structure["models.mpt"].extend( + [ + "MptForCausalLM", + "MptForQuestionAnswering", + "MptForSequenceClassification", + "MptForTokenClassification", + "MptModel", + "MptPreTrainedModel", + ] + ) + _import_structure["models.mra"].extend( + [ + "MraForMaskedLM", + "MraForMultipleChoice", + "MraForQuestionAnswering", + "MraForSequenceClassification", + "MraForTokenClassification", + "MraModel", + "MraPreTrainedModel", + ] + ) + _import_structure["models.mt5"].extend( + [ + "MT5EncoderModel", + "MT5ForConditionalGeneration", + "MT5ForQuestionAnswering", + "MT5ForSequenceClassification", + "MT5ForTokenClassification", + "MT5Model", + "MT5PreTrainedModel", + ] + ) + _import_structure["models.musicgen"].extend( + [ + "MusicgenForCausalLM", + "MusicgenForConditionalGeneration", + "MusicgenModel", + "MusicgenPreTrainedModel", + "MusicgenProcessor", + ] + ) + _import_structure["models.musicgen_melody"].extend( + [ + "MusicgenMelodyForCausalLM", + "MusicgenMelodyForConditionalGeneration", + "MusicgenMelodyModel", + "MusicgenMelodyPreTrainedModel", + ] + ) + _import_structure["models.mvp"].extend( + [ + "MvpForCausalLM", + "MvpForConditionalGeneration", + "MvpForQuestionAnswering", + "MvpForSequenceClassification", + "MvpModel", + "MvpPreTrainedModel", + ] + ) + _import_structure["models.nemotron"].extend( + [ + "NemotronForCausalLM", + "NemotronForQuestionAnswering", + "NemotronForSequenceClassification", + "NemotronForTokenClassification", + "NemotronModel", + "NemotronPreTrainedModel", + ] + ) + _import_structure["models.nllb_moe"].extend( + [ + "NllbMoeForConditionalGeneration", + "NllbMoeModel", + "NllbMoePreTrainedModel", + "NllbMoeSparseMLP", + "NllbMoeTop2Router", + ] + ) + _import_structure["models.nystromformer"].extend( + [ + "NystromformerForMaskedLM", + "NystromformerForMultipleChoice", + "NystromformerForQuestionAnswering", + "NystromformerForSequenceClassification", + "NystromformerForTokenClassification", + "NystromformerModel", + "NystromformerPreTrainedModel", + ] + ) + _import_structure["models.olmo"].extend( + [ + "OlmoForCausalLM", + "OlmoModel", + "OlmoPreTrainedModel", + ] + ) + _import_structure["models.olmo2"].extend( + [ + "Olmo2ForCausalLM", + "Olmo2Model", + "Olmo2PreTrainedModel", + ] + ) + _import_structure["models.olmoe"].extend( + [ + "OlmoeForCausalLM", + "OlmoeModel", + "OlmoePreTrainedModel", + ] + ) + _import_structure["models.omdet_turbo"].extend( + [ + "OmDetTurboForObjectDetection", + "OmDetTurboPreTrainedModel", + ] + ) + _import_structure["models.oneformer"].extend( + [ + "OneFormerForUniversalSegmentation", + "OneFormerModel", + "OneFormerPreTrainedModel", + ] + ) + _import_structure["models.openai"].extend( + [ + "OpenAIGPTDoubleHeadsModel", + "OpenAIGPTForSequenceClassification", + "OpenAIGPTLMHeadModel", + "OpenAIGPTModel", + "OpenAIGPTPreTrainedModel", + ] + ) + _import_structure["models.opt"].extend( + [ + "OPTForCausalLM", + "OPTForQuestionAnswering", + "OPTForSequenceClassification", + "OPTModel", + "OPTPreTrainedModel", + ] + ) + _import_structure["models.owlv2"].extend( + [ + "Owlv2ForObjectDetection", + "Owlv2Model", + "Owlv2PreTrainedModel", + "Owlv2TextModel", + "Owlv2VisionModel", + ] + ) + _import_structure["models.owlvit"].extend( + [ + "OwlViTForObjectDetection", + "OwlViTModel", + "OwlViTPreTrainedModel", + "OwlViTTextModel", + "OwlViTVisionModel", + ] + ) + _import_structure["models.paligemma"].extend( + [ + "PaliGemmaForConditionalGeneration", + "PaliGemmaPreTrainedModel", + "PaliGemmaProcessor", + ] + ) + _import_structure["models.patchtsmixer"].extend( + [ + "PatchTSMixerForPrediction", + "PatchTSMixerForPretraining", + "PatchTSMixerForRegression", + "PatchTSMixerForTimeSeriesClassification", + "PatchTSMixerModel", + "PatchTSMixerPreTrainedModel", + ] + ) + _import_structure["models.patchtst"].extend( + [ + "PatchTSTForClassification", + "PatchTSTForPrediction", + "PatchTSTForPretraining", + "PatchTSTForRegression", + "PatchTSTModel", + "PatchTSTPreTrainedModel", + ] + ) + _import_structure["models.pegasus"].extend( + [ + "PegasusForCausalLM", + "PegasusForConditionalGeneration", + "PegasusModel", + "PegasusPreTrainedModel", + ] + ) + _import_structure["models.pegasus_x"].extend( + [ + "PegasusXForConditionalGeneration", + "PegasusXModel", + "PegasusXPreTrainedModel", + ] + ) + _import_structure["models.perceiver"].extend( + [ + "PerceiverForImageClassificationConvProcessing", + "PerceiverForImageClassificationFourier", + "PerceiverForImageClassificationLearned", + "PerceiverForMaskedLM", + "PerceiverForMultimodalAutoencoding", + "PerceiverForOpticalFlow", + "PerceiverForSequenceClassification", + "PerceiverModel", + "PerceiverPreTrainedModel", + ] + ) + _import_structure["models.persimmon"].extend( + [ + "PersimmonForCausalLM", + "PersimmonForSequenceClassification", + "PersimmonForTokenClassification", + "PersimmonModel", + "PersimmonPreTrainedModel", + ] + ) + _import_structure["models.phi"].extend( + [ + "PhiForCausalLM", + "PhiForSequenceClassification", + "PhiForTokenClassification", + "PhiModel", + "PhiPreTrainedModel", + ] + ) + _import_structure["models.phi3"].extend( + [ + "Phi3ForCausalLM", + "Phi3ForSequenceClassification", + "Phi3ForTokenClassification", + "Phi3Model", + "Phi3PreTrainedModel", + ] + ) + _import_structure["models.phimoe"].extend( + [ + "PhimoeForCausalLM", + "PhimoeForSequenceClassification", + "PhimoeModel", + "PhimoePreTrainedModel", + ] + ) + _import_structure["models.pix2struct"].extend( + [ + "Pix2StructForConditionalGeneration", + "Pix2StructPreTrainedModel", + "Pix2StructTextModel", + "Pix2StructVisionModel", + ] + ) + _import_structure["models.pixtral"].extend(["PixtralPreTrainedModel", "PixtralVisionModel"]) + _import_structure["models.plbart"].extend( + [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", + ] + ) + _import_structure["models.poolformer"].extend( + [ + "PoolFormerForImageClassification", + "PoolFormerModel", + "PoolFormerPreTrainedModel", + ] + ) + _import_structure["models.pop2piano"].extend( + [ + "Pop2PianoForConditionalGeneration", + "Pop2PianoPreTrainedModel", + ] + ) + _import_structure["models.prompt_depth_anything"].extend( + [ + "PromptDepthAnythingForDepthEstimation", + "PromptDepthAnythingPreTrainedModel", + ] + ) + _import_structure["models.prophetnet"].extend( + [ + "ProphetNetDecoder", + "ProphetNetEncoder", + "ProphetNetForCausalLM", + "ProphetNetForConditionalGeneration", + "ProphetNetModel", + "ProphetNetPreTrainedModel", + ] + ) + _import_structure["models.pvt"].extend( + [ + "PvtForImageClassification", + "PvtModel", + "PvtPreTrainedModel", + ] + ) + _import_structure["models.pvt_v2"].extend( + [ + "PvtV2Backbone", + "PvtV2ForImageClassification", + "PvtV2Model", + "PvtV2PreTrainedModel", + ] + ) + _import_structure["models.qwen2"].extend( + [ + "Qwen2ForCausalLM", + "Qwen2ForQuestionAnswering", + "Qwen2ForSequenceClassification", + "Qwen2ForTokenClassification", + "Qwen2Model", + "Qwen2PreTrainedModel", + ] + ) + _import_structure["models.qwen2_5_vl"].extend( + [ + "Qwen2_5_VLForConditionalGeneration", + "Qwen2_5_VLModel", + "Qwen2_5_VLPreTrainedModel", + ] + ) + _import_structure["models.qwen2_audio"].extend( + [ + "Qwen2AudioEncoder", + "Qwen2AudioForConditionalGeneration", + "Qwen2AudioPreTrainedModel", + ] + ) + _import_structure["models.qwen2_moe"].extend( + [ + "Qwen2MoeForCausalLM", + "Qwen2MoeForQuestionAnswering", + "Qwen2MoeForSequenceClassification", + "Qwen2MoeForTokenClassification", + "Qwen2MoeModel", + "Qwen2MoePreTrainedModel", + ] + ) + _import_structure["models.qwen2_vl"].extend( + [ + "Qwen2VLForConditionalGeneration", + "Qwen2VLModel", + "Qwen2VLPreTrainedModel", + ] + ) + _import_structure["models.qwen3"].extend( + [ + "Qwen3ForCausalLM", + "Qwen3ForQuestionAnswering", + "Qwen3ForSequenceClassification", + "Qwen3ForTokenClassification", + "Qwen3Model", + "Qwen3PreTrainedModel", + ] + ) + _import_structure["models.qwen3_moe"].extend( + [ + "Qwen3MoeForCausalLM", + "Qwen3MoeForQuestionAnswering", + "Qwen3MoeForSequenceClassification", + "Qwen3MoeForTokenClassification", + "Qwen3MoeModel", + "Qwen3MoePreTrainedModel", + ] + ) + _import_structure["models.rag"].extend( + [ + "RagModel", + "RagPreTrainedModel", + "RagSequenceForGeneration", + "RagTokenForGeneration", + ] + ) + _import_structure["models.recurrent_gemma"].extend( + [ + "RecurrentGemmaForCausalLM", + "RecurrentGemmaModel", + "RecurrentGemmaPreTrainedModel", + ] + ) + _import_structure["models.reformer"].extend( + [ + "ReformerForMaskedLM", + "ReformerForQuestionAnswering", + "ReformerForSequenceClassification", + "ReformerModel", + "ReformerModelWithLMHead", + "ReformerPreTrainedModel", + ] + ) + _import_structure["models.regnet"].extend( + [ + "RegNetForImageClassification", + "RegNetModel", + "RegNetPreTrainedModel", + ] + ) + _import_structure["models.rembert"].extend( + [ + "RemBertForCausalLM", + "RemBertForMaskedLM", + "RemBertForMultipleChoice", + "RemBertForQuestionAnswering", + "RemBertForSequenceClassification", + "RemBertForTokenClassification", + "RemBertModel", + "RemBertPreTrainedModel", + ] + ) + _import_structure["models.resnet"].extend( + [ + "ResNetBackbone", + "ResNetForImageClassification", + "ResNetModel", + "ResNetPreTrainedModel", + ] + ) + _import_structure["models.roberta"].extend( + [ + "RobertaForCausalLM", + "RobertaForMaskedLM", + "RobertaForMultipleChoice", + "RobertaForQuestionAnswering", + "RobertaForSequenceClassification", + "RobertaForTokenClassification", + "RobertaModel", + "RobertaPreTrainedModel", + ] + ) + _import_structure["models.roberta_prelayernorm"].extend( + [ + "RobertaPreLayerNormForCausalLM", + "RobertaPreLayerNormForMaskedLM", + "RobertaPreLayerNormForMultipleChoice", + "RobertaPreLayerNormForQuestionAnswering", + "RobertaPreLayerNormForSequenceClassification", + "RobertaPreLayerNormForTokenClassification", + "RobertaPreLayerNormModel", + "RobertaPreLayerNormPreTrainedModel", + ] + ) + _import_structure["models.roc_bert"].extend( + [ + "RoCBertForCausalLM", + "RoCBertForMaskedLM", + "RoCBertForMultipleChoice", + "RoCBertForPreTraining", + "RoCBertForQuestionAnswering", + "RoCBertForSequenceClassification", + "RoCBertForTokenClassification", + "RoCBertModel", + "RoCBertPreTrainedModel", + ] + ) + _import_structure["models.roformer"].extend( + [ + "RoFormerForCausalLM", + "RoFormerForMaskedLM", + "RoFormerForMultipleChoice", + "RoFormerForQuestionAnswering", + "RoFormerForSequenceClassification", + "RoFormerForTokenClassification", + "RoFormerModel", + "RoFormerPreTrainedModel", + ] + ) + _import_structure["models.rt_detr"].extend( + [ + "RTDetrForObjectDetection", + "RTDetrModel", + "RTDetrPreTrainedModel", + "RTDetrResNetBackbone", + "RTDetrResNetPreTrainedModel", + ] + ) + _import_structure["models.rt_detr_v2"].extend( + ["RTDetrV2ForObjectDetection", "RTDetrV2Model", "RTDetrV2PreTrainedModel"] + ) + _import_structure["models.rwkv"].extend( + [ + "RwkvForCausalLM", + "RwkvModel", + "RwkvPreTrainedModel", + ] + ) + _import_structure["models.sam"].extend( + [ + "SamModel", + "SamPreTrainedModel", + "SamVisionModel", + ] + ) + _import_structure["models.seamless_m4t"].extend( + [ + "SeamlessM4TCodeHifiGan", + "SeamlessM4TForSpeechToSpeech", + "SeamlessM4TForSpeechToText", + "SeamlessM4TForTextToSpeech", + "SeamlessM4TForTextToText", + "SeamlessM4THifiGan", + "SeamlessM4TModel", + "SeamlessM4TPreTrainedModel", + "SeamlessM4TTextToUnitForConditionalGeneration", + "SeamlessM4TTextToUnitModel", + ] + ) + _import_structure["models.seamless_m4t_v2"].extend( + [ + "SeamlessM4Tv2ForSpeechToSpeech", + "SeamlessM4Tv2ForSpeechToText", + "SeamlessM4Tv2ForTextToSpeech", + "SeamlessM4Tv2ForTextToText", + "SeamlessM4Tv2Model", + "SeamlessM4Tv2PreTrainedModel", + ] + ) + _import_structure["models.segformer"].extend( + [ + "SegformerDecodeHead", + "SegformerForImageClassification", + "SegformerForSemanticSegmentation", + "SegformerModel", + "SegformerPreTrainedModel", + ] + ) + _import_structure["models.seggpt"].extend( + [ + "SegGptForImageSegmentation", + "SegGptModel", + "SegGptPreTrainedModel", + ] + ) + _import_structure["models.sew"].extend( + [ + "SEWForCTC", + "SEWForSequenceClassification", + "SEWModel", + "SEWPreTrainedModel", + ] + ) + _import_structure["models.sew_d"].extend( + [ + "SEWDForCTC", + "SEWDForSequenceClassification", + "SEWDModel", + "SEWDPreTrainedModel", + ] + ) + _import_structure["models.shieldgemma2"].append("ShieldGemma2ForImageClassification") + _import_structure["models.siglip"].extend( + [ + "SiglipForImageClassification", + "SiglipModel", + "SiglipPreTrainedModel", + "SiglipTextModel", + "SiglipVisionModel", + ] + ) + _import_structure["models.siglip2"].extend( + [ + "Siglip2ForImageClassification", + "Siglip2Model", + "Siglip2PreTrainedModel", + "Siglip2TextModel", + "Siglip2VisionModel", + ] + ) + _import_structure["models.smolvlm"].extend( + [ + "SmolVLMForConditionalGeneration", + "SmolVLMModel", + "SmolVLMPreTrainedModel", + "SmolVLMProcessor", + "SmolVLMVisionConfig", + "SmolVLMVisionTransformer", + ] + ) + _import_structure["models.speech_encoder_decoder"].extend(["SpeechEncoderDecoderModel"]) + _import_structure["models.speech_to_text"].extend( + [ + "Speech2TextForConditionalGeneration", + "Speech2TextModel", + "Speech2TextPreTrainedModel", + ] + ) + _import_structure["models.speecht5"].extend( + [ + "SpeechT5ForSpeechToSpeech", + "SpeechT5ForSpeechToText", + "SpeechT5ForTextToSpeech", + "SpeechT5HifiGan", + "SpeechT5Model", + "SpeechT5PreTrainedModel", + ] + ) + _import_structure["models.splinter"].extend( + [ + "SplinterForPreTraining", + "SplinterForQuestionAnswering", + "SplinterModel", + "SplinterPreTrainedModel", + ] + ) + _import_structure["models.squeezebert"].extend( + [ + "SqueezeBertForMaskedLM", + "SqueezeBertForMultipleChoice", + "SqueezeBertForQuestionAnswering", + "SqueezeBertForSequenceClassification", + "SqueezeBertForTokenClassification", + "SqueezeBertModel", + "SqueezeBertPreTrainedModel", + ] + ) + _import_structure["models.stablelm"].extend( + [ + "StableLmForCausalLM", + "StableLmForSequenceClassification", + "StableLmForTokenClassification", + "StableLmModel", + "StableLmPreTrainedModel", + ] + ) + _import_structure["models.starcoder2"].extend( + [ + "Starcoder2ForCausalLM", + "Starcoder2ForSequenceClassification", + "Starcoder2ForTokenClassification", + "Starcoder2Model", + "Starcoder2PreTrainedModel", + ] + ) + _import_structure["models.superglue"].extend( + [ + "SuperGlueForKeypointMatching", + "SuperGluePreTrainedModel", + ] + ) + _import_structure["models.superpoint"].extend( + [ + "SuperPointForKeypointDetection", + "SuperPointPreTrainedModel", + ] + ) + _import_structure["models.swiftformer"].extend( + [ + "SwiftFormerForImageClassification", + "SwiftFormerModel", + "SwiftFormerPreTrainedModel", + ] + ) + _import_structure["models.swin"].extend( + [ + "SwinBackbone", + "SwinForImageClassification", + "SwinForMaskedImageModeling", + "SwinModel", + "SwinPreTrainedModel", + ] + ) + _import_structure["models.swin2sr"].extend( + [ + "Swin2SRForImageSuperResolution", + "Swin2SRModel", + "Swin2SRPreTrainedModel", + ] + ) + _import_structure["models.swinv2"].extend( + [ + "Swinv2Backbone", + "Swinv2ForImageClassification", + "Swinv2ForMaskedImageModeling", + "Swinv2Model", + "Swinv2PreTrainedModel", + ] + ) + _import_structure["models.switch_transformers"].extend( + [ + "SwitchTransformersEncoderModel", + "SwitchTransformersForConditionalGeneration", + "SwitchTransformersModel", + "SwitchTransformersPreTrainedModel", + "SwitchTransformersSparseMLP", + "SwitchTransformersTop1Router", + ] + ) + _import_structure["models.t5"].extend( + [ + "T5EncoderModel", + "T5ForConditionalGeneration", + "T5ForQuestionAnswering", + "T5ForSequenceClassification", + "T5ForTokenClassification", + "T5Model", + "T5PreTrainedModel", + ] + ) + _import_structure["models.table_transformer"].extend( + [ + "TableTransformerForObjectDetection", + "TableTransformerModel", + "TableTransformerPreTrainedModel", + ] + ) + _import_structure["models.tapas"].extend( + [ + "TapasForMaskedLM", + "TapasForQuestionAnswering", + "TapasForSequenceClassification", + "TapasModel", + "TapasPreTrainedModel", + ] + ) + _import_structure["models.textnet"].extend( + [ + "TextNetBackbone", + "TextNetForImageClassification", + "TextNetModel", + "TextNetPreTrainedModel", + ] + ) + _import_structure["models.time_series_transformer"].extend( + [ + "TimeSeriesTransformerForPrediction", + "TimeSeriesTransformerModel", + "TimeSeriesTransformerPreTrainedModel", + ] + ) + _import_structure["models.timesformer"].extend( + [ + "TimesformerForVideoClassification", + "TimesformerModel", + "TimesformerPreTrainedModel", + ] + ) + _import_structure["models.timm_backbone"].extend(["TimmBackbone"]) + _import_structure["models.timm_wrapper"].extend( + ["TimmWrapperForImageClassification", "TimmWrapperModel", "TimmWrapperPreTrainedModel"] + ) + _import_structure["models.trocr"].extend( + [ + "TrOCRForCausalLM", + "TrOCRPreTrainedModel", + ] + ) + _import_structure["models.tvp"].extend( + [ + "TvpForVideoGrounding", + "TvpModel", + "TvpPreTrainedModel", + ] + ) + _import_structure["models.udop"].extend( + [ + "UdopEncoderModel", + "UdopForConditionalGeneration", + "UdopModel", + "UdopPreTrainedModel", + ], + ) + _import_structure["models.umt5"].extend( + [ + "UMT5EncoderModel", + "UMT5ForConditionalGeneration", + "UMT5ForQuestionAnswering", + "UMT5ForSequenceClassification", + "UMT5ForTokenClassification", + "UMT5Model", + "UMT5PreTrainedModel", + ] + ) + _import_structure["models.unispeech"].extend( + [ + "UniSpeechForCTC", + "UniSpeechForPreTraining", + "UniSpeechForSequenceClassification", + "UniSpeechModel", + "UniSpeechPreTrainedModel", + ] + ) + _import_structure["models.unispeech_sat"].extend( + [ + "UniSpeechSatForAudioFrameClassification", + "UniSpeechSatForCTC", + "UniSpeechSatForPreTraining", + "UniSpeechSatForSequenceClassification", + "UniSpeechSatForXVector", + "UniSpeechSatModel", + "UniSpeechSatPreTrainedModel", + ] + ) + _import_structure["models.univnet"].extend( + [ + "UnivNetModel", + ] + ) + _import_structure["models.upernet"].extend( + [ + "UperNetForSemanticSegmentation", + "UperNetPreTrainedModel", + ] + ) + _import_structure["models.video_llava"].extend( + [ + "VideoLlavaForConditionalGeneration", + "VideoLlavaPreTrainedModel", + "VideoLlavaProcessor", + ] + ) + _import_structure["models.videomae"].extend( + [ + "VideoMAEForPreTraining", + "VideoMAEForVideoClassification", + "VideoMAEModel", + "VideoMAEPreTrainedModel", + ] + ) + _import_structure["models.vilt"].extend( + [ + "ViltForImageAndTextRetrieval", + "ViltForImagesAndTextClassification", + "ViltForMaskedLM", + "ViltForQuestionAnswering", + "ViltForTokenClassification", + "ViltModel", + "ViltPreTrainedModel", + ] + ) + _import_structure["models.vipllava"].extend( + [ + "VipLlavaForConditionalGeneration", + "VipLlavaPreTrainedModel", + ] + ) + _import_structure["models.vision_encoder_decoder"].extend(["VisionEncoderDecoderModel"]) + _import_structure["models.vision_text_dual_encoder"].extend(["VisionTextDualEncoderModel"]) + _import_structure["models.visual_bert"].extend( + [ + "VisualBertForMultipleChoice", + "VisualBertForPreTraining", + "VisualBertForQuestionAnswering", + "VisualBertForRegionToPhraseAlignment", + "VisualBertForVisualReasoning", + "VisualBertModel", + "VisualBertPreTrainedModel", + ] + ) + _import_structure["models.vit"].extend( + [ + "ViTForImageClassification", + "ViTForMaskedImageModeling", + "ViTModel", + "ViTPreTrainedModel", + ] + ) + _import_structure["models.vit_mae"].extend( + [ + "ViTMAEForPreTraining", + "ViTMAEModel", + "ViTMAEPreTrainedModel", + ] + ) + _import_structure["models.vit_msn"].extend( + [ + "ViTMSNForImageClassification", + "ViTMSNModel", + "ViTMSNPreTrainedModel", + ] + ) + _import_structure["models.vitdet"].extend( + [ + "VitDetBackbone", + "VitDetModel", + "VitDetPreTrainedModel", + ] + ) + _import_structure["models.vitmatte"].extend( + [ + "VitMatteForImageMatting", + "VitMattePreTrainedModel", + ] + ) + _import_structure["models.vitpose"].extend( + [ + "VitPoseForPoseEstimation", + "VitPosePreTrainedModel", + ] + ) + _import_structure["models.vitpose_backbone"].extend( + [ + "VitPoseBackbone", + "VitPoseBackbonePreTrainedModel", + ] + ) + _import_structure["models.vits"].extend( + [ + "VitsModel", + "VitsPreTrainedModel", + ] + ) + _import_structure["models.vivit"].extend( + [ + "VivitForVideoClassification", + "VivitModel", + "VivitPreTrainedModel", + ] + ) + _import_structure["models.wav2vec2"].extend( + [ + "Wav2Vec2ForAudioFrameClassification", + "Wav2Vec2ForCTC", + "Wav2Vec2ForMaskedLM", + "Wav2Vec2ForPreTraining", + "Wav2Vec2ForSequenceClassification", + "Wav2Vec2ForXVector", + "Wav2Vec2Model", + "Wav2Vec2PreTrainedModel", + ] + ) + _import_structure["models.wav2vec2_bert"].extend( + [ + "Wav2Vec2BertForAudioFrameClassification", + "Wav2Vec2BertForCTC", + "Wav2Vec2BertForSequenceClassification", + "Wav2Vec2BertForXVector", + "Wav2Vec2BertModel", + "Wav2Vec2BertPreTrainedModel", + ] + ) + _import_structure["models.wav2vec2_conformer"].extend( + [ + "Wav2Vec2ConformerForAudioFrameClassification", + "Wav2Vec2ConformerForCTC", + "Wav2Vec2ConformerForPreTraining", + "Wav2Vec2ConformerForSequenceClassification", + "Wav2Vec2ConformerForXVector", + "Wav2Vec2ConformerModel", + "Wav2Vec2ConformerPreTrainedModel", + ] + ) + _import_structure["models.wavlm"].extend( + [ + "WavLMForAudioFrameClassification", + "WavLMForCTC", + "WavLMForSequenceClassification", + "WavLMForXVector", + "WavLMModel", + "WavLMPreTrainedModel", + ] + ) + _import_structure["models.whisper"].extend( + [ + "WhisperForAudioClassification", + "WhisperForCausalLM", + "WhisperForConditionalGeneration", + "WhisperModel", + "WhisperPreTrainedModel", + ] + ) + _import_structure["models.x_clip"].extend( + [ + "XCLIPModel", + "XCLIPPreTrainedModel", + "XCLIPTextModel", + "XCLIPVisionModel", + ] + ) + _import_structure["models.xglm"].extend( + [ + "XGLMForCausalLM", + "XGLMModel", + "XGLMPreTrainedModel", + ] + ) + _import_structure["models.xlm"].extend( + [ + "XLMForMultipleChoice", + "XLMForQuestionAnswering", + "XLMForQuestionAnsweringSimple", + "XLMForSequenceClassification", + "XLMForTokenClassification", + "XLMModel", + "XLMPreTrainedModel", + "XLMWithLMHeadModel", + ] + ) + _import_structure["models.xlm_roberta"].extend( + [ + "XLMRobertaForCausalLM", + "XLMRobertaForMaskedLM", + "XLMRobertaForMultipleChoice", + "XLMRobertaForQuestionAnswering", + "XLMRobertaForSequenceClassification", + "XLMRobertaForTokenClassification", + "XLMRobertaModel", + "XLMRobertaPreTrainedModel", + ] + ) + _import_structure["models.xlm_roberta_xl"].extend( + [ + "XLMRobertaXLForCausalLM", + "XLMRobertaXLForMaskedLM", + "XLMRobertaXLForMultipleChoice", + "XLMRobertaXLForQuestionAnswering", + "XLMRobertaXLForSequenceClassification", + "XLMRobertaXLForTokenClassification", + "XLMRobertaXLModel", + "XLMRobertaXLPreTrainedModel", + ] + ) + _import_structure["models.xlnet"].extend( + [ + "XLNetForMultipleChoice", + "XLNetForQuestionAnswering", + "XLNetForQuestionAnsweringSimple", + "XLNetForSequenceClassification", + "XLNetForTokenClassification", + "XLNetLMHeadModel", + "XLNetModel", + "XLNetPreTrainedModel", + ] + ) + _import_structure["models.xmod"].extend( + [ + "XmodForCausalLM", + "XmodForMaskedLM", + "XmodForMultipleChoice", + "XmodForQuestionAnswering", + "XmodForSequenceClassification", + "XmodForTokenClassification", + "XmodModel", + "XmodPreTrainedModel", + ] + ) + _import_structure["models.yolos"].extend( + [ + "YolosForObjectDetection", + "YolosModel", + "YolosPreTrainedModel", + ] + ) + _import_structure["models.yoso"].extend( + [ + "YosoForMaskedLM", + "YosoForMultipleChoice", + "YosoForQuestionAnswering", + "YosoForSequenceClassification", + "YosoForTokenClassification", + "YosoModel", + "YosoPreTrainedModel", + ] + ) + _import_structure["models.zamba"].extend( + [ + "ZambaForCausalLM", + "ZambaForSequenceClassification", + "ZambaModel", + "ZambaPreTrainedModel", + ] + ) + _import_structure["models.zamba2"].extend( + [ + "Zamba2ForCausalLM", + "Zamba2ForSequenceClassification", + "Zamba2Model", + "Zamba2PreTrainedModel", + ] + ) + _import_structure["models.zoedepth"].extend( + [ + "ZoeDepthForDepthEstimation", + "ZoeDepthPreTrainedModel", + ] + ) + _import_structure["optimization"] = [ + "Adafactor", + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_inverse_sqrt_schedule", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + "get_wsd_schedule", + ] + _import_structure["pytorch_utils"] = [ + "Conv1D", + "apply_chunking_to_forward", + "prune_layer", + ] + _import_structure["sagemaker"] = [] + _import_structure["time_series_utils"] = [] + +try: + if not ( + is_librosa_available() + and is_essentia_available() + and is_scipy_available() + and is_torch_available() + and is_pretty_midi_available() + ): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import ( + dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects, + ) + + _import_structure["utils.dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects"] = [ + name + for name in dir(dummy_essentia_and_librosa_and_pretty_midi_and_scipy_and_torch_objects) + if not name.startswith("_") + ] +else: + _import_structure["models.pop2piano"].append("Pop2PianoFeatureExtractor") + _import_structure["models.pop2piano"].append("Pop2PianoTokenizer") + _import_structure["models.pop2piano"].append("Pop2PianoProcessor") + +try: + if not is_torchaudio_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import ( + dummy_torchaudio_objects, + ) + + _import_structure["utils.dummy_torchaudio_objects"] = [ + name for name in dir(dummy_torchaudio_objects) if not name.startswith("_") + ] +else: + _import_structure["models.musicgen_melody"].append("MusicgenMelodyFeatureExtractor") + _import_structure["models.musicgen_melody"].append("MusicgenMelodyProcessor") + + +# FLAX-backed objects +try: + if not is_flax_available(): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from transformers.utils import dummy_flax_objects + + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_") + ] +else: + _import_structure["generation"].extend( + [ + "FlaxForcedBOSTokenLogitsProcessor", + "FlaxForcedEOSTokenLogitsProcessor", + "FlaxForceTokensLogitsProcessor", + "FlaxGenerationMixin", + "FlaxLogitsProcessor", + "FlaxLogitsProcessorList", + "FlaxLogitsWarper", + "FlaxMinLengthLogitsProcessor", + "FlaxTemperatureLogitsWarper", + "FlaxSuppressTokensAtBeginLogitsProcessor", + "FlaxSuppressTokensLogitsProcessor", + "FlaxTopKLogitsWarper", + "FlaxTopPLogitsWarper", + "FlaxWhisperTimeStampLogitsProcessor", + ] + ) + _import_structure["modeling_flax_outputs"] = [] + _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] + _import_structure["models.albert"].extend( + [ + "FlaxAlbertForMaskedLM", + "FlaxAlbertForMultipleChoice", + "FlaxAlbertForPreTraining", + "FlaxAlbertForQuestionAnswering", + "FlaxAlbertForSequenceClassification", + "FlaxAlbertForTokenClassification", + "FlaxAlbertModel", + "FlaxAlbertPreTrainedModel", + ] + ) + _import_structure["models.auto"].extend( + [ + "FLAX_MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_MASKED_LM_MAPPING", + "FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "FLAX_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "FLAX_MODEL_FOR_PRETRAINING_MAPPING", + "FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", + "FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "FLAX_MODEL_FOR_VISION_2_SEQ_MAPPING", + "FLAX_MODEL_MAPPING", + "FlaxAutoModel", + "FlaxAutoModelForCausalLM", + "FlaxAutoModelForImageClassification", + "FlaxAutoModelForMaskedLM", + "FlaxAutoModelForMultipleChoice", + "FlaxAutoModelForNextSentencePrediction", + "FlaxAutoModelForPreTraining", + "FlaxAutoModelForQuestionAnswering", + "FlaxAutoModelForSeq2SeqLM", + "FlaxAutoModelForSequenceClassification", + "FlaxAutoModelForSpeechSeq2Seq", + "FlaxAutoModelForTokenClassification", + "FlaxAutoModelForVision2Seq", + ] + ) + + # Flax models structure + + _import_structure["models.bart"].extend( + [ + "FlaxBartDecoderPreTrainedModel", + "FlaxBartForCausalLM", + "FlaxBartForConditionalGeneration", + "FlaxBartForQuestionAnswering", + "FlaxBartForSequenceClassification", + "FlaxBartModel", + "FlaxBartPreTrainedModel", + ] + ) + _import_structure["models.beit"].extend( + [ + "FlaxBeitForImageClassification", + "FlaxBeitForMaskedImageModeling", + "FlaxBeitModel", + "FlaxBeitPreTrainedModel", + ] + ) + + _import_structure["models.bert"].extend( + [ + "FlaxBertForCausalLM", + "FlaxBertForMaskedLM", + "FlaxBertForMultipleChoice", + "FlaxBertForNextSentencePrediction", + "FlaxBertForPreTraining", + "FlaxBertForQuestionAnswering", + "FlaxBertForSequenceClassification", + "FlaxBertForTokenClassification", + "FlaxBertModel", + "FlaxBertPreTrainedModel", + ] + ) + _import_structure["models.big_bird"].extend( + [ + "FlaxBigBirdForCausalLM", + "FlaxBigBirdForMaskedLM", + "FlaxBigBirdForMultipleChoice", + "FlaxBigBirdForPreTraining", + "FlaxBigBirdForQuestionAnswering", + "FlaxBigBirdForSequenceClassification", + "FlaxBigBirdForTokenClassification", + "FlaxBigBirdModel", + "FlaxBigBirdPreTrainedModel", + ] + ) + _import_structure["models.blenderbot"].extend( + [ + "FlaxBlenderbotForConditionalGeneration", + "FlaxBlenderbotModel", + "FlaxBlenderbotPreTrainedModel", + ] + ) + _import_structure["models.blenderbot_small"].extend( + [ + "FlaxBlenderbotSmallForConditionalGeneration", + "FlaxBlenderbotSmallModel", + "FlaxBlenderbotSmallPreTrainedModel", + ] + ) + _import_structure["models.bloom"].extend( + [ + "FlaxBloomForCausalLM", + "FlaxBloomModel", + "FlaxBloomPreTrainedModel", + ] + ) + _import_structure["models.clip"].extend( + [ + "FlaxCLIPModel", + "FlaxCLIPPreTrainedModel", + "FlaxCLIPTextModel", + "FlaxCLIPTextPreTrainedModel", + "FlaxCLIPTextModelWithProjection", + "FlaxCLIPVisionModel", + "FlaxCLIPVisionPreTrainedModel", + ] + ) + _import_structure["models.dinov2"].extend( + [ + "FlaxDinov2Model", + "FlaxDinov2ForImageClassification", + "FlaxDinov2PreTrainedModel", + ] + ) + _import_structure["models.distilbert"].extend( + [ + "FlaxDistilBertForMaskedLM", + "FlaxDistilBertForMultipleChoice", + "FlaxDistilBertForQuestionAnswering", + "FlaxDistilBertForSequenceClassification", + "FlaxDistilBertForTokenClassification", + "FlaxDistilBertModel", + "FlaxDistilBertPreTrainedModel", + ] + ) + _import_structure["models.electra"].extend( + [ + "FlaxElectraForCausalLM", + "FlaxElectraForMaskedLM", + "FlaxElectraForMultipleChoice", + "FlaxElectraForPreTraining", + "FlaxElectraForQuestionAnswering", + "FlaxElectraForSequenceClassification", + "FlaxElectraForTokenClassification", + "FlaxElectraModel", + "FlaxElectraPreTrainedModel", + ] + ) + _import_structure["models.encoder_decoder"].append("FlaxEncoderDecoderModel") + _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]) + _import_structure["models.gpt_neo"].extend( + ["FlaxGPTNeoForCausalLM", "FlaxGPTNeoModel", "FlaxGPTNeoPreTrainedModel"] + ) + _import_structure["models.gptj"].extend(["FlaxGPTJForCausalLM", "FlaxGPTJModel", "FlaxGPTJPreTrainedModel"]) + _import_structure["models.llama"].extend(["FlaxLlamaForCausalLM", "FlaxLlamaModel", "FlaxLlamaPreTrainedModel"]) + _import_structure["models.gemma"].extend(["FlaxGemmaForCausalLM", "FlaxGemmaModel", "FlaxGemmaPreTrainedModel"]) + _import_structure["models.longt5"].extend( + [ + "FlaxLongT5ForConditionalGeneration", + "FlaxLongT5Model", + "FlaxLongT5PreTrainedModel", + ] + ) + _import_structure["models.marian"].extend( + [ + "FlaxMarianModel", + "FlaxMarianMTModel", + "FlaxMarianPreTrainedModel", + ] + ) + _import_structure["models.mbart"].extend( + [ + "FlaxMBartForConditionalGeneration", + "FlaxMBartForQuestionAnswering", + "FlaxMBartForSequenceClassification", + "FlaxMBartModel", + "FlaxMBartPreTrainedModel", + ] + ) + _import_structure["models.mistral"].extend( + [ + "FlaxMistralForCausalLM", + "FlaxMistralModel", + "FlaxMistralPreTrainedModel", + ] + ) + _import_structure["models.mt5"].extend(["FlaxMT5EncoderModel", "FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]) + _import_structure["models.opt"].extend( + [ + "FlaxOPTForCausalLM", + "FlaxOPTModel", + "FlaxOPTPreTrainedModel", + ] + ) + _import_structure["models.pegasus"].extend( + [ + "FlaxPegasusForConditionalGeneration", + "FlaxPegasusModel", + "FlaxPegasusPreTrainedModel", + ] + ) + _import_structure["models.regnet"].extend( + [ + "FlaxRegNetForImageClassification", + "FlaxRegNetModel", + "FlaxRegNetPreTrainedModel", + ] + ) + _import_structure["models.resnet"].extend( + [ + "FlaxResNetForImageClassification", + "FlaxResNetModel", + "FlaxResNetPreTrainedModel", + ] + ) + _import_structure["models.roberta"].extend( + [ + "FlaxRobertaForCausalLM", + "FlaxRobertaForMaskedLM", + "FlaxRobertaForMultipleChoice", + "FlaxRobertaForQuestionAnswering", + "FlaxRobertaForSequenceClassification", + "FlaxRobertaForTokenClassification", + "FlaxRobertaModel", + "FlaxRobertaPreTrainedModel", + ] + ) + _import_structure["models.roberta_prelayernorm"].extend( + [ + "FlaxRobertaPreLayerNormForCausalLM", + "FlaxRobertaPreLayerNormForMaskedLM", + "FlaxRobertaPreLayerNormForMultipleChoice", + "FlaxRobertaPreLayerNormForQuestionAnswering", + "FlaxRobertaPreLayerNormForSequenceClassification", + "FlaxRobertaPreLayerNormForTokenClassification", + "FlaxRobertaPreLayerNormModel", + "FlaxRobertaPreLayerNormPreTrainedModel", + ] + ) + _import_structure["models.roformer"].extend( + [ + "FlaxRoFormerForMaskedLM", + "FlaxRoFormerForMultipleChoice", + "FlaxRoFormerForQuestionAnswering", + "FlaxRoFormerForSequenceClassification", + "FlaxRoFormerForTokenClassification", + "FlaxRoFormerModel", + "FlaxRoFormerPreTrainedModel", + ] + ) + _import_structure["models.speech_encoder_decoder"].append("FlaxSpeechEncoderDecoderModel") + _import_structure["models.t5"].extend( + [ + "FlaxT5EncoderModel", + "FlaxT5ForConditionalGeneration", + "FlaxT5Model", + "FlaxT5PreTrainedModel", + ] + ) + _import_structure["models.vision_encoder_decoder"].append("FlaxVisionEncoderDecoderModel") + _import_structure["models.vision_text_dual_encoder"].extend(["FlaxVisionTextDualEncoderModel"]) + _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) + _import_structure["models.wav2vec2"].extend( + [ + "FlaxWav2Vec2ForCTC", + "FlaxWav2Vec2ForPreTraining", + "FlaxWav2Vec2Model", + "FlaxWav2Vec2PreTrainedModel", + ] + ) + _import_structure["models.whisper"].extend( + [ + "FlaxWhisperForConditionalGeneration", + "FlaxWhisperModel", + "FlaxWhisperPreTrainedModel", + "FlaxWhisperForAudioClassification", + ] + ) + _import_structure["models.xglm"].extend( + [ + "FlaxXGLMForCausalLM", + "FlaxXGLMModel", + "FlaxXGLMPreTrainedModel", + ] + ) + _import_structure["models.xlm_roberta"].extend( + [ + "FlaxXLMRobertaForMaskedLM", + "FlaxXLMRobertaForMultipleChoice", + "FlaxXLMRobertaForQuestionAnswering", + "FlaxXLMRobertaForSequenceClassification", + "FlaxXLMRobertaForTokenClassification", + "FlaxXLMRobertaModel", + "FlaxXLMRobertaForCausalLM", + "FlaxXLMRobertaPreTrainedModel", + ] + ) + +sys.modules[__name__] = _LazyModule( + 'transformers', + transformers.__file__, + _import_structure, + module_spec=__spec__, + extra_objects={"__version__": transformers.__version__}, +) diff --git a/mindnlp/transformers/models/__init__.py b/mindnlp/transformers/models/__init__.py deleted file mode 100644 index 079dc6b93..000000000 --- a/mindnlp/transformers/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import auto -from .auto import * diff --git a/mindnlp/transformers/models/auto.py b/mindnlp/transformers/models/auto.py deleted file mode 100644 index 500610755..000000000 --- a/mindnlp/transformers/models/auto.py +++ /dev/null @@ -1,16 +0,0 @@ - -from transformers.models.auto import modeling_auto -from transformers.models.auto import configuration_auto -from transformers.models.auto import feature_extraction_auto -from transformers.models.auto import image_processing_auto -from transformers.models.auto import processing_auto -from transformers.models.auto import tokenization_auto -from transformers.models.auto import auto_factory - -from transformers.models.auto.modeling_auto import * -from transformers.models.auto.configuration_auto import * -from transformers.models.auto.feature_extraction_auto import * -from transformers.models.auto.image_processing_auto import * -from transformers.models.auto.processing_auto import * -from transformers.models.auto.tokenization_auto import * -from transformers.models.auto.auto_factory import * diff --git a/mindnlp/transformers/models/auto_bk/__init__.py b/mindnlp/transformers/models/auto_bk/__init__.py deleted file mode 100644 index b481626c3..000000000 --- a/mindnlp/transformers/models/auto_bk/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from transformers.models import auto -from transformers.models.auto import configuration_auto -from transformers.models.auto import feature_extraction_auto -from transformers.models.auto import image_processing_auto -from transformers.models.auto import processing_auto -from transformers.models.auto import tokenization_auto - -from transformers.models.auto.configuration_auto import * -from transformers.models.auto.feature_extraction_auto import * -from transformers.models.auto.image_processing_auto import * -from transformers.models.auto.processing_auto import * -from transformers.models.auto.tokenization_auto import * - -from . import modeling_auto - -from .auto_factory import * -from .modeling_auto import * - -__all__ = [] -__all__.extend(auto.__all__) -__all__.extend(modeling_auto.__all__) diff --git a/mindnlp/transformers/models/auto_bk/auto_factory.py b/mindnlp/transformers/models/auto_bk/auto_factory.py deleted file mode 100644 index 4ec822205..000000000 --- a/mindnlp/transformers/models/auto_bk/auto_factory.py +++ /dev/null @@ -1,772 +0,0 @@ -# coding=utf-8 -# Copyright 2021 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Factory function to build auto-model classes.""" - -import copy -import importlib -import json -import os -import warnings -from typing import Any, TypeVar, Union - -from transformers.configuration_utils import PretrainedConfig -from transformers.dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code -from transformers.utils import ( - CONFIG_NAME, - cached_file, - copy_func, - extract_commit_hash, - find_adapter_config_file, - is_peft_available, - is_torch_available, - logging, - requires_backends, -) -from transformers.models.auto.configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings - - -if is_torch_available(): - from transformers.generation import GenerationMixin - - -logger = logging.get_logger(__name__) - -_T = TypeVar("_T") -# Tokenizers will depend on packages installed, too much variance and there are no common base or Protocol -_LazyAutoMappingValue = tuple[Union[type[Any], None], Union[type[Any], None]] - -CLASS_DOCSTRING = """ - This is a generic model class that will be instantiated as one of the model classes of the library when created - with the [`~BaseAutoModelClass.from_pretrained`] class method or the [`~BaseAutoModelClass.from_config`] class - method. - - This class cannot be instantiated directly using `__init__()` (throws an error). -""" - -FROM_CONFIG_DOCSTRING = """ - Instantiates one of the model classes of the library from a configuration. - - Note: - Loading a model from its configuration file does **not** load the model weights. It only affects the - model's configuration. Use [`~BaseAutoModelClass.from_pretrained`] to load the model weights. - - Args: - config ([`PretrainedConfig`]): - The model class to instantiate is selected based on the configuration class: - - List options - attn_implementation (`str`, *optional*): - The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), or `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download configuration from huggingface.co and cache. - >>> config = AutoConfig.from_pretrained("checkpoint_placeholder") - >>> model = BaseAutoModelClass.from_config(config) - ``` -""" - -FROM_PRETRAINED_TORCH_DOCSTRING = """ - Instantiate one of the model classes of the library from a pretrained model. - - The model class to instantiate is selected based on the `model_type` property of the config object (either - passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by - falling back to using pattern matching on `pretrained_model_name_or_path`: - - List options - - The model is set in evaluation mode by default using `model.eval()` (so for instance, dropout modules are - deactivated). To train the model, you should first set it back in training mode with `model.train()` - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In - this case, `from_tf` should be set to `True` and a configuration object should be provided as - `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a - PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards. - model_args (additional positional arguments, *optional*): - Will be passed along to the underlying model `__init__()` method. - config ([`PretrainedConfig`], *optional*): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - state_dict (*dict[str, torch.Tensor]*, *optional*): - A state dictionary to use instead of a state dictionary loaded from saved weights file. - - This option can be used if you want to create a model from a pretrained configuration but load your own - weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and - [`~PreTrainedModel.from_pretrained`] is not a simpler option. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_tf (`bool`, *optional*, defaults to `False`): - Load the model weights from a TensorFlow checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - code_revision (`str`, *optional*, defaults to `"main"`): - The specific revision to use for the code on the Hub, if the code leaves in a different repository than - the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based - system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier - allowed by git. - kwargs (additional keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") - - >>> # Update configuration during loading - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower) - >>> config = AutoConfig.from_pretrained("./tf_model/shortcut_placeholder_tf_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./tf_model/shortcut_placeholder_tf_checkpoint.ckpt.index", from_tf=True, config=config - ... ) - ``` -""" - -FROM_PRETRAINED_TF_DOCSTRING = """ - Instantiate one of the model classes of the library from a pretrained model. - - The model class to instantiate is selected based on the `model_type` property of the config object (either - passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by - falling back to using pattern matching on `pretrained_model_name_or_path`: - - List options - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this - case, `from_pt` should be set to `True` and a configuration object should be provided as `config` - argument. This loading path is slower than converting the PyTorch model in a TensorFlow model - using the provided conversion scripts and loading the TensorFlow model afterwards. - model_args (additional positional arguments, *optional*): - Will be passed along to the underlying model `__init__()` method. - config ([`PretrainedConfig`], *optional*): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - code_revision (`str`, *optional*, defaults to `"main"`): - The specific revision to use for the code on the Hub, if the code leaves in a different repository than - the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based - system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier - allowed by git. - kwargs (additional keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") - - >>> # Update configuration during loading - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config - ... ) - ``` -""" - -FROM_PRETRAINED_FLAX_DOCSTRING = """ - Instantiate one of the model classes of the library from a pretrained model. - - The model class to instantiate is selected based on the `model_type` property of the config object (either - passed as an argument or loaded from `pretrained_model_name_or_path` if possible), or when it's missing, by - falling back to using pattern matching on `pretrained_model_name_or_path`: - - List options - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - Can be either: - - - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. - - A path to a *directory* containing model weights saved using - [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - - A path or url to a *PyTorch state_dict save file* (e.g, `./pt_model/pytorch_model.bin`). In this - case, `from_pt` should be set to `True` and a configuration object should be provided as `config` - argument. This loading path is slower than converting the PyTorch model in a TensorFlow model - using the provided conversion scripts and loading the TensorFlow model afterwards. - model_args (additional positional arguments, *optional*): - Will be passed along to the underlying model `__init__()` method. - config ([`PretrainedConfig`], *optional*): - Configuration for the model to use instead of an automatically loaded configuration. Configuration can - be automatically loaded when: - - - The model is a model provided by the library (loaded with the *model id* string of a pretrained - model). - - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the - save directory. - - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a - configuration JSON file named *config.json* is found in the directory. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - from_pt (`bool`, *optional*, defaults to `False`): - Load the model weights from a PyTorch checkpoint save file (see docstring of - `pretrained_model_name_or_path` argument). - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force the (re-)download of the model weights and configuration files, overriding the - cached versions if they exist. - resume_download: - Deprecated and ignored. All downloads are now resumed by default when possible. - Will be removed in v5 of Transformers. - proxies (`dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. - output_loading_info(`bool`, *optional*, defaults to `False`): - Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. - local_files_only(`bool`, *optional*, defaults to `False`): - Whether or not to only look at local files (e.g., not try downloading the model). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - trust_remote_code (`bool`, *optional*, defaults to `False`): - Whether or not to allow for custom models defined on the Hub in their own modeling files. This option - should only be set to `True` for repositories you trust and in which you have read the code, as it will - execute code present on the Hub on your local machine. - code_revision (`str`, *optional*, defaults to `"main"`): - The specific revision to use for the code on the Hub, if the code leaves in a different repository than - the rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based - system for storing models and other artifacts on huggingface.co, so `revision` can be any identifier - allowed by git. - kwargs (additional keyword arguments, *optional*): - Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., - `output_attentions=True`). Behaves differently depending on whether a `config` is provided or - automatically loaded: - - - If a configuration is provided with `config`, `**kwargs` will be directly passed to the - underlying model's `__init__` method (we assume all relevant updates to the configuration have - already been done) - - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that - corresponds to a configuration attribute will be used to override said attribute with the - supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute - will be passed to the underlying model's `__init__` function. - - Examples: - - ```python - >>> from transformers import AutoConfig, BaseAutoModelClass - - >>> # Download model and configuration from huggingface.co and cache. - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder") - - >>> # Update configuration during loading - >>> model = BaseAutoModelClass.from_pretrained("checkpoint_placeholder", output_attentions=True) - >>> model.config.output_attentions - True - - >>> # Loading from a PyTorch checkpoint file instead of a TensorFlow model (slower) - >>> config = AutoConfig.from_pretrained("./pt_model/shortcut_placeholder_pt_model_config.json") - >>> model = BaseAutoModelClass.from_pretrained( - ... "./pt_model/shortcut_placeholder_pytorch_model.bin", from_pt=True, config=config - ... ) - ``` -""" - - -def _get_model_class(config, model_mapping): - supported_models = model_mapping[type(config)] - if not isinstance(supported_models, (list, tuple)): - return supported_models - - name_to_model = {model.__name__: model for model in supported_models} - architectures = getattr(config, "architectures", []) - for arch in architectures: - if arch in name_to_model: - return name_to_model[arch] - elif f"TF{arch}" in name_to_model: - return name_to_model[f"TF{arch}"] - elif f"Flax{arch}" in name_to_model: - return name_to_model[f"Flax{arch}"] - - # If not architecture is set in the config or match the supported models, the first element of the tuple is the - # defaults. - return supported_models[0] - - -class _BaseAutoModelClass: - # Base class for auto models. - _model_mapping = None - - def __init__(self, *args, **kwargs) -> None: - raise OSError( - f"{self.__class__.__name__} is designed to be instantiated " - f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " - f"`{self.__class__.__name__}.from_config(config)` methods." - ) - - @classmethod - def from_config(cls, config, **kwargs): - trust_remote_code = kwargs.pop("trust_remote_code", None) - has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map - has_local_code = type(config) in cls._model_mapping.keys() - if has_remote_code: - class_ref = config.auto_map[cls.__name__] - if "--" in class_ref: - upstream_repo = class_ref.split("--")[0] - else: - upstream_repo = None - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, config._name_or_path, has_local_code, has_remote_code, upstream_repo=upstream_repo - ) - - if has_remote_code and trust_remote_code: - if "--" in class_ref: - repo_id, class_ref = class_ref.split("--") - else: - repo_id = config.name_or_path - model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) - # This block handles the case where the user is loading a model with `trust_remote_code=True` - # but a library model exists with the same name. We don't want to override the autoclass - # mappings in this case, or all future loads of that model will be the remote code model. - if not has_local_code: - cls.register(config.__class__, model_class, exist_ok=True) - model_class.register_for_auto_class(auto_class=cls) - _ = kwargs.pop("code_revision", None) - model_class = add_generation_mixin_to_remote_model(model_class) - return model_class._from_config(config, **kwargs) - elif type(config) in cls._model_mapping.keys(): - model_class = _get_model_class(config, cls._model_mapping) - return model_class._from_config(config, **kwargs) - - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." - ) - - @classmethod - def _prepare_config_for_auto_class(cls, config: PretrainedConfig) -> PretrainedConfig: - """Additional autoclass-specific config post-loading manipulation. May be overridden in subclasses.""" - return config - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike[str]], *model_args, **kwargs): - config = kwargs.pop("config", None) - trust_remote_code = kwargs.get("trust_remote_code", None) - kwargs["_from_auto"] = True - hub_kwargs_names = [ - "cache_dir", - "force_download", - "local_files_only", - "proxies", - "resume_download", - "revision", - "subfolder", - "use_auth_token", - "token", - ] - hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} - code_revision = kwargs.pop("code_revision", None) - commit_hash = kwargs.pop("_commit_hash", None) - adapter_kwargs = kwargs.pop("adapter_kwargs", None) - - token = hub_kwargs.pop("token", None) - use_auth_token = hub_kwargs.pop("use_auth_token", None) - if use_auth_token is not None: - warnings.warn( - "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.", - FutureWarning, - ) - if token is not None: - raise ValueError( - "`token` and `use_auth_token` are both specified. Please set only the argument `token`." - ) - token = use_auth_token - - if token is not None: - hub_kwargs["token"] = token - - if commit_hash is None: - if not isinstance(config, PretrainedConfig): - # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible - resolved_config_file = cached_file( - pretrained_model_name_or_path, - CONFIG_NAME, - _raise_exceptions_for_gated_repo=False, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - **hub_kwargs, - ) - commit_hash = extract_commit_hash(resolved_config_file, commit_hash) - else: - commit_hash = getattr(config, "_commit_hash", None) - - if is_peft_available(): - if adapter_kwargs is None: - adapter_kwargs = {} - if token is not None: - adapter_kwargs["token"] = token - - maybe_adapter_path = find_adapter_config_file( - pretrained_model_name_or_path, _commit_hash=commit_hash, **adapter_kwargs - ) - - if maybe_adapter_path is not None: - with open(maybe_adapter_path, "r", encoding="utf-8") as f: - adapter_config = json.load(f) - - adapter_kwargs["_adapter_model_path"] = pretrained_model_name_or_path - pretrained_model_name_or_path = adapter_config["base_model_name_or_path"] - - if not isinstance(config, PretrainedConfig): - kwargs_orig = copy.deepcopy(kwargs) - # ensure not to pollute the config object with torch_dtype="auto" - since it's - # meaningless in the context of the config object - torch.dtype values are acceptable - if kwargs.get("torch_dtype", None) == "auto": - _ = kwargs.pop("torch_dtype") - # to not overwrite the quantization_config if config has a quantization_config - if kwargs.get("quantization_config", None) is not None: - _ = kwargs.pop("quantization_config") - - config, kwargs = AutoConfig.from_pretrained( - pretrained_model_name_or_path, - return_unused_kwargs=True, - code_revision=code_revision, - _commit_hash=commit_hash, - **hub_kwargs, - **kwargs, - ) - - # if torch_dtype=auto was passed here, ensure to pass it on - if kwargs_orig.get("torch_dtype", None) == "auto": - kwargs["torch_dtype"] = "auto" - if kwargs_orig.get("quantization_config", None) is not None: - kwargs["quantization_config"] = kwargs_orig["quantization_config"] - - has_remote_code = hasattr(config, "auto_map") and cls.__name__ in config.auto_map - has_local_code = type(config) in cls._model_mapping.keys() - upstream_repo = None - if has_remote_code: - class_ref = config.auto_map[cls.__name__] - if "--" in class_ref: - upstream_repo = class_ref.split("--")[0] - trust_remote_code = resolve_trust_remote_code( - trust_remote_code, - pretrained_model_name_or_path, - has_local_code, - has_remote_code, - ) - kwargs["trust_remote_code"] = trust_remote_code - - # Set the adapter kwargs - kwargs["adapter_kwargs"] = adapter_kwargs - - if has_remote_code and trust_remote_code: - model_class = get_class_from_dynamic_module( - class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs - ) - _ = hub_kwargs.pop("code_revision", None) - # This block handles the case where the user is loading a model with `trust_remote_code=True` - # but a library model exists with the same name. We don't want to override the autoclass - # mappings in this case, or all future loads of that model will be the remote code model. - if not has_local_code: - cls.register(config.__class__, model_class, exist_ok=True) - model_class.register_for_auto_class(auto_class=cls) - model_class = add_generation_mixin_to_remote_model(model_class) - return model_class.from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs - ) - elif type(config) in cls._model_mapping.keys(): - model_class = _get_model_class(config, cls._model_mapping) - if model_class.config_class == config.sub_configs.get("text_config", None): - config = config.get_text_config() - return model_class.from_pretrained( - pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs - ) - raise ValueError( - f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" - f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." - ) - - @classmethod - def register(cls, config_class, model_class, exist_ok=False) -> None: - """ - Register a new model for this class. - - Args: - config_class ([`PretrainedConfig`]): - The configuration corresponding to the model to register. - model_class ([`PreTrainedModel`]): - The model to register. - """ - if hasattr(model_class, "config_class") and model_class.config_class.__name__ != config_class.__name__: - raise ValueError( - "The model class you are passing has a `config_class` attribute that is not consistent with the " - f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix " - "one of those so they match!" - ) - cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok) - - -class _BaseAutoBackboneClass(_BaseAutoModelClass): - # Base class for auto backbone models. - _model_mapping = None - - @classmethod - def _load_timm_backbone_from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - requires_backends(cls, ["vision", "timm"]) - from ...models.timm_backbone import TimmBackboneConfig - - config = kwargs.pop("config", TimmBackboneConfig()) - - if kwargs.get("out_features", None) is not None: - raise ValueError("Cannot specify `out_features` for timm backbones") - - if kwargs.get("output_loading_info", False): - raise ValueError("Cannot specify `output_loading_info=True` when loading from timm") - - num_channels = kwargs.pop("num_channels", config.num_channels) - features_only = kwargs.pop("features_only", config.features_only) - use_pretrained_backbone = kwargs.pop("use_pretrained_backbone", config.use_pretrained_backbone) - out_indices = kwargs.pop("out_indices", config.out_indices) - config = TimmBackboneConfig( - backbone=pretrained_model_name_or_path, - num_channels=num_channels, - features_only=features_only, - use_pretrained_backbone=use_pretrained_backbone, - out_indices=out_indices, - ) - return super().from_config(config, **kwargs) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - use_timm_backbone = kwargs.pop("use_timm_backbone", False) - if use_timm_backbone: - return cls._load_timm_backbone_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - -def insert_head_doc(docstring, head_doc: str = ""): - if len(head_doc) > 0: - return docstring.replace( - "one of the model classes of the library ", - f"one of the model classes of the library (with a {head_doc} head) ", - ) - return docstring.replace( - "one of the model classes of the library ", "one of the base model classes of the library " - ) - - -def auto_class_update(cls, checkpoint_for_example: str = "google-bert/bert-base-cased", head_doc: str = ""): - # Create a new class with the right name from the base class - model_mapping = cls._model_mapping - name = cls.__name__ - class_docstring = insert_head_doc(CLASS_DOCSTRING, head_doc=head_doc) - cls.__doc__ = class_docstring.replace("BaseAutoModelClass", name) - - # Now we need to copy and re-register `from_config` and `from_pretrained` as class methods otherwise we can't - # have a specific docstrings for them. - from_config = copy_func(_BaseAutoModelClass.from_config) - from_config_docstring = insert_head_doc(FROM_CONFIG_DOCSTRING, head_doc=head_doc) - from_config_docstring = from_config_docstring.replace("BaseAutoModelClass", name) - from_config_docstring = from_config_docstring.replace("checkpoint_placeholder", checkpoint_for_example) - from_config.__doc__ = from_config_docstring - from_config = replace_list_option_in_docstrings(model_mapping._model_mapping, use_model_types=False)(from_config) - cls.from_config = classmethod(from_config) - - if name.startswith("TF"): - from_pretrained_docstring = FROM_PRETRAINED_TF_DOCSTRING - elif name.startswith("Flax"): - from_pretrained_docstring = FROM_PRETRAINED_FLAX_DOCSTRING - else: - from_pretrained_docstring = FROM_PRETRAINED_TORCH_DOCSTRING - from_pretrained = copy_func(_BaseAutoModelClass.from_pretrained) - from_pretrained_docstring = insert_head_doc(from_pretrained_docstring, head_doc=head_doc) - from_pretrained_docstring = from_pretrained_docstring.replace("BaseAutoModelClass", name) - from_pretrained_docstring = from_pretrained_docstring.replace("checkpoint_placeholder", checkpoint_for_example) - shortcut = checkpoint_for_example.split("/")[-1].split("-")[0] - from_pretrained_docstring = from_pretrained_docstring.replace("shortcut_placeholder", shortcut) - from_pretrained.__doc__ = from_pretrained_docstring - from_pretrained = replace_list_option_in_docstrings(model_mapping._model_mapping)(from_pretrained) - cls.from_pretrained = classmethod(from_pretrained) - return cls - - -def get_values(model_mapping): - result = [] - for model in model_mapping.values(): - if isinstance(model, (list, tuple)): - result += list(model) - else: - result.append(model) - - return result - - -def getattribute_from_module(module, attr): - if attr is None: - return None - if isinstance(attr, tuple): - return tuple(getattribute_from_module(module, a) for a in attr) - if hasattr(module, attr): - return getattr(module, attr) - # Some of the mappings have entries model_type -> object of another model type. In that case we try to grab the - # object at the top level. - transformers_module = importlib.import_module("transformers") - - if module != transformers_module: - try: - return getattribute_from_module(transformers_module, attr) - except ValueError: - raise ValueError(f"Could not find {attr} neither in {module} nor in {transformers_module}!") - else: - raise ValueError(f"Could not find {attr} in {transformers_module}!") - - -def add_generation_mixin_to_remote_model(model_class): - """ - Adds `GenerationMixin` to the inheritance of `model_class`, if `model_class` is a PyTorch model. - - This function is used for backwards compatibility purposes: in v4.45, we've started a deprecation cycle to make - `PreTrainedModel` stop inheriting from `GenerationMixin`. Without this function, older models dynamically loaded - from the Hub may not have the `generate` method after we remove the inheritance. - """ - # 1. If it is not a PT model (i.e. doesn't inherit Module), do nothing - if "torch.nn.modules.module.Module" not in str(model_class.__mro__): - return model_class - - # 2. If it already **directly** inherits from GenerationMixin, do nothing - if "GenerationMixin" in str(model_class.__bases__): - return model_class - - # 3. Prior to v4.45, we could detect whether a model was `generate`-compatible if it had its own `generate` and/or - # `prepare_inputs_for_generation` method. - has_custom_generate_in_class = hasattr(model_class, "generate") and "GenerationMixin" not in str( - getattr(model_class, "generate") - ) - has_custom_prepare_inputs = hasattr(model_class, "prepare_inputs_for_generation") and "GenerationMixin" not in str( - getattr(model_class, "prepare_inputs_for_generation") - ) - if has_custom_generate_in_class or has_custom_prepare_inputs: - model_class_with_generation_mixin = type( - model_class.__name__, (model_class, GenerationMixin), {**model_class.__dict__} - ) - return model_class_with_generation_mixin - return model_class - -__all__ = ["get_values"] \ No newline at end of file diff --git a/mindnlp/transformers/models/auto_bk/modeling_auto.py b/mindnlp/transformers/models/auto_bk/modeling_auto.py deleted file mode 100644 index 2dce59859..000000000 --- a/mindnlp/transformers/models/auto_bk/modeling_auto.py +++ /dev/null @@ -1,448 +0,0 @@ -import warnings - -from transformers.models.auto.modeling_auto import ( - MODEL_FOR_MASK_GENERATION_MAPPING, - MODEL_FOR_KEYPOINT_DETECTION_MAPPING, - MODEL_FOR_TEXT_ENCODING_MAPPING, - MODEL_FOR_IMAGE_TO_IMAGE_MAPPING, - MODEL_MAPPING, - MODEL_FOR_PRETRAINING_MAPPING, - MODEL_WITH_LM_HEAD_MAPPING, - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - MODEL_FOR_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING, - MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING, - MODEL_FOR_IMAGE_SEGMENTATION_MAPPING, - MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING, - MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING, - MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING, - MODEL_FOR_OBJECT_DETECTION_MAPPING, - MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING, - MODEL_FOR_DEPTH_ESTIMATION_MAPPING, - MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING, - MODEL_FOR_VISION_2_SEQ_MAPPING, - MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING, - MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING, - MODEL_FOR_CTC_MAPPING, - MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING, - MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING, - MODEL_FOR_AUDIO_XVECTOR_MAPPING, - MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING, - MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING, - MODEL_FOR_BACKBONE_MAPPING, - MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING, - MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING, - MODEL_FOR_IMAGE_MAPPING, - MODEL_FOR_RETRIEVAL_MAPPING, - MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING, - MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING, -) -from .auto_factory import ( - _BaseAutoBackboneClass, - _BaseAutoModelClass, - auto_class_update, -) - -class AutoModelForMaskGeneration(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_MASK_GENERATION_MAPPING - - -class AutoModelForKeypointDetection(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_KEYPOINT_DETECTION_MAPPING - - -class AutoModelForTextEncoding(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_TEXT_ENCODING_MAPPING - - -class AutoModelForImageToImage(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_IMAGE_TO_IMAGE_MAPPING - - -class AutoModel(_BaseAutoModelClass): - _model_mapping = MODEL_MAPPING - - -AutoModel = auto_class_update(AutoModel) - - -class AutoModelForPreTraining(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_PRETRAINING_MAPPING - - -AutoModelForPreTraining = auto_class_update(AutoModelForPreTraining, head_doc="pretraining") - - -# Private on purpose, the public class will add the deprecation warnings. -class _AutoModelWithLMHead(_BaseAutoModelClass): - _model_mapping = MODEL_WITH_LM_HEAD_MAPPING - - -_AutoModelWithLMHead = auto_class_update(_AutoModelWithLMHead, head_doc="language modeling") - - -class AutoModelForCausalLM(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_CAUSAL_LM_MAPPING - - -AutoModelForCausalLM = auto_class_update(AutoModelForCausalLM, head_doc="causal language modeling") - - -class AutoModelForMaskedLM(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_MASKED_LM_MAPPING - - -AutoModelForMaskedLM = auto_class_update(AutoModelForMaskedLM, head_doc="masked language modeling") - - -class AutoModelForSeq2SeqLM(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING - - -AutoModelForSeq2SeqLM = auto_class_update( - AutoModelForSeq2SeqLM, - head_doc="sequence-to-sequence language modeling", - checkpoint_for_example="google-t5/t5-base", -) - - -class AutoModelForSequenceClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING - - -AutoModelForSequenceClassification = auto_class_update( - AutoModelForSequenceClassification, head_doc="sequence classification" -) - - -class AutoModelForQuestionAnswering(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING - - -AutoModelForQuestionAnswering = auto_class_update(AutoModelForQuestionAnswering, head_doc="question answering") - - -class AutoModelForTableQuestionAnswering(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING - - -AutoModelForTableQuestionAnswering = auto_class_update( - AutoModelForTableQuestionAnswering, - head_doc="table question answering", - checkpoint_for_example="google/tapas-base-finetuned-wtq", -) - - -class AutoModelForVisualQuestionAnswering(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING - - -AutoModelForVisualQuestionAnswering = auto_class_update( - AutoModelForVisualQuestionAnswering, - head_doc="visual question answering", - checkpoint_for_example="dandelin/vilt-b32-finetuned-vqa", -) - - -class AutoModelForDocumentQuestionAnswering(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING - - -AutoModelForDocumentQuestionAnswering = auto_class_update( - AutoModelForDocumentQuestionAnswering, - head_doc="document question answering", - checkpoint_for_example='impira/layoutlm-document-qa", revision="52e01b3', -) - - -class AutoModelForTokenClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING - - -AutoModelForTokenClassification = auto_class_update(AutoModelForTokenClassification, head_doc="token classification") - - -class AutoModelForMultipleChoice(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_MULTIPLE_CHOICE_MAPPING - - -AutoModelForMultipleChoice = auto_class_update(AutoModelForMultipleChoice, head_doc="multiple choice") - - -class AutoModelForNextSentencePrediction(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING - - -AutoModelForNextSentencePrediction = auto_class_update( - AutoModelForNextSentencePrediction, head_doc="next sentence prediction" -) - - -class AutoModelForImageClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING - - -AutoModelForImageClassification = auto_class_update(AutoModelForImageClassification, head_doc="image classification") - - -class AutoModelForZeroShotImageClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING - - -AutoModelForZeroShotImageClassification = auto_class_update( - AutoModelForZeroShotImageClassification, head_doc="zero-shot image classification" -) - - -class AutoModelForImageSegmentation(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_IMAGE_SEGMENTATION_MAPPING - - -AutoModelForImageSegmentation = auto_class_update(AutoModelForImageSegmentation, head_doc="image segmentation") - - -class AutoModelForSemanticSegmentation(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING - - -AutoModelForSemanticSegmentation = auto_class_update( - AutoModelForSemanticSegmentation, head_doc="semantic segmentation" -) - - -class AutoModelForUniversalSegmentation(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING - - -AutoModelForUniversalSegmentation = auto_class_update( - AutoModelForUniversalSegmentation, head_doc="universal image segmentation" -) - - -class AutoModelForInstanceSegmentation(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING - - -AutoModelForInstanceSegmentation = auto_class_update( - AutoModelForInstanceSegmentation, head_doc="instance segmentation" -) - - -class AutoModelForObjectDetection(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING - - -AutoModelForObjectDetection = auto_class_update(AutoModelForObjectDetection, head_doc="object detection") - - -class AutoModelForZeroShotObjectDetection(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING - - -AutoModelForZeroShotObjectDetection = auto_class_update( - AutoModelForZeroShotObjectDetection, head_doc="zero-shot object detection" -) - - -class AutoModelForDepthEstimation(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_DEPTH_ESTIMATION_MAPPING - - -AutoModelForDepthEstimation = auto_class_update(AutoModelForDepthEstimation, head_doc="depth estimation") - - -class AutoModelForVideoClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING - - -AutoModelForVideoClassification = auto_class_update(AutoModelForVideoClassification, head_doc="video classification") - - -class AutoModelForVision2Seq(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_VISION_2_SEQ_MAPPING - - -AutoModelForVision2Seq = auto_class_update(AutoModelForVision2Seq, head_doc="vision-to-text modeling") - - -class AutoModelForImageTextToText(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING - - -AutoModelForImageTextToText = auto_class_update(AutoModelForImageTextToText, head_doc="image-text-to-text modeling") - - -class AutoModelForAudioClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING - - -AutoModelForAudioClassification = auto_class_update(AutoModelForAudioClassification, head_doc="audio classification") - - -class AutoModelForCTC(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_CTC_MAPPING - - -AutoModelForCTC = auto_class_update(AutoModelForCTC, head_doc="connectionist temporal classification") - - -class AutoModelForSpeechSeq2Seq(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING - - -AutoModelForSpeechSeq2Seq = auto_class_update( - AutoModelForSpeechSeq2Seq, head_doc="sequence-to-sequence speech-to-text modeling" -) - - -class AutoModelForAudioFrameClassification(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING - - -AutoModelForAudioFrameClassification = auto_class_update( - AutoModelForAudioFrameClassification, head_doc="audio frame (token) classification" -) - - -class AutoModelForAudioXVector(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_AUDIO_XVECTOR_MAPPING - - -class AutoModelForTextToSpectrogram(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING - - -class AutoModelForTextToWaveform(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING - - -class AutoBackbone(_BaseAutoBackboneClass): - _model_mapping = MODEL_FOR_BACKBONE_MAPPING - - -AutoModelForAudioXVector = auto_class_update(AutoModelForAudioXVector, head_doc="audio retrieval via x-vector") - - -class AutoModelForMaskedImageModeling(_BaseAutoModelClass): - _model_mapping = MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING - - -AutoModelForMaskedImageModeling = auto_class_update(AutoModelForMaskedImageModeling, head_doc="masked image modeling") - - -class AutoModelWithLMHead(_AutoModelWithLMHead): - @classmethod - def from_config(cls, config): - warnings.warn( - "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " - "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " - "`AutoModelForSeq2SeqLM` for encoder-decoder models.", - FutureWarning, - ) - return super().from_config(config) - - @classmethod - def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): - warnings.warn( - "The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use " - "`AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and " - "`AutoModelForSeq2SeqLM` for encoder-decoder models.", - FutureWarning, - ) - return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) - - -__all__ = [ - "MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING", - "MODEL_FOR_AUDIO_FRAME_CLASSIFICATION_MAPPING", - "MODEL_FOR_AUDIO_XVECTOR_MAPPING", - "MODEL_FOR_BACKBONE_MAPPING", - "MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING", - "MODEL_FOR_CAUSAL_LM_MAPPING", - "MODEL_FOR_CTC_MAPPING", - "MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING", - "MODEL_FOR_DEPTH_ESTIMATION_MAPPING", - "MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING", - "MODEL_FOR_IMAGE_MAPPING", - "MODEL_FOR_IMAGE_SEGMENTATION_MAPPING", - "MODEL_FOR_IMAGE_TO_IMAGE_MAPPING", - "MODEL_FOR_KEYPOINT_DETECTION_MAPPING", - "MODEL_FOR_INSTANCE_SEGMENTATION_MAPPING", - "MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING", - "MODEL_FOR_MASKED_LM_MAPPING", - "MODEL_FOR_MASK_GENERATION_MAPPING", - "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", - "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", - "MODEL_FOR_OBJECT_DETECTION_MAPPING", - "MODEL_FOR_PRETRAINING_MAPPING", - "MODEL_FOR_QUESTION_ANSWERING_MAPPING", - "MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING", - "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", - "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", - "MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING", - "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", - "MODEL_FOR_TEXT_ENCODING_MAPPING", - "MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING", - "MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING", - "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", - "MODEL_FOR_UNIVERSAL_SEGMENTATION_MAPPING", - "MODEL_FOR_VIDEO_CLASSIFICATION_MAPPING", - "MODEL_FOR_VISION_2_SEQ_MAPPING", - "MODEL_FOR_RETRIEVAL_MAPPING", - "MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING", - "MODEL_FOR_VISUAL_QUESTION_ANSWERING_MAPPING", - "MODEL_MAPPING", - "MODEL_WITH_LM_HEAD_MAPPING", - "MODEL_FOR_ZERO_SHOT_IMAGE_CLASSIFICATION_MAPPING", - "MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING", - "MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING", - "MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING", - "AutoModel", - "AutoBackbone", - "AutoModelForAudioClassification", - "AutoModelForAudioFrameClassification", - "AutoModelForAudioXVector", - "AutoModelForCausalLM", - "AutoModelForCTC", - "AutoModelForDepthEstimation", - "AutoModelForImageClassification", - "AutoModelForImageSegmentation", - "AutoModelForImageToImage", - "AutoModelForInstanceSegmentation", - "AutoModelForKeypointDetection", - "AutoModelForMaskGeneration", - "AutoModelForTextEncoding", - "AutoModelForMaskedImageModeling", - "AutoModelForMaskedLM", - "AutoModelForMultipleChoice", - "AutoModelForNextSentencePrediction", - "AutoModelForObjectDetection", - "AutoModelForPreTraining", - "AutoModelForQuestionAnswering", - "AutoModelForSemanticSegmentation", - "AutoModelForSeq2SeqLM", - "AutoModelForSequenceClassification", - "AutoModelForSpeechSeq2Seq", - "AutoModelForTableQuestionAnswering", - "AutoModelForTextToSpectrogram", - "AutoModelForTextToWaveform", - "AutoModelForTokenClassification", - "AutoModelForUniversalSegmentation", - "AutoModelForVideoClassification", - "AutoModelForVision2Seq", - "AutoModelForVisualQuestionAnswering", - "AutoModelForDocumentQuestionAnswering", - "AutoModelWithLMHead", - "AutoModelForZeroShotImageClassification", - "AutoModelForZeroShotObjectDetection", - "AutoModelForImageTextToText", -] \ No newline at end of file diff --git a/mindnlp/transformers/pipelines.py b/mindnlp/transformers/pipelines.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/mindnlp/utils/testing_utils.py b/mindnlp/utils/testing_utils.py new file mode 100644 index 000000000..d819c9a1a --- /dev/null +++ b/mindnlp/utils/testing_utils.py @@ -0,0 +1,2102 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Utils for test cases.""" +import collections +import contextlib +import doctest +import functools +import inspect +import logging +import multiprocessing +import os +import re +import shlex +import shutil +import subprocess +import sys +import tempfile +import time +import unittest +import asyncio +from collections.abc import Mapping +from collections import defaultdict + +from io import StringIO +from pathlib import Path +from typing import Callable, Dict, Iterable, Iterator, List, Optional, Union +from unittest import mock +from unittest.mock import patch + +import urllib3 +import numpy as np + +import mindspore +from mindnlp.core.configs import SUPPORT_BF16 + +from transformers.utils.import_utils import ( + is_pytest_available, + is_mindspore_available, + is_essentia_available, + is_librosa_available, + is_pretty_midi_available, + is_scipy_available, + is_pyctcdecode_available, + is_safetensors_available, + is_sentencepiece_available, + is_soundfile_availble, + is_tokenizers_available, + is_pytesseract_available, + is_vision_available, + is_g2p_en_available, + is_levenshtein_available, + is_nltk_available, + is_ftfy_available +) +from transformers.utils.generic import strtobool + +if is_pytest_available(): + from _pytest.doctest import ( + Module, + _get_checker, + _get_continue_on_failure, + _get_runner, + _is_mocked, + _patch_unwrap_mock_aware, + get_optionflags, + import_path, + ) + from _pytest.config import create_terminal_writer + from _pytest.outcomes import skip + from pytest import DoctestItem +else: + Module = object + DoctestItem = object + +if is_mindspore_available(): + from mindspore import ops + + +DUMMY_UNKNOWN_IDENTIFIER = "julien-c/dummy-unknown" +SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy" + +def is_pipeline_test(test_case): + """ + Decorator marking a test as a pipeline test. If RUN_PIPELINE_TESTS is set to a falsy value, those tests will be + skipped. + """ + if not _run_pipeline_tests: + return unittest.skip("test is pipeline test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_pipeline_test()(test_case) + +def parse_flag_from_env(key, default=False): + """ + Parses a flag value from the environment variable. + + Args: + key (str): The name of the environment variable to retrieve the flag value from. + default (bool, optional): The default flag value to return if the environment variable is not set. Defaults to False. + + Returns: + bool: The parsed flag value. Returns the default value if the environment variable is not set or if its value cannot be parsed. + + Raises: + ValueError: If the environment variable value is set but cannot be parsed as a boolean ('yes' or 'no'). + + Note: + The flag value is retrieved from the environment variable specified by `key`. If the environment variable is not set, the default value is returned. If the environment variable value is set, it is +parsed as a boolean using the `strtobool` function from the `distutils.util` module. If the parsing fails, a `ValueError` is raised with a descriptive error message indicating that the value must be either +'yes' or 'no'. + """ + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = strtobool(value) + except ValueError as exc: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") from exc + return _value + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) +_run_too_slow_tests = parse_flag_from_env("RUN_TOO_SLOW", default=False) +_run_pipeline_tests = parse_flag_from_env("RUN_PIPELINE_TESTS", default=True) + +def slow(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + +def tooslow(test_case): + """ + Decorator marking a test as too slow. + + Slow tests are skipped while they're in the process of being fixed. No test should stay tagged as "tooslow" as + these will not be tested by the CI. + + """ + return unittest.skipUnless(_run_too_slow_tests, "test is too slow")(test_case) + +def parse_int_from_env(key, default=None): + """Parses an integer value from the specified environment variable. + + Args: + key (str): The name of the environment variable to retrieve the integer value from. + default (int, optional): The default integer value to return if the environment variable is not set or cannot be converted to an integer. Defaults to None. + + Returns: + int or None: The integer value parsed from the environment variable or the default value if provided. Returns None if the environment variable is not set and no default value is specified. + + Raises: + ValueError: If the value retrieved from the environment variable cannot be converted to an integer. + """ + try: + value = os.environ[key] + except KeyError: + _value = default + else: + try: + _value = int(value) + except ValueError as exc: + raise ValueError(f"If set, {key} must be a int.") from exc + return _value + + +def require_ftfy(test_case): + """ + Decorator marking a test that requires ftfy. These tests are skipped when ftfy isn't installed. + """ + return unittest.skipUnless(is_ftfy_available(), "test requires ftfy")(test_case) + + +def require_levenshtein(test_case): + """ + Decorator marking a test that requires Levenshtein. + + These tests are skipped when Levenshtein isn't installed. + + """ + return unittest.skipUnless(is_levenshtein_available(), "test requires Levenshtein")(test_case) + + +def require_nltk(test_case): + """ + Decorator marking a test that requires NLTK. + + These tests are skipped when NLTK isn't installed. + + """ + return unittest.skipUnless(is_nltk_available(), "test requires NLTK")(test_case) + + +def require_vision(test_case): + """ + Decorator marking a test that requires the vision dependencies. These tests are skipped when torchaudio isn't + installed. + """ + return unittest.skipUnless(is_vision_available(), "test requires vision")(test_case) + +def require_tokenizers(test_case): + """ + Decorator marking a test that requires 🤗 Tokenizers. These tests are skipped when 🤗 Tokenizers isn't installed. + """ + return unittest.skipUnless(is_tokenizers_available(), "test requires tokenizers")(test_case) + +def require_sentencepiece(test_case): + """ + Decorator marking a test that requires SentencePiece. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_sentencepiece_available(), "test requires SentencePiece")(test_case) + +def require_mindspore(test_case): + """ + Decorator marking a test that requires MindSpore. + + These tests are skipped when MindSpore isn't installed. + + """ + return unittest.skipUnless(is_mindspore_available(), "test requires MindSpore")(test_case) + +def require_bfloat16(test_case): + """require_bfloat16""" + return unittest.skipUnless(SUPPORT_BF16, "test need bfloat16")(test_case) + +def require_mindspore_gpu(test_case): + """Decorator marking a test that requires CUDA and MindSpore.""" + return unittest.skipUnless(mindspore.get_context('device_target') == "GPU", "test requires CUDA")(test_case) + +def require_mindspore_npu(test_case): + """Decorator marking a test that requires CANN and MindSpore.""" + return unittest.skipUnless(mindspore.get_context('device_target') == "Ascend", "test requires CANN")(test_case) + + +def require_librosa(test_case): + """ + Decorator marking a test that requires librosa + """ + return unittest.skipUnless(is_librosa_available(), "test requires librosa")(test_case) + +def require_essentia(test_case): + """ + Decorator marking a test that requires essentia + """ + return unittest.skipUnless(is_essentia_available(), "test requires essentia")(test_case) + +def require_pretty_midi(test_case): + """ + Decorator marking a test that requires pretty_midi + """ + return unittest.skipUnless(is_pretty_midi_available(), "test requires pretty_midi")(test_case) + +def require_scipy(test_case): + """ + Decorator marking a test that requires Scipy. These tests are skipped when SentencePiece isn't installed. + """ + return unittest.skipUnless(is_scipy_available(), "test requires Scipy")(test_case) + +def require_pyctcdecode(test_case): + """ + Decorator marking a test that requires pyctcdecode + """ + return unittest.skipUnless(is_pyctcdecode_available(), "test requires pyctcdecode")(test_case) + +def require_safetensors(test_case): + """ + Decorator marking a test that requires safetensors. These tests are skipped when safetensors isn't installed. + """ + return unittest.skipUnless(is_safetensors_available(), "test requires safetensors")(test_case) + +def require_pytesseract(test_case): + """ + Decorator marking a test that requires pytesseract + """ + return unittest.skipUnless(is_pytesseract_available(), "test requires pytesseract")(test_case) + +def require_g2p_en(test_case): + """ + Decorator marking a test that requires pytesseract + """ + return unittest.skipUnless(is_g2p_en_available(), "test requires g2p-en")(test_case) + + +def cmd_exists(cmd): + """ + Check if a command exists in the system PATH. + + Args: + cmd (str): The name of the command to check for existence in the system PATH. + + Returns: + None: Returns None if the command exists in the system PATH, otherwise returns False. + + Raises: + None. + """ + return shutil.which(cmd) is not None +# +# Helper functions for dealing with testing text outputs +# The original code came from: +# https://github.com/fastai/fastai/blob/master/tests/utils/text.py + + +# When any function contains print() calls that get overwritten, like progress bars, +# a special care needs to be applied, since under pytest -s captured output (capsys +# or contextlib.redirect_stdout) contains any temporary printed strings, followed by +# \r's. This helper function ensures that the buffer will contain the same output +# with and without -s in pytest, by turning: +# foo bar\r tar mar\r final message +# into: +# final message +# it can handle a single string or a multiline buffer +def apply_print_resets(buf): + """ + Apply print resets by removing any characters before the last carriage return in the given buffer. + + Args: + buf (str): The input buffer containing text data. + + Returns: + None. The function modifies the buffer in place. + + Raises: + None. + """ + return re.sub(r"^.*\r", "", buf, 0, re.M) + + +def assert_screenout(out, what): + """ + This function asserts the presence of a specified string within the provided output. + + Args: + out (str): The output string to be checked for the presence of the specified string. + what (str): The string to be searched for within the output. + + Returns: + None: This function does not return any value. + + Raises: + AssertionError: If the specified string 'what' is not found within the output string 'out'. + """ + out_pr = apply_print_resets(out).lower() + match_str = out_pr.find(what.lower()) + assert match_str != -1, f"expecting to find {what} in output: f{out_pr}" + + +class CaptureStd: + """ + Context manager to capture: + + - stdout: replay it, clean it up and make it available via `obj.out` + - stderr: replay it and make it available via `obj.err` + + Args: + out (`bool`, *optional*, defaults to `True`): Whether to capture stdout or not. + err (`bool`, *optional*, defaults to `True`): Whether to capture stderr or not. + replay (`bool`, *optional*, defaults to `True`): Whether to replay or not. + By default each captured stream gets replayed back on context's exit, so that one can see what the test was + doing. If this is a not wanted behavior and the captured data shouldn't be replayed, pass `replay=False` to + disable this feature. + + Examples: + + ```python + # to capture stdout only with auto-replay + with CaptureStdout() as cs: + print("Secret message") + assert "message" in cs.out + + # to capture stderr only with auto-replay + import sys + + with CaptureStderr() as cs: + print("Warning: ", file=sys.stderr) + assert "Warning" in cs.err + + # to capture both streams with auto-replay + with CaptureStd() as cs: + print("Secret message") + print("Warning: ", file=sys.stderr) + assert "message" in cs.out + assert "Warning" in cs.err + + # to capture just one of the streams, and not the other, with auto-replay + with CaptureStd(err=False) as cs: + print("Secret message") + assert "message" in cs.out + # but best use the stream-specific subclasses + + # to capture without auto-replay + with CaptureStd(replay=False) as cs: + print("Secret message") + assert "message" in cs.out + ```""" + def __init__(self, out=True, err=True, replay=True): + """Initialize a CaptureStd object. + + Args: + self (CaptureStd): The instance of the CaptureStd class. + out (bool): Flag indicating whether to capture stdout. Default is True. + err (bool): Flag indicating whether to capture stderr. Default is True. + replay (bool): Flag indicating whether to replay captured output. Default is True. + + Returns: + None + + Raises: + None + + This method initializes a CaptureStd object with the given parameters. The 'out' parameter determines whether to capture stdout, while the 'err' parameter determines whether to capture stderr. By +default, both 'out' and 'err' are set to True. If 'out' is True, a StringIO object is created to capture stdout. If 'out' is False, stdout is not captured and the 'out' attribute is set to 'not capturing +stdout'. The same logic applies to 'err' and stderr. + + The 'replay' parameter determines whether the captured output should be replayed. By default, 'replay' is set to True. + + Note: If 'out' or 'err' is set to True, but the CaptureStd context is not finished yet (i.e., __exit__ is not called), an error message is set to the corresponding attribute indicating that the context +was called too early. + """ + self.replay = replay + + if out: + self.out_buf = StringIO() + self.out = "error: CaptureStd context is unfinished yet, called too early" + else: + self.out_buf = None + self.out = "not capturing stdout" + + if err: + self.err_buf = StringIO() + self.err = "error: CaptureStd context is unfinished yet, called too early" + else: + self.err_buf = None + self.err = "not capturing stderr" + + def __enter__(self): + """ + The '__enter__' method is used as a context manager to redirect the standard output and standard error streams to the provided buffers. + + Args: + self: An instance of the 'CaptureStd' class. + + Returns: + None. This method does not return any value explicitly. + + Raises: + None. + """ + if self.out_buf: + self.out_old = sys.stdout + sys.stdout = self.out_buf + + if self.err_buf: + self.err_old = sys.stderr + sys.stderr = self.err_buf + + return self + + def __exit__(self, *exc): + """ + This method __exit__ is called automatically when exiting a 'with' block that uses the CaptureStd context manager. + + Args: + self: An instance of the CaptureStd class that represents the current context manager. It is used to access the attributes and buffers within the context manager. + + Returns: + None. The method does not explicitly return a value. + + Raises: + This method does not raise any exceptions explicitly. However, exceptions may be raised if there are errors during the execution of the code within the method. + """ + if self.out_buf: + sys.stdout = self.out_old + captured = self.out_buf.getvalue() + if self.replay: + sys.stdout.write(captured) + self.out = apply_print_resets(captured) + + if self.err_buf: + sys.stderr = self.err_old + captured = self.err_buf.getvalue() + if self.replay: + sys.stderr.write(captured) + self.err = captured + + def __repr__(self): + """ + Returns a string representation of the CaptureStd object. + + Args: + self: The instance of the CaptureStd class. + + Returns: + None. This method does not return any value. + + Raises: + None. + + Description: + The __repr__ method is called when the repr() function is used on an instance of the CaptureStd class. It generates a string representation of the object, which includes the captured stdout and +stderr outputs, if any. The generated string representation is returned by the method. + + This method checks if the 'out_buf' attribute of the CaptureStd object is not empty. If it is not empty, the captured stdout output is added to the message string. Similarly, if the 'err_buf' +attribute is not empty, the captured stderr output is added to the message string. The final message string is then returned by the method. + + Note that the stdout and stderr outputs are represented as 'stdout: ' and 'stderr: ' respectively in the message string. + + Example Usage: + capture = CaptureStd() + capture.capture_stdout('Hello, world!') + capture.capture_stderr('Oops, an error occurred.') + repr_str = repr(capture) + print(repr_str) + # Output: "stdout: Hello, world!\nstderr: Oops, an error occurred.\n" + """ + msg = "" + if self.out_buf: + msg += f"stdout: {self.out}\n" + if self.err_buf: + msg += f"stderr: {self.err}\n" + return msg + + +# in tests it's the best to capture only the stream that's wanted, otherwise +# it's easy to miss things, so unless you need to capture both streams, use the +# subclasses below (less typing). Or alternatively, configure `CaptureStd` to +# disable the stream you don't need to test. + + +class CaptureStdout(CaptureStd): + """Same as CaptureStd but captures only stdout""" + def __init__(self, replay=True): + """ + Initializes an instance of the CaptureStdout class. + + Args: + self: The instance of the class. + replay (bool): A boolean flag indicating whether the captured output should be replayed. + Defaults to True. If set to True, the captured output will be replayed. + If set to False, the captured output will not be replayed. + + Returns: + None. This method does not return any value. + + Raises: + No specific exceptions are raised by this method. + """ + super().__init__(err=False, replay=replay) + + +class CaptureStderr(CaptureStd): + """Same as CaptureStd but captures only stderr""" + def __init__(self, replay=True): + """ + Initializes an instance of the CaptureStderr class. + + Args: + self (CaptureStderr): The current object. + replay (bool): Indicates whether to replay the captured stderr output. Default is True. + + Returns: + None. This method does not return any value. + + Raises: + None. This method does not raise any exceptions. + """ + super().__init__(out=False, replay=replay) + + +class CaptureLogger: + """ + Context manager to capture `logging` streams + + Args: + logger: 'logging` logger object + + Returns: + The captured output is available via `self.out` + + Example: + + ```python + >>> from transformers import logging + >>> from transformers.testing_utils import CaptureLogger + + >>> msg = "Testing 1, 2, 3" + >>> logging.set_verbosity_info() + >>> logger = logging.get_logger("transformers.models.bart.tokenization_bart") + >>> with CaptureLogger(logger) as cl: + ... logger.info(msg) + >>> assert cl.out, msg + "\n" + ``` + """ + def __init__(self, logger): + """ + Initializes a new instance of the CaptureLogger class. + + Args: + self: The instance of the class. + logger: An object representing the logger to be used for capturing logs. It should be an instance of a logger class. + + Returns: + None. This method does not return any value. + + Raises: + None. This method does not raise any exceptions. + """ + self.logger = logger + self.io = StringIO() + self.sh = logging.StreamHandler(self.io) + self.out = "" + + def __enter__(self): + """ + This method is an implementation of the context manager protocol for the CaptureLogger class. + + Args: + self: An instance of the CaptureLogger class. It represents the current object that the method is being called upon. + + Returns: + None. The method does not explicitly return any value, but it adds a handler to the logger associated with the CaptureLogger instance. + + Raises: + This method does not raise any exceptions under normal circumstances. However, potential exceptions could be raised if there are issues with adding the handler to the logger, such as improper +configuration of the logging system. + """ + self.logger.addHandler(self.sh) + return self + + def __exit__(self, *exc): + """ + This method __exit__ is called automatically when exiting a 'with' block in the CaptureLogger class. + + Args: + self (CaptureLogger): An instance of the CaptureLogger class. It is used to access the logger and the captured output. + + Returns: + None. This method does not return any value. + + Raises: + This method does not raise any exceptions explicitly. However, exceptions may be raised internally if there are issues with removing the handler or getting the captured output. + """ + self.logger.removeHandler(self.sh) + self.out = self.io.getvalue() + + def __repr__(self): + """ + Return a string representation of the CaptureLogger object. + + Args: + self (CaptureLogger): The instance of the CaptureLogger class. + + Returns: + None: This method does not explicitly return any value, as it returns None. + + Raises: + None: This method does not raise any exceptions. + """ + return f"captured: {self.out}\n" + + +@contextlib.contextmanager +def LoggingLevel(level): + """ + This is a context manager to temporarily change transformers modules logging level to the desired value and have it + restored to the original setting at the end of the scope. + + Example: + + ```python + with LoggingLevel(logging.INFO): + AutoModel.from_pretrained("gpt2") # calls logger.info() several times + ``` + """ + orig_level = logging.get_verbosity() + try: + logging.set_verbosity(level) + yield + finally: + logging.set_verbosity(orig_level) + + +@contextlib.contextmanager +# adapted from https://stackoverflow.com/a/64789046/9201239 +def ExtendSysPath(path: Union[str, os.PathLike]) -> Iterator[None]: + """ + Temporary add given path to `sys.path`. + + Usage : + + ```python + with ExtendSysPath("/path/to/dir"): + mymodule = importlib.import_module("mymodule") + ``` + """ + path = os.fspath(path) + try: + sys.path.insert(0, path) + yield + finally: + sys.path.remove(path) + + +class TestCasePlus(unittest.TestCase): + """ + This class extends *unittest.TestCase* with additional features. + + Feature 1: A set of fully resolved important file and dir path accessors. + + In tests often we need to know where things are relative to the current test file, and it's not trivial since the + test could be invoked from more than one directory or could reside in sub-directories with different depths. This + class solves this problem by sorting out all the basic paths and provides easy accessors to them: + + - `pathlib` objects (all fully resolved): + + - `test_file_path` - the current test file path (=`__file__`) + - `test_file_dir` - the directory containing the current test file + - `tests_dir` - the directory of the `tests` test suite + - `examples_dir` - the directory of the `examples` test suite + - `repo_root_dir` - the directory of the repository + - `src_dir` - the directory of `src` (i.e. where the `transformers` sub-dir resides) + + - stringified paths---same as above but these return paths as strings, rather than `pathlib` objects: + + - `test_file_path_str` + - `test_file_dir_str` + - `tests_dir_str` + - `examples_dir_str` + - `repo_root_dir_str` + - `src_dir_str` + + Feature 2: Flexible auto-removable temporary dirs which are guaranteed to get removed at the end of test. + + 1. Create a unique temporary dir: + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir() + ``` + + `tmp_dir` will contain the path to the created temporary dir. It will be automatically removed at the end of the + test. + + + 2. Create a temporary dir of my choice, ensure it's empty before the test starts and don't + empty it after the test. + + ```python + def test_whatever(self): + tmp_dir = self.get_auto_remove_tmp_dir("./xxx") + ``` + + This is useful for debug when you want to monitor a specific directory and want to make sure the previous tests + didn't leave any data in there. + + 3. You can override the first two options by directly overriding the `before` and `after` args, leading to the + following behavior: + + `before=True`: the temporary dir will always be cleared at the beginning of the test. + + `before=False`: if the temporary dir already existed, any existing files will remain there. + + `after=True`: the temporary dir will always be deleted at the end of the test. + + `after=False`: the temporary dir will always be left intact at the end of the test. + + Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the project repository checkout are + allowed if an explicit `tmp_dir` is used, so that by mistake no `/tmp` or similar important part of the filesystem + will get nuked. i.e. please always pass paths that start with `./` + + Note 2: Each test can register multiple temporary dirs and they all will get auto-removed, unless requested + otherwise. + + Feature 3: Get a copy of the `os.environ` object that sets up `PYTHONPATH` specific to the current test suite. This + is useful for invoking external programs from the test suite - e.g. distributed training. + + + ```python + def test_whatever(self): + env = self.get_env() + ```""" + def setUp(self): + """ + Set up the necessary environment for the TestCasePlus class. + + Args: + self: The instance of the TestCasePlus class. + + Returns: + None. This method does not return any value. + + Raises: + ValueError: If the root directory of the repository cannot be determined from the test file path. + + Description: + This method is called before each test case to set up the required environment for the TestCasePlus class. It initializes various directories and paths based on the current test file's location. The +method performs the following steps: + + 1. Sets up a list to keep track of temporary directories that need to be cleaned up later. + 2. Retrieves the path of the test file using the inspect module. + 3. Resolves the absolute path of the test file. + 4. Determines the parent directory of the test file. + 5. Checks if there are 'src' and 'tests' directories in any of the parent directories up to three levels above the test file. + 6. If such directories are found, the loop breaks and the repository root directory is set as the temporary directory. + 7. If no valid temporary directory is found, a ValueError is raised indicating that the root directory of the repository could not be determined. + 8. Sets the paths for the 'tests', 'examples', and 'src' directories within the repository root directory. + + Note: + This method assumes a specific directory structure for the repository, where 'src' and 'tests' directories exist at an appropriate level above the test file. + + Example usage: + test_case = TestCasePlus() + test_case.setUp() + """ + # get_auto_remove_tmp_dir feature: + self.teardown_tmp_dirs = [] + + # figure out the resolved paths for repo_root, tests, examples, etc. + self._test_file_path = inspect.getfile(self.__class__) + path = Path(self._test_file_path).resolve() + self._test_file_dir = path.parents[0] + for up in [1, 2, 3]: + tmp_dir = path.parents[up] + if (tmp_dir / "src").is_dir() and (tmp_dir / "tests").is_dir(): + break + if tmp_dir: + self._repo_root_dir = tmp_dir + else: + raise ValueError(f"can't figure out the root of the repo from {self._test_file_path}") + self._tests_dir = self._repo_root_dir / "tests" + self._examples_dir = self._repo_root_dir / "examples" + self._src_dir = self._repo_root_dir / "src" + + @property + def test_file_path(self): + """ + Returns the test file path. + + Args: + self: An instance of the TestCasePlus class. + + Returns: + None. The method does not return any value. + + Raises: + This method does not raise any exceptions. + """ + return self._test_file_path + + @property + def test_file_path_str(self): + """ + Method to retrieve the string representation of the test file path. + + Args: + self: Instance of the TestCasePlus class. + - Type: object + - Purpose: Represents the current instance of the class. + - Restrictions: None + + Returns: + The method returns a string representing the test file path. + - Type: str + - Purpose: Provides the string representation of the test file path. + + Raises: + No exceptions are raised by this method. + """ + return str(self._test_file_path) + + @property + def test_file_dir(self): + """ + This method retrieves the directory path where test files are located. + + Args: + self: An instance of the TestCasePlus class. + This parameter refers to the current instance of the TestCasePlus class. + + Returns: + None. The method does not return any value explicitly but retrieves and returns the test file directory path. + + Raises: + This method does not raise any exceptions. + """ + return self._test_file_dir + + @property + def test_file_dir_str(self): + """ + Method test_file_dir_str in the class TestCasePlus. + + Args: + self: Represents the instance of the class. No additional parameters are required. + + Returns: + str: A string representation of the _test_file_dir attribute of the instance. + + Raises: + None. + """ + return str(self._test_file_dir) + + @property + def tests_dir(self): + """ + Method: tests_dir + + Description: + Returns the tests directory path used by the TestCasePlus class. + + Args: + - self (object): The instance of the TestCasePlus class. + + Returns: + - None: This method does not return any value explicitly. + + Raises: + - None + """ + return self._tests_dir + + @property + def tests_dir_str(self): + """ + Returns the tests directory as a string. + + Args: + self: An instance of the TestCasePlus class. + + Returns: + str: The tests directory path converted to a string. + + Raises: + None. + + This method returns the tests directory path as a string. The tests directory is obtained from the '_tests_dir' attribute of the TestCasePlus class. The returned string represents the absolute path of +the tests directory. + + Example usage: + >>> test_case = TestCasePlus() + >>> test_case.tests_dir_str() + '/path/to/tests/directory' + """ + return str(self._tests_dir) + + @property + def examples_dir(self): + """ + Method to get the examples directory path. + + Args: + self: The instance of the class. + + Returns: + None. The method returns the examples directory path. + + Raises: + This method does not raise any exceptions. + """ + return self._examples_dir + + @property + def examples_dir_str(self): + """ + Method examples_dir_str in the class TestCasePlus returns the string representation of the _examples_dir attribute. + + Args: + self: An instance of the TestCasePlus class. + Purpose: Represents the current instance of the class. + Restrictions: None. + + Returns: + str: A string representation of the _examples_dir attribute. + Purpose: Provides a human-readable string representation of the _examples_dir attribute. + + Raises: + None. + """ + return str(self._examples_dir) + + @property + def repo_root_dir(self): + """ + Method to retrieve the root directory of the repository. + + Args: + self (TestCasePlus): The instance of the TestCasePlus class. + This parameter is required to access the instance attributes and methods. + + Returns: + None. The method returns the value of the '_repo_root_dir' attribute of the instance. + + Raises: + This method does not raise any exceptions. + """ + return self._repo_root_dir + + @property + def repo_root_dir_str(self): + """ + Method to retrieve the repository root directory as a string. + + Args: + self: The instance of the class TestCasePlus. + This parameter is automatically passed and refers to the instance itself. + + Returns: + str: A string representing the repository root directory. + This method returns the repository root directory as a string. + + Raises: + None. + """ + return str(self._repo_root_dir) + + @property + def src_dir(self): + """ + Returns the source directory path for the TestCasePlus class. + + Args: + self (TestCasePlus): An instance of the TestCasePlus class. + + Returns: + None: The method does not return any value. + + Raises: + None: This method does not raise any exceptions. + """ + return self._src_dir + + @property + def src_dir_str(self): + """ + Method to retrieve the source directory path as a string representation. + + Args: + self: An instance of the TestCasePlus class. + This parameter refers to the current object instance. + It is used to access the source directory path stored in the _src_dir attribute. + + Returns: + None + This method returns the source directory path as a string. If the source directory path does not exist or is empty, None is returned. + + Raises: + None + This method does not raise any exceptions. + """ + return str(self._src_dir) + + def get_env(self): + """ + Return a copy of the `os.environ` object that sets up `PYTHONPATH` correctly, depending on the test suite it's + invoked from. This is useful for invoking external programs from the test suite - e.g. distributed training. + + It always inserts `./src` first, then `./tests` or `./examples` depending on the test suite type and finally + the preset `PYTHONPATH` if any (all full resolved paths). + + """ + env = os.environ.copy() + paths = [self.src_dir_str] + if "/examples" in self.test_file_dir_str: + paths.append(self.examples_dir_str) + else: + paths.append(self.tests_dir_str) + paths.append(env.get("PYTHONPATH", "")) + + env["PYTHONPATH"] = ":".join(paths) + return env + + def get_auto_remove_tmp_dir(self, tmp_dir=None, before=None, after=None): + """ + Args: + tmp_dir (`string`, *optional*): + if `None`: + + - a unique temporary path will be created + - sets `before=True` if `before` is `None` + - sets `after=True` if `after` is `None` + else: + + - `tmp_dir` will be created + - sets `before=True` if `before` is `None` + - sets `after=False` if `after` is `None` + before (`bool`, *optional*): + If `True` and the `tmp_dir` already exists, make sure to empty it right away if `False` and the + `tmp_dir` already exists, any existing files will remain there. + after (`bool`, *optional*): + If `True`, delete the `tmp_dir` at the end of the test if `False`, leave the `tmp_dir` and its contents + intact at the end of the test. + + Returns: + tmp_dir(`string`): either the same value as passed via *tmp_dir* or the path to the auto-selected tmp dir + """ + if tmp_dir is not None: + # defining the most likely desired behavior for when a custom path is provided. + # this most likely indicates the debug mode where we want an easily locatable dir that: + # 1. gets cleared out before the test (if it already exists) + # 2. is left intact after the test + if before is None: + before = True + if after is None: + after = False + + # using provided path + path = Path(tmp_dir).resolve() + + # to avoid nuking parts of the filesystem, only relative paths are allowed + if not tmp_dir.startswith("./"): + raise ValueError( + f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`" + ) + + # ensure the dir is empty to start with + if before is True and path.exists(): + shutil.rmtree(tmp_dir, ignore_errors=True) + + path.mkdir(parents=True, exist_ok=True) + + else: + # defining the most likely desired behavior for when a unique tmp path is auto generated + # (not a debug mode), here we require a unique tmp dir that: + # 1. is empty before the test (it will be empty in this situation anyway) + # 2. gets fully removed after the test + if before is None: + before = True + if after is None: + after = True + + # using unique tmp dir (always empty, regardless of `before`) + tmp_dir = tempfile.mkdtemp() + + if after is True: + # register for deletion + self.teardown_tmp_dirs.append(tmp_dir) + + return tmp_dir + + def python_one_liner_max_rss(self, one_liner_str): + """ + Runs the passed python one liner (just the code) and returns how much max cpu memory was used to run the + program. + + Args: + one_liner_str (`string`): + a python one liner code that gets passed to `python -c` + + Returns: + max cpu memory bytes used to run the program. This value is likely to vary slightly from run to run. + + Requirements: + this helper needs `/usr/bin/time` to be installed (`apt install time`) + + Example: + + ``` + one_liner_str = 'from transformers import AutoModel; AutoModel.from_pretrained("t5-large")' + max_rss = self.python_one_liner_max_rss(one_liner_str) + ``` + """ + if not cmd_exists("/usr/bin/time"): + raise ValueError("/usr/bin/time is required, install with `apt install time`") + + cmd = shlex.split(f"/usr/bin/time -f %M python -c '{one_liner_str}'") + with CaptureStd() as cs: + execute_subprocess_async(cmd, env=self.get_env()) + # returned data is in KB so convert to bytes + max_rss = int(cs.err.split("\n")[-2].replace("stderr: ", "")) * 1024 + return max_rss + + def tearDown(self): + """ + Tears down the test case by cleaning up temporary directories. + + Args: + self (TestCasePlus): The instance of the TestCasePlus class. + + Returns: + None: This method does not return any value. + + Raises: + None: This method does not raise any exceptions. + """ + # get_auto_remove_tmp_dir feature: remove registered temp dirs + for path in self.teardown_tmp_dirs: + shutil.rmtree(path, ignore_errors=True) + self.teardown_tmp_dirs = [] + + +def mockenv(**kwargs): + """ + this is a convenience wrapper, that allows this :: + + @mockenv(RUN_SLOW=True, USE_TF=False) def test_something(): + run_slow = os.getenv("RUN_SLOW", False) use_tf = os.getenv("USE_TF", False) + + """ + return mock.patch.dict(os.environ, kwargs) + + +# from https://stackoverflow.com/a/34333710/9201239 +@contextlib.contextmanager +def mockenv_context(*remove, **update): + """ + Temporarily updates the `os.environ` dictionary in-place. Similar to mockenv + + The `os.environ` dictionary is updated in-place so that the modification is sure to work in all situations. + + Args: + remove: Environment variables to remove. + update: Dictionary of environment variables and values to add/update. + """ + env = os.environ + update = update or {} + remove = remove or [] + + # List of environment variables being updated or removed. + stomped = (set(update.keys()) | set(remove)) & set(env.keys()) + # Environment variables and values to restore on exit. + update_after = {k: env[k] for k in stomped} + # Environment variables and values to remove on exit. + remove_after = frozenset(k for k in update if k not in env) + + try: + env.update(update) + for k in remove: + env.pop(k, None) + yield + finally: + env.update(update_after) + for k in remove_after: + env.pop(k) + + +# --- pytest conf functions --- # + +# to avoid multiple invocation from tests/conftest.py and examples/conftest.py - make sure it's called only once +pytest_opt_registered = {} + + +def pytest_addoption_shared(parser): + """ + This function is to be called from `conftest.py` via `pytest_addoption` wrapper that has to be defined there. + + It allows loading both `conftest.py` files at once without causing a failure due to adding the same `pytest` + option. + + """ + option = "--make-reports" + if option not in pytest_opt_registered: + parser.addoption( + option, + action="store", + default=False, + help="generate report files. The value of this option is used as a prefix to report names", + ) + pytest_opt_registered[option] = 1 + + +def pytest_terminal_summary_main(tr, ids): + """ + Generate multiple reports at the end of test suite run - each report goes into a dedicated file in the current + directory. The report files are prefixed with the test suite name. + + This function emulates --duration and -rA pytest arguments. + + This function is to be called from `conftest.py` via `pytest_terminal_summary` wrapper that has to be defined + there. + + Args: + - tr: `terminalreporter` passed from `conftest.py` + - ids: unique id like `tests` or `examples` that will be incorporated into the final reports filenames - this is + needed as some jobs have multiple runs of pytest, so we can't have them overwrite each other. + + NB: this functions taps into a private _pytest API and while unlikely, it could break should pytest do internal + changes - also it calls default internal methods of terminalreporter which can be hijacked by various `pytest-` + plugins and interfere. + + """ + if not ids: + ids = "tests" + + config = tr.config + orig_writer = config.get_terminal_writer() + orig_tbstyle = config.option.tbstyle + orig_reportchars = tr.reportchars + + dirs = f"reports/{ids}" + Path(dirs).mkdir(parents=True, exist_ok=True) + report_files = { + k: f"{dirs}/{k}.txt" + for k in [ + "durations", + "errors", + "failures_long", + "failures_short", + "failures_line", + "passes", + "stats", + "summary_short", + "warnings", + ] + } + + # custom durations report + # note: there is no need to call pytest --durations=XX to get this separate report + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/runner.py#L66 + dlist = [] + for replist in tr.stats.values(): + for rep in replist: + if hasattr(rep, "duration"): + dlist.append(rep) + if dlist: + dlist.sort(key=lambda x: x.duration, reverse=True) + with open(report_files["durations"], "w") as f: + durations_min = 0.05 # sec + f.write("slowest durations\n") + for i, rep in enumerate(dlist): + if rep.duration < durations_min: + f.write(f"{len(dlist)-i} durations < {durations_min} secs were omitted") + break + f.write(f"{rep.duration:02.2f}s {rep.when:<8} {rep.nodeid}\n") + + def summary_failures_short(tr): + # expecting that the reports were --tb=long (default) so we chop them off here to the last frame + reports = tr.getreports("failed") + if not reports: + return + tr.write_sep("=", "FAILURES SHORT STACK") + for rep in reports: + msg = tr._getfailureheadline(rep) + tr.write_sep("_", msg, red=True, bold=True) + # chop off the optional leading extra frames, leaving only the last one + longrepr = re.sub(r".*_ _ _ (_ ){10,}_ _ ", "", rep.longreprtext, 0, re.M | re.S) + tr._tw.line(longrepr) + # note: not printing out any rep.sections to keep the report short + + # use ready-made report funcs, we are just hijacking the filehandle to log to a dedicated file each + # adapted from https://github.com/pytest-dev/pytest/blob/897f151e/src/_pytest/terminal.py#L814 + # note: some pytest plugins may interfere by hijacking the default `terminalreporter` (e.g. + # pytest-instafail does that) + + # report failures with line/short/long styles + config.option.tbstyle = "auto" # full tb + with open(report_files["failures_long"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + # config.option.tbstyle = "short" # short tb + with open(report_files["failures_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + summary_failures_short(tr) + + config.option.tbstyle = "line" # one line per error + with open(report_files["failures_line"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_failures() + + with open(report_files["errors"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_errors() + + with open(report_files["warnings"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_warnings() # normal warnings + tr.summary_warnings() # final warnings + + tr.reportchars = "wPpsxXEf" # emulate -rA (used in summary_passes() and short_test_summary()) + + # Skip the `passes` report, as it starts to take more than 5 minutes, and sometimes it timeouts on CircleCI if it + # takes > 10 minutes (as this part doesn't generate any output on the terminal). + # (also, it seems there is no useful information in this report, and we rarely need to read it) + # with open(report_files["passes"], "w") as f: + # tr._tw = create_terminal_writer(config, f) + # tr.summary_passes() + + with open(report_files["summary_short"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.short_test_summary() + + with open(report_files["stats"], "w") as f: + tr._tw = create_terminal_writer(config, f) + tr.summary_stats() + + # restore: + tr._tw = orig_writer + tr.reportchars = orig_reportchars + config.option.tbstyle = orig_tbstyle + + +# --- distributed testing functions --- # + +# adapted from https://stackoverflow.com/a/59041913/9201239 +class _RunOutput: + + """ + Represents the output of a command execution, including the return code, standard output, and standard error. + + Attributes: + returncode (int): The return code of the executed command. + stdout (str): The standard output captured from the command execution. + stderr (str): The standard error captured from the command execution. + """ + def __init__(self, returncode, stdout, stderr): + """ + __init__(self, returncode, stdout, stderr) + + Initializes the _RunOutput class instance with the provided return code, standard output, and standard error. + + Args: + self (_RunOutput): The instance of the _RunOutput class. + returncode (int): The return code from the executed command. + stdout (str): The standard output generated by the executed command. + stderr (str): The standard error generated by the executed command. + + Returns: + None: This method does not return any value. + + Raises: + No specific exceptions are raised by this method. + """ + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + """ + Docstring for _read_stream function: + + Args: + stream (stream): The input stream from which the function reads data. + callback (function): The callback function to be executed for each line read from the stream. + + Returns: + None. The function does not return any value. + + Raises: + No specific exceptions are raised by this function. + """ + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + """ + This function runs a subprocess and captures its standard output and error streams. + + Args: + - cmd (List[str]): A list of command and arguments to be executed. + - env (Optional[Dict[str, str]]): A dictionary of environment variables to be used for the subprocess. + - stdin (Optional[asyncio.subprocess.StreamReader]): A stream representing the standard input for the subprocess. + - timeout (Optional[float]): The maximum time in seconds to wait for the subprocess to complete. + - quiet (bool): If True, suppresses the output of the subprocess. + - echo (bool): If True, prints the command being executed. + + Returns: + _RunOutput: An object containing the return code of the subprocess, its standard output, and standard error. + + Raises: + - asyncio.TimeoutError: If the subprocess execution exceeds the specified timeout. + - OSError: If an OS-related error occurs during the subprocess execution. + - ValueError: If the provided command is invalid or the arguments are of the wrong type. + """ + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + await asyncio.wait( + [ + _read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:")), + _read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:")), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + """ + Args: + cmd (List[str]): A list of strings representing the command and its arguments to be executed. + env (Optional[Dict[str, str]]): A dictionary of environment variables to be passed to the subprocess. + stdin (Optional[Union[str, bytes]]): The input to be passed to the subprocess. + timeout (int): The maximum time in seconds to wait for the subprocess to complete. + quiet (bool): If True, suppresses output from the subprocess. + echo (bool): If True, prints the subprocess output to the console. + + Returns: + _RunOutput: An object containing the output of the executed subprocess, including stdout, stderr, and returncode. + + Raises: + RuntimeError: If the subprocess fails with a non-zero return code or produces no output. + """ + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + # check that the subprocess actually did run and produced some output, should the test rely on + # the remote side to do the testing + if not result.stdout and not result.stderr: + raise RuntimeError(f"'{cmd_str}' produced no output.") + + return result + + +def pytest_xdist_worker_id(): + """ + Returns an int value of worker's numerical id under `pytest-xdist`'s concurrent workers `pytest -n N` regime, or 0 + if `-n 1` or `pytest-xdist` isn't being used. + """ + worker = os.environ.get("PYTEST_XDIST_WORKER", "gw0") + worker = re.sub(r"^gw", "", worker, 0, re.M) + return int(worker) + + +def get_torch_dist_unique_port(): + """ + Returns a port number that can be fed to `torch.distributed.launch`'s `--master_port` argument. + + Under `pytest-xdist` it adds a delta number based on a worker id so that concurrent tests don't try to use the same + port at once. + """ + port = 29500 + uniq_delta = pytest_xdist_worker_id() + return port + uniq_delta + + +def nested_simplify(obj, decimals=3): + """ + Simplifies an object by rounding float numbers, and downcasting tensors/numpy arrays to get simple equality test + within tests. + """ + if isinstance(obj, list): + return [nested_simplify(item, decimals) for item in obj] + if isinstance(obj, tuple): + return tuple(nested_simplify(item, decimals) for item in obj) + if isinstance(obj, np.ndarray): + return nested_simplify(obj.tolist()) + if isinstance(obj, Mapping): + return {nested_simplify(k, decimals): nested_simplify(v, decimals) for k, v in obj.items()} + if isinstance(obj, (str, int, np.int64)): + return obj + if obj is None: + return obj + if is_mindspore_available() and ops.is_tensor(obj): + return nested_simplify(obj.numpy().tolist()) + if isinstance(obj, float): + return round(obj, decimals) + if isinstance(obj, (np.int32, np.float32)): + return nested_simplify(obj.item(), decimals) + raise RuntimeError(f"Not supported: {type(obj)}") + + +def to_2tuple(x): + """ + Converts the input value to a 2-tuple. + + Args: + x: The value to be converted. It can be of any type. + + Returns: + A 2-tuple with the input value. If the input value is already an iterable, it is returned as is. + Otherwise, a 2-tuple is created with the input value repeated twice. + + Raises: + None. + + """ + if isinstance(x, collections.abc.Iterable): + return x + return (x, x) + + +# These utils relate to ensuring the right error message is received when running scripts +class SubprocessCallException(Exception): + """SubprocessCallException""" +def run_command(command: List[str], return_stdout=False): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occurred while running `command` + """ + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + return None + +class RequestCounter: + """ + Helper class that will count all requests made online. + + Might not be robust if urllib3 changes its logging format but should be good enough for us. + + Usage: + ```py + with RequestCounter() as counter: + _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") + assert counter["GET"] == 0 + assert counter["HEAD"] == 1 + assert counter.total_calls == 1 + ``` + """ + def __enter__(self): + """ + __enter__ + + Args: + self: The instance of the RequestCounter class. + + Returns: + None. This method does not explicitly return a value. + + Raises: + No specific exceptions are raised within this method. + """ + self._counter = defaultdict(int) + self.patcher = patch.object(urllib3.connectionpool.log, "debug", wraps=urllib3.connectionpool.log.debug) + self.mock = self.patcher.start() + return self + + def __exit__(self, *args, **kwargs) -> None: + """ + This method '__exit__' in the class 'RequestCounter' is called upon exiting a context manager. It updates the request counters based on the logged HTTP methods. + + Args: + - self: An instance of the 'RequestCounter' class. It represents the current instance of the class. + + Returns: + - None: This method does not return any value. + + Raises: + This method does not explicitly raise any exceptions. + """ + for call in self.mock.call_args_list: + log = call.args[0] % call.args[1:] + for method in ("HEAD", "GET", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"): + if method in log: + self._counter[method] += 1 + break + self.patcher.stop() + + def __getitem__(self, key: str) -> int: + """ + Retrieve the count associated with the specified key from the RequestCounter. + + Args: + self (RequestCounter): An instance of the RequestCounter class. + key (str): The key for which the count needs to be retrieved. It should be a string representing the identifier of the request. + + Returns: + int: The count associated with the specified key. This count indicates the number of times the request identified by the key has been made. + + Raises: + KeyError: If the specified key does not exist in the RequestCounter, a KeyError is raised indicating that the count for the key cannot be retrieved. + """ + return self._counter[key] + + @property + def total_calls(self) -> int: + """ + Method to calculate the total number of calls made to the RequestCounter instance. + + Args: + self (RequestCounter): The instance of the RequestCounter class. + This parameter is automatically passed when calling the method. + + Returns: + int: The total number of calls made to the RequestCounter instance. + It is the sum of all the values stored in the internal counter. + + Raises: + No specific exceptions are raised by this method. + """ + return sum(self._counter.values()) + +def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None): + """ + To decorate flaky tests. They will be retried on failures. + + Args: + max_attempts (`int`, *optional*, defaults to 5): + The maximum number of attempts to retry the flaky test. + wait_before_retry (`float`, *optional*): + If provided, will wait that number of seconds before retrying the test. + description (`str`, *optional*): + A string to describe the situation (what / where / why is flaky, link to GH issue/PR comments, errors, + etc.) + """ + def decorator(test_func_ref): + @functools.wraps(test_func_ref) + def wrapper(*args, **kwargs): + retry_count = 1 + + while retry_count < max_attempts: + try: + return test_func_ref(*args, **kwargs) + + except Exception as err: + print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr) + if wait_before_retry is not None: + time.sleep(wait_before_retry) + retry_count += 1 + + return test_func_ref(*args, **kwargs) + + return wrapper + + return decorator + + +def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): + """ + To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. + + Args: + test_case (`unittest.TestCase`): + The test that will run `target_func`. + target_func (`Callable`): + The function implementing the actual testing logic. + inputs (`dict`, *optional*, defaults to `None`): + The inputs that will be passed to `target_func` through an (input) queue. + timeout (`int`, *optional*, defaults to `None`): + The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. + variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. + """ + if timeout is None: + timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) + + start_methohd = "spawn" + ctx = multiprocessing.get_context(start_methohd) + + input_queue = ctx.Queue(1) + output_queue = ctx.JoinableQueue(1) + + # We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. + input_queue.put(inputs, timeout=timeout) + + process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) + process.start() + # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents + # the test to exit properly. + try: + results = output_queue.get(timeout=timeout) + output_queue.task_done() + except Exception as e: + process.terminate() + test_case.fail(e) + process.join(timeout=timeout) + + if results["error"] is not None: + test_case.fail(f'{results["error"]}') + + +# The following contains utils to run the documentation tests without having to overwrite any files. + +# The `preprocess_string` function adds `# doctest: +IGNORE_RESULT` markers on the fly anywhere a `load_dataset` call is +# made as a print would otherwise fail the corresonding line. + +# To skip cuda tests, make sure to call `SKIP_CUDA_DOCTEST=1 pytest --doctest-modules + + +def preprocess_string(string, skip_cuda_tests): + """Prepare a docstring or a `.md` file to be run by doctest. + + The argument `string` would be the whole file content if it is a `.md` file. For a python file, it would be one of + its docstring. In each case, it may contain multiple python code examples. If `skip_cuda_tests` is `True` and a + cuda stuff is detective (with a heuristic), this method will return an empty string so no doctest will be run for + `string`. + """ + codeblock_pattern = r"(```(?:python|py)\s*\n\s*>>> )((?:.*?\n)*?.*?```)" + codeblocks = re.split(re.compile(codeblock_pattern, flags=re.MULTILINE | re.DOTALL), string) + is_cuda_found = False + for i, codeblock in enumerate(codeblocks): + if "load_dataset(" in codeblock and "# doctest: +IGNORE_RESULT" not in codeblock: + codeblocks[i] = re.sub(r"(>>> .*load_dataset\(.*)", r"\1 # doctest: +IGNORE_RESULT", codeblock) + if ( + (">>>" in codeblock or "..." in codeblock) + and re.search(r"cuda|to\(0\)|device=0", codeblock) + and skip_cuda_tests + ): + is_cuda_found = True + break + + modified_string = "" + if not is_cuda_found: + modified_string = "".join(codeblocks) + + return modified_string + + +class HfDocTestParser(doctest.DocTestParser): + """ + Overwrites the DocTestParser from doctest to properly parse the codeblocks that are formatted with black. This + means that there are no extra lines at the end of our snippets. The `# doctest: +IGNORE_RESULT` marker is also + added anywhere a `load_dataset` call is made as a print would otherwise fail the corresponding line. + + Tests involving cuda are skipped base on a naive pattern that should be updated if it is not enough. + """ + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + # fmt: off + _EXAMPLE_RE = re.compile(r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) >>> .*) # PS1 line + (?:\n [ ]* \.\.\. .*)*) # PS2 lines + \n? + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*>>>) # Not a line starting with PS1 + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:(?!```).)* # Match any character except '`' until a '```' is found (this is specific to HF because black removes the last line) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + (?:\n|$) # Match a new line or end of string + )*) + ''', re.MULTILINE | re.VERBOSE + ) + # fmt: on + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + skip_cuda_tests: bool = bool(os.environ.get("SKIP_CUDA_DOCTEST", False)) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + + def parse(self, string, name=""): + """ + Overwrites the `parse` method to incorporate a skip for CUDA tests, and remove logs and dataset prints before + calling `super().parse` + """ + string = preprocess_string(string, self.skip_cuda_tests) + return super().parse(string, name) + + +class HfDoctestModule(Module): + """ + Overwrites the `DoctestModule` of the pytest package to make sure the HFDocTestParser is used when discovering + tests. + """ + def collect(self) -> Iterable[DoctestItem]: + """ + Collects doctests from the specified module. + + Args: + self (HfDoctestModule): The instance of the HfDoctestModule class. + + Returns: + Iterable[DoctestItem]: A collection of doctests represented as DoctestItem objects. + + Raises: + ImportError: If the module cannot be imported and the 'doctest_ignore_import_errors' configuration option is not set. + Skip: If the 'doctest_ignore_import_errors' configuration option is set and the module cannot be imported. + """ + class MockAwareDocTestFinder(doctest.DocTestFinder): + """A hackish doctest finder that overrides stdlib internals to fix a stdlib bug. + + https://github.com/pytest-dev/pytest/issues/3456 https://bugs.python.org/issue25532 + """ + + def _find_lineno(self, obj, source_lines): + """Doctest code does not take into account `@property`, this + is a hackish way to fix it. https://bugs.python.org/issue17446 + + Wrapped Doctests will need to be unwrapped so the correct line number is returned. This will be + reported upstream. #8796 + """ + if isinstance(obj, property): + obj = getattr(obj, "fget", obj) + + if hasattr(obj, "__wrapped__"): + # Get the main obj in case of it being wrapped + obj = inspect.unwrap(obj) + + # Type ignored because this is a private function. + return super()._find_lineno( # type:ignore[misc] + obj, + source_lines, + ) + + def _find(self, tests, obj, name, module, source_lines, globs, seen) -> None: + if _is_mocked(obj): + return + with _patch_unwrap_mock_aware(): + # Type ignored because this is a private function. + super()._find( # type:ignore[misc] + tests, obj, name, module, source_lines, globs, seen + ) + + if self.path.name == "conftest.py": + module = self.config.pluginmanager._importconftest( + self.path, + self.config.getoption("importmode"), + rootpath=self.config.rootpath, + ) + else: + try: + module = import_path( + self.path, + root=self.config.rootpath, + mode=self.config.getoption("importmode"), + ) + except ImportError: + if self.config.getvalue("doctest_ignore_import_errors"): + skip(f"unable to import module {self.path}") + else: + raise + + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + finder = MockAwareDocTestFinder(parser=HfDocTestParser()) + # !!!!!!!!!!! HF Specific !!!!!!!!!!! + optionflags = get_optionflags(self) + runner = _get_runner( + verbose=False, + optionflags=optionflags, + checker=_get_checker(), + continue_on_failure=_get_continue_on_failure(self.config), + ) + for test in finder.find(module, module.__name__): + if test.examples: # skip empty doctests and cuda + yield DoctestItem.from_parent(self, name=test.name, runner=runner, dtest=test) + + +def _device_agnostic_dispatch(device: str, dispatch_table: Dict[str, Callable], *args, **kwargs): + """ + Executes a device-agnostic dispatch based on the given device and dispatch table. + + Args: + device (str): The device for which the dispatch is performed. + dispatch_table (Dict[str, Callable]): A dictionary containing the dispatch functions for different devices. + + Returns: + None: Returns None if the dispatch function for the given device is None. + + Raises: + None: This function does not raise any exceptions. + """ + if device not in dispatch_table: + return dispatch_table["default"](*args, **kwargs) + + fn = dispatch_table[device] + + # Some device agnostic functions return values. Need to guard against `None` + # instead at user level. + if fn is None: + return None + return fn(*args, **kwargs) + +def get_tests_dir(append_path=None): + """ + Args: + append_path: optional path to append to the tests dir path + + Return: + The full path to the `tests` dir, so that the tests can be invoked from anywhere. Optionally `append_path` is + joined after the `tests` dir the former is provided. + + """ + # this function caller's __file__ + caller__file__ = inspect.stack()[1][1] + tests_dir = os.path.abspath(os.path.dirname(caller__file__)) + + while not tests_dir.endswith("tests"): + tests_dir = os.path.dirname(tests_dir) + + if append_path: + return os.path.join(tests_dir, append_path) + return tests_dir + +def check_json_file_has_correct_format(file_path): + ''' + Check if the provided JSON file has the correct format. + + Args: + file_path (str): The path to the JSON file to be checked. + + Returns: + None: This function does not return any value. + + Raises: + AssertionError: If the JSON file does not have the correct format as per the specified conditions. + FileNotFoundError: If the specified file_path does not exist. + UnicodeDecodeError: If the file cannot be decoded using the specified encoding. + ''' + with open(file_path, "r", encoding='utf-8') as f: + lines = f.readlines() + if len(lines) == 1: + # length can only be 1 if dict is empty + assert lines[0] == "{}" + else: + # otherwise make sure json has correct format (at least 3 lines) + assert len(lines) >= 3 + # each key one line, ident should be 2, min length is 3 + assert lines[0].strip() == "{" + for _ in lines[1:-1]: + left_indent = len(lines[1]) - len(lines[1].lstrip()) + assert left_indent == 2 + assert lines[-1].strip() == "}" + +_run_staging = parse_flag_from_env("MINDNLP_CO_STAGING", default=False) + +def is_staging_test(test_case): + """ + Decorator marking a test as a staging test. + + Those tests will run using the staging environment of huggingface.co instead of the real model hub. + """ + if not _run_staging: + return unittest.skip("test is staging test")(test_case) + else: + try: + import pytest # We don't need a hard dependency on pytest in the main library + except ImportError: + return test_case + else: + return pytest.mark.is_staging_test()(test_case) + + +def require_soundfile(test_case): + """ + Decorator marking a test that requires soundfile + + These tests are skipped when soundfile isn't installed. + + """ + return unittest.skipUnless(is_soundfile_availble(), "test requires soundfile")( + test_case + ) + +def backend_empty_cache(): + if hasattr(mindspore, 'hal'): + mindspore.hal.empty_cache() \ No newline at end of file