diff --git a/docs/source/Instruction/Supported-models-and-datasets.md b/docs/source/Instruction/Supported-models-and-datasets.md index d953dc4e84..1b4e35058f 100644 --- a/docs/source/Instruction/Supported-models-and-datasets.md +++ b/docs/source/Instruction/Supported-models-and-datasets.md @@ -1024,6 +1024,7 @@ |[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)| |[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)| |[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)| +|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)| |[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)| |[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)| diff --git a/docs/source_en/Instruction/Supported-models-and-datasets.md b/docs/source_en/Instruction/Supported-models-and-datasets.md index 13bbdb6ac9..808cec5a85 100644 --- a/docs/source_en/Instruction/Supported-models-and-datasets.md +++ b/docs/source_en/Instruction/Supported-models-and-datasets.md @@ -1024,6 +1024,7 @@ The table below introduces the models integrated with ms-swift: |[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)| |[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)| |[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)| +|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)| |[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)| |[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)| diff --git a/swift/llm/model/constant.py b/swift/llm/model/constant.py index 72c96fe343..dd252e0fd8 100644 --- a/swift/llm/model/constant.py +++ b/swift/llm/model/constant.py @@ -274,6 +274,7 @@ class MLLMModelType: gemma3_vision = 'gemma3_vision' gemma3n = 'gemma3n' mistral_2503 = 'mistral_2503' + mistral_2506 = 'mistral_2506' paddle_ocr = 'paddle_ocr' diff --git a/swift/llm/model/model/mistral.py b/swift/llm/model/model/mistral.py index 6d23b23b41..ceaa6e6ed8 100644 --- a/swift/llm/model/model/mistral.py +++ b/swift/llm/model/model/mistral.py @@ -1,8 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. - from typing import Any, Dict -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from swift.llm import TemplateType from ..constant import LLMModelType, MLLMModelType @@ -130,12 +129,7 @@ def get_model_tokenizer_mistral_2503(model_dir: str, model_kwargs: Dict[str, Any], load_model: bool = True, **kwargs): - try: - from transformers import Mistral3ForConditionalGeneration - except ImportError: - raise ImportError('Please install Mistral3ForConditionalGeneration by running ' - '`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`') - + from transformers import Mistral3ForConditionalGeneration kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs) @@ -184,4 +178,35 @@ def get_model_tokenizer_devstral_2505(model_dir: str, architectures=['Mistral3ForConditionalGeneration'], model_arch=ModelArch.llava_hf, requires=['transformers>=4.49'], - ), ) + )) + + +def get_model_tokenizer_mistral_2506(model_dir: str, + model_info: ModelInfo, + model_kwargs: Dict[str, Any], + load_model: bool = True, + **kwargs): + from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + from transformers import Mistral3ForConditionalGeneration + tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False) + processor = AutoProcessor.from_pretrained(tokenizer_dir) + kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration + kwargs['tokenizer'] = processor.tokenizer + model, _ = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs) + return model, processor + + +register_model( + ModelMeta( + MLLMModelType.mistral_2506, + [ + ModelGroup([ + Model('mistralai/Mistral-Small-3.2-24B-Instruct-2506', 'mistralai/Mistral-Small-3.2-24B-Instruct-2506'), + ]), + ], + TemplateType.mistral_2506, + get_model_tokenizer_mistral_2506, + architectures=['Mistral3ForConditionalGeneration'], + model_arch=ModelArch.llava_hf, + requires=['transformers>=4.49'], + )) diff --git a/swift/llm/model/model_arch.py b/swift/llm/model/model_arch.py index 5c8e2f4051..fed4ff8055 100644 --- a/swift/llm/model/model_arch.py +++ b/swift/llm/model/model_arch.py @@ -81,7 +81,6 @@ class MLLMModelArch: megrez_omni = 'megrez_omni' valley = 'valley' gemma3n = 'gemma3n' - mistral_2503 = 'mistral_2503' keye_vl = 'keye_vl' midashenglm = 'midashenglm' diff --git a/swift/llm/template/constant.py b/swift/llm/template/constant.py index 4a073f5442..82810452ad 100644 --- a/swift/llm/template/constant.py +++ b/swift/llm/template/constant.py @@ -229,6 +229,7 @@ class MLLMTemplateType: gemma3_vision = 'gemma3_vision' gemma3n = 'gemma3n' mistral_2503 = 'mistral_2503' + mistral_2506 = 'mistral_2506' paddle_ocr = 'paddle_ocr' diff --git a/swift/llm/template/template/llm.py b/swift/llm/template/template/llm.py index 67d9a1514e..dbd9aa2850 100644 --- a/swift/llm/template/template/llm.py +++ b/swift/llm/template/template/llm.py @@ -119,29 +119,6 @@ def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None: chat_sep=['[INST] '], suffix=[''])) -today = datetime.now().strftime('%Y-%m-%d') - -mistral_2501_system = ( - 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' - 'headquartered in Paris.\n' - f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' - "When you're not sure about some information, you say that you don't have the information and don't " - 'make up anything.\n' - "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " - 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' - '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' - 'Where do you travel from?")') - -register_template( - TemplateMeta( - LLMTemplateType.mistral_2501, - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], - default_system=mistral_2501_system)) - register_template( TemplateMeta( LLMTemplateType.xverse, diff --git a/swift/llm/template/template/mistral.py b/swift/llm/template/template/mistral.py index 679599a6d8..ec2870df8f 100644 --- a/swift/llm/template/template/mistral.py +++ b/swift/llm/template/template/mistral.py @@ -1,14 +1,39 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +import os +from dataclasses import dataclass, field +from datetime import datetime, timedelta from typing import Any, Dict, List, Literal, Optional -import torch - from ..base import Template -from ..constant import MLLMTemplateType +from ..constant import LLMTemplateType, MLLMTemplateType from ..register import TemplateMeta, register_template from ..template_inputs import StdTemplateInputs -from ..utils import Context, findall -from .llm import mistral_2501_system +from ..utils import Context, Prompt, findall + +today = datetime.now().strftime('%Y-%m-%d') + +mistral_2501_system = ( + 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' + 'headquartered in Paris.\n' + f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' + "When you're not sure about some information, you say that you don't have the information and don't " + 'make up anything.\n' + "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " + 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' + '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' + 'Where do you travel from?")') + + +@dataclass +class Mistral3TemplateMeta(TemplateMeta): + prefix: Prompt = field(default_factory=lambda: ['']) + prompt: Prompt = field(default_factory=lambda: ['[INST]{{QUERY}}[/INST]']) + chat_sep: Optional[Prompt] = field(default_factory=lambda: ['']) + suffix: Prompt = field(default_factory=lambda: ['']) + system_prefix: Optional[Prompt] = field(default_factory=lambda: ['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]']) + + +register_template(Mistral3TemplateMeta(LLMTemplateType.mistral_2501, default_system=mistral_2501_system)) class Mistral2503Template(Template): @@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: labels = encoded['labels'] loss_scale = encoded.get('loss_scale', None) idx_list = findall(input_ids, self.image_token) + patch_size = processor.patch_size * processor.spatial_merge_size if idx_list: - image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') + image_inputs = processor.image_processor(images, patch_size=patch_size, return_tensors='pt') encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype) encoded['image_sizes'] = image_sizes = image_inputs['image_sizes'] def _get_new_tokens(i): height, width = image_sizes[i] - num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size) - num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size) + num_height_tokens = height // patch_size + num_width_tokens = width // patch_size replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token] ] * num_height_tokens # Flatten list @@ -52,15 +78,8 @@ def _get_new_tokens(i): register_template( - TemplateMeta( - MLLMTemplateType.mistral_2503, - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], - default_system=mistral_2501_system, - template_cls=Mistral2503Template)) + Mistral3TemplateMeta( + MLLMTemplateType.mistral_2503, default_system=mistral_2501_system, template_cls=Mistral2503Template)) devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. ' @@ -122,12 +141,27 @@ def _get_new_tokens(i): 'executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new ' 'plan and confirm with the user before proceeding.\n') +register_template(Mistral3TemplateMeta('devstral', default_system=devstral_small_2505_system)) + + +class Mistral2506Template(Mistral2503Template): + + def _get_mistral_system(self): + from swift.llm import get_model_name + model_dir = self.model_info.model_dir + model_name = get_model_name(model_dir) + file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt') + with open(file_path, 'r') as file: + system_prompt = file.read() + today = datetime.today().strftime('%Y-%m-%d') + yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d') + return system_prompt.format(name=model_name, today=today, yesterday=yesterday) + + def _swift_encode(self, inputs: StdTemplateInputs): + if inputs.system is None: + inputs.system = self._get_mistral_system() + return super()._swift_encode(inputs) + + register_template( - TemplateMeta( - 'devstral', - prefix=[''], - prompt=['[INST]{{QUERY}}[/INST]'], # the user query - chat_sep=[''], - suffix=[''], - system_prefix=['[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], # the system prompt - default_system=devstral_small_2505_system)) + Mistral3TemplateMeta(MLLMTemplateType.mistral_2506, default_system=None, template_cls=Mistral2506Template)) diff --git a/tests/test_align/test_template/test_vision.py b/tests/test_align/test_template/test_vision.py index 46f8f494f1..ab5d72baa2 100644 --- a/tests/test_align/test_template/test_vision.py +++ b/tests/test_align/test_template/test_vision.py @@ -1092,6 +1092,15 @@ def test_ernie_vl_thinking(): assert response == '\n\n' + response2 +def test_mistral_2506(): + pt_engine = PtEngine('mistralai/Mistral-Small-3.2-24B-Instruct-2506') + response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': 'describe the image.'}]) + assert response[:200] == ( + 'The image features a close-up of a kitten with striking blue eyes. The kitten has a soft, ' + 'fluffy coat with a mix of white, gray, and brown fur. Its fur pattern includes distinct ' + 'stripes, particularly ') + + if __name__ == '__main__': from swift.llm import PtEngine, RequestConfig from swift.utils import get_logger, seed_everything @@ -1168,4 +1177,5 @@ def test_ernie_vl_thinking(): # test_llava_onevision1_5() # test_paddle_ocr() # test_ernie_vl() - test_ernie_vl_thinking() + # test_ernie_vl_thinking() + test_mistral_2506()