Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/Instruction/Supported-models-and-datasets.md
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)|
|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|
|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,6 +1024,7 @@ The table below introduces the models integrated with ms-swift:
|[google/gemma-3n-E4B-it](https://modelscope.cn/models/google/gemma-3n-E4B-it)|gemma3n|gemma3n|transformers>=4.53.1|✘|-|[google/gemma-3n-E4B-it](https://huggingface.co/google/gemma-3n-E4B-it)|
|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Base-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Base-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Base-2503)|
|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://modelscope.cn/models/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|mistral_2503|mistral_2503|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.1-24B-Instruct-2503](https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503)|
|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://modelscope.cn/models/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|mistral_2506|mistral_2506|transformers>=4.49|✘|-|[mistralai/Mistral-Small-3.2-24B-Instruct-2506](https://huggingface.co/mistralai/Mistral-Small-3.2-24B-Instruct-2506)|
|[PaddlePaddle/PaddleOCR-VL](https://modelscope.cn/models/PaddlePaddle/PaddleOCR-VL)|paddle_ocr|paddle_ocr|-|✘|-|[PaddlePaddle/PaddleOCR-VL](https://huggingface.co/PaddlePaddle/PaddleOCR-VL)|
|[JinaAI/jina-reranker-m0](https://modelscope.cn/models/JinaAI/jina-reranker-m0)|jina_reranker_m0|jina_reranker_m0|-|✘|reranker, vision|[JinaAI/jina-reranker-m0](https://huggingface.co/JinaAI/jina-reranker-m0)|

Expand Down
1 change: 1 addition & 0 deletions swift/llm/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ class MLLMModelType:
gemma3_vision = 'gemma3_vision'
gemma3n = 'gemma3n'
mistral_2503 = 'mistral_2503'
mistral_2506 = 'mistral_2506'
paddle_ocr = 'paddle_ocr'


Expand Down
43 changes: 34 additions & 9 deletions swift/llm/model/model/mistral.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from typing import Any, Dict

from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer

from swift.llm import TemplateType
from ..constant import LLMModelType, MLLMModelType
Expand Down Expand Up @@ -130,12 +129,7 @@ def get_model_tokenizer_mistral_2503(model_dir: str,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
try:
from transformers import Mistral3ForConditionalGeneration
except ImportError:
raise ImportError('Please install Mistral3ForConditionalGeneration by running '
'`pip install git+https://github.com/huggingface/transformers@v4.49.0-Mistral-3`')

from transformers import Mistral3ForConditionalGeneration
kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
model, processor = get_model_tokenizer_multimodal(model_dir, model_info, model_kwargs, load_model, **kwargs)

Expand Down Expand Up @@ -184,4 +178,35 @@ def get_model_tokenizer_devstral_2505(model_dir: str,
architectures=['Mistral3ForConditionalGeneration'],
model_arch=ModelArch.llava_hf,
requires=['transformers>=4.49'],
), )
))


def get_model_tokenizer_mistral_2506(model_dir: str,
model_info: ModelInfo,
model_kwargs: Dict[str, Any],
load_model: bool = True,
**kwargs):
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The MistralTokenizer is imported but not used in this function. It should be removed to keep the code clean.

from transformers import Mistral3ForConditionalGeneration
tokenizer_dir = safe_snapshot_download('mistralai/Mistral-Small-3.1-24B-Instruct-2503', download_model=False)
processor = AutoProcessor.from_pretrained(tokenizer_dir)
kwargs['automodel_class'] = kwargs['automodel_class'] or Mistral3ForConditionalGeneration
kwargs['tokenizer'] = processor.tokenizer
model, _ = get_model_tokenizer_with_flash_attn(model_dir, model_info, model_kwargs, load_model, **kwargs)
return model, processor


register_model(
ModelMeta(
MLLMModelType.mistral_2506,
[
ModelGroup([
Model('mistralai/Mistral-Small-3.2-24B-Instruct-2506', 'mistralai/Mistral-Small-3.2-24B-Instruct-2506'),
]),
],
TemplateType.mistral_2506,
get_model_tokenizer_mistral_2506,
architectures=['Mistral3ForConditionalGeneration'],
model_arch=ModelArch.llava_hf,
requires=['transformers>=4.49'],
))
1 change: 0 additions & 1 deletion swift/llm/model/model_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ class MLLMModelArch:
megrez_omni = 'megrez_omni'
valley = 'valley'
gemma3n = 'gemma3n'
mistral_2503 = 'mistral_2503'
keye_vl = 'keye_vl'

midashenglm = 'midashenglm'
Expand Down
1 change: 1 addition & 0 deletions swift/llm/template/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ class MLLMTemplateType:
gemma3_vision = 'gemma3_vision'
gemma3n = 'gemma3n'
mistral_2503 = 'mistral_2503'
mistral_2506 = 'mistral_2506'
paddle_ocr = 'paddle_ocr'


Expand Down
23 changes: 0 additions & 23 deletions swift/llm/template/template/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,29 +119,6 @@ def _preprocess_inputs(self, inputs: StdTemplateInputs) -> None:
chat_sep=['</s>[INST] '],
suffix=['</s>']))

today = datetime.now().strftime('%Y-%m-%d')

mistral_2501_system = (
'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
'headquartered in Paris.\n'
f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
"When you're not sure about some information, you say that you don't have the information and don't "
'make up anything.\n'
"If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
'"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
'Where do you travel from?")')

register_template(
TemplateMeta(
LLMTemplateType.mistral_2501,
prefix=['<s>'],
prompt=['[INST]{{QUERY}}[/INST]'],
chat_sep=['</s>'],
suffix=['</s>'],
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
default_system=mistral_2501_system))

register_template(
TemplateMeta(
LLMTemplateType.xverse,
Expand Down
84 changes: 59 additions & 25 deletions swift/llm/template/template/mistral.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Any, Dict, List, Literal, Optional

import torch

from ..base import Template
from ..constant import MLLMTemplateType
from ..constant import LLMTemplateType, MLLMTemplateType
from ..register import TemplateMeta, register_template
from ..template_inputs import StdTemplateInputs
from ..utils import Context, findall
from .llm import mistral_2501_system
from ..utils import Context, Prompt, findall

today = datetime.now().strftime('%Y-%m-%d')

mistral_2501_system = (
'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup '
'headquartered in Paris.\n'
f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n'
Comment on lines +13 to +18
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The today variable is defined at the module level, which means it's only evaluated once when the module is imported. If the application runs for more than a day, the date in the system prompt will become stale. This can lead to incorrect behavior from the model. A similar issue was addressed for Mistral2506Template by dynamically generating the system prompt. Please apply a similar fix here for mistral_2501 and mistral_2503 templates.

This would involve:

  1. Creating a get_mistral_2501_system() function that returns the system prompt with the current date.
  2. Creating a Mistral2501Template(Template) class with an _swift_encode method that calls get_mistral_2501_system() to set the system prompt if it's not provided.
  3. Updating the registration for mistral_2501 to use this new template class with default_system=None.
  4. Changing Mistral2503Template to inherit from Mistral2501Template.
  5. Updating the registration for mistral_2503 to set default_system=None.

"When you're not sure about some information, you say that you don't have the information and don't "
'make up anything.\n'
"If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer "
'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. '
'"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "'
'Where do you travel from?")')


@dataclass
class Mistral3TemplateMeta(TemplateMeta):
prefix: Prompt = field(default_factory=lambda: ['<s>'])
prompt: Prompt = field(default_factory=lambda: ['[INST]{{QUERY}}[/INST]'])
chat_sep: Optional[Prompt] = field(default_factory=lambda: ['</s>'])
suffix: Prompt = field(default_factory=lambda: ['</s>'])
system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'])


register_template(Mistral3TemplateMeta(LLMTemplateType.mistral_2501, default_system=mistral_2501_system))


class Mistral2503Template(Template):
Expand All @@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]:
labels = encoded['labels']
loss_scale = encoded.get('loss_scale', None)
idx_list = findall(input_ids, self.image_token)
patch_size = processor.patch_size * processor.spatial_merge_size
if idx_list:
image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt')
image_inputs = processor.image_processor(images, patch_size=patch_size, return_tensors='pt')
encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype)
encoded['image_sizes'] = image_sizes = image_inputs['image_sizes']

def _get_new_tokens(i):
height, width = image_sizes[i]
num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size)
num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size)
num_height_tokens = height // patch_size
num_width_tokens = width // patch_size
replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token]
] * num_height_tokens
# Flatten list
Expand All @@ -52,15 +78,8 @@ def _get_new_tokens(i):


register_template(
TemplateMeta(
MLLMTemplateType.mistral_2503,
prefix=['<s>'],
prompt=['[INST]{{QUERY}}[/INST]'],
chat_sep=['</s>'],
suffix=['</s>'],
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'],
default_system=mistral_2501_system,
template_cls=Mistral2503Template))
Mistral3TemplateMeta(
MLLMTemplateType.mistral_2503, default_system=mistral_2501_system, template_cls=Mistral2503Template))

devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt
'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. '
Expand Down Expand Up @@ -122,12 +141,27 @@ def _get_new_tokens(i):
'executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new '
'plan and confirm with the user before proceeding.\n</TROUBLESHOOTING>')

register_template(Mistral3TemplateMeta('devstral', default_system=devstral_small_2505_system))


class Mistral2506Template(Mistral2503Template):

def _get_mistral_system(self):
from swift.llm import get_model_name
model_dir = self.model_info.model_dir
model_name = get_model_name(model_dir)
file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt')
with open(file_path, 'r') as file:
system_prompt = file.read()
today = datetime.today().strftime('%Y-%m-%d')
yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d')
return system_prompt.format(name=model_name, today=today, yesterday=yesterday)

def _swift_encode(self, inputs: StdTemplateInputs):
if inputs.system is None:
inputs.system = self._get_mistral_system()
return super()._swift_encode(inputs)


register_template(
TemplateMeta(
'devstral',
prefix=['<s>'],
prompt=['[INST]{{QUERY}}[/INST]'], # the user query
chat_sep=['</s>'],
suffix=['</s>'],
system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], # the system prompt
default_system=devstral_small_2505_system))
Mistral3TemplateMeta(MLLMTemplateType.mistral_2506, default_system=None, template_cls=Mistral2506Template))
12 changes: 11 additions & 1 deletion tests/test_align/test_template/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,15 @@ def test_ernie_vl_thinking():
assert response == '\n<think>\n' + response2


def test_mistral_2506():
pt_engine = PtEngine('mistralai/Mistral-Small-3.2-24B-Instruct-2506')
response = _infer_model(pt_engine, messages=[{'role': 'user', 'content': 'describe the image.'}])
assert response[:200] == (
'The image features a close-up of a kitten with striking blue eyes. The kitten has a soft, '
'fluffy coat with a mix of white, gray, and brown fur. Its fur pattern includes distinct '
'stripes, particularly ')


if __name__ == '__main__':
from swift.llm import PtEngine, RequestConfig
from swift.utils import get_logger, seed_everything
Expand Down Expand Up @@ -1168,4 +1177,5 @@ def test_ernie_vl_thinking():
# test_llava_onevision1_5()
# test_paddle_ocr()
# test_ernie_vl()
test_ernie_vl_thinking()
# test_ernie_vl_thinking()
test_mistral_2506()
Loading