-
Notifications
You must be signed in to change notification settings - Fork 1k
[model] support mistral 2506 #6624
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,14 +1,39 @@ | ||
| # Copyright (c) Alibaba, Inc. and its affiliates. | ||
| import os | ||
| from dataclasses import dataclass, field | ||
| from datetime import datetime, timedelta | ||
| from typing import Any, Dict, List, Literal, Optional | ||
|
|
||
| import torch | ||
|
|
||
| from ..base import Template | ||
| from ..constant import MLLMTemplateType | ||
| from ..constant import LLMTemplateType, MLLMTemplateType | ||
| from ..register import TemplateMeta, register_template | ||
| from ..template_inputs import StdTemplateInputs | ||
| from ..utils import Context, findall | ||
| from .llm import mistral_2501_system | ||
| from ..utils import Context, Prompt, findall | ||
|
|
||
| today = datetime.now().strftime('%Y-%m-%d') | ||
|
|
||
| mistral_2501_system = ( | ||
| 'You are Mistral Small 3, a Large Language Model (LLM) created by Mistral AI, a French startup ' | ||
| 'headquartered in Paris.\n' | ||
| f'Your knowledge base was last updated on 2023-10-01. The current date is {today}.\n\n' | ||
|
Comment on lines
+13
to
+18
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The This would involve:
|
||
| "When you're not sure about some information, you say that you don't have the information and don't " | ||
| 'make up anything.\n' | ||
| "If the user's question is not clear, ambiguous, or does not provide enough context for you to accurately answer " | ||
| 'the question, you do not try to answer it right away and you rather ask the user to clarify their request (e.g. ' | ||
| '"What are some good restaurants around me?" => "Where are you?" or "When is the next flight to Tokyo" => "' | ||
| 'Where do you travel from?")') | ||
|
|
||
|
|
||
| @dataclass | ||
| class Mistral3TemplateMeta(TemplateMeta): | ||
| prefix: Prompt = field(default_factory=lambda: ['<s>']) | ||
| prompt: Prompt = field(default_factory=lambda: ['[INST]{{QUERY}}[/INST]']) | ||
| chat_sep: Optional[Prompt] = field(default_factory=lambda: ['</s>']) | ||
| suffix: Prompt = field(default_factory=lambda: ['</s>']) | ||
| system_prefix: Optional[Prompt] = field(default_factory=lambda: ['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]']) | ||
|
|
||
|
|
||
| register_template(Mistral3TemplateMeta(LLMTemplateType.mistral_2501, default_system=mistral_2501_system)) | ||
|
|
||
|
|
||
| class Mistral2503Template(Template): | ||
|
|
@@ -28,15 +53,16 @@ def _encode(self, inputs: StdTemplateInputs) -> Dict[str, Any]: | |
| labels = encoded['labels'] | ||
| loss_scale = encoded.get('loss_scale', None) | ||
| idx_list = findall(input_ids, self.image_token) | ||
| patch_size = processor.patch_size * processor.spatial_merge_size | ||
| if idx_list: | ||
| image_inputs = processor.image_processor(images, patch_size=processor.patch_size, return_tensors='pt') | ||
| image_inputs = processor.image_processor(images, patch_size=patch_size, return_tensors='pt') | ||
| encoded['pixel_values'] = image_inputs['pixel_values'].to(self.model_info.torch_dtype) | ||
| encoded['image_sizes'] = image_sizes = image_inputs['image_sizes'] | ||
|
|
||
| def _get_new_tokens(i): | ||
| height, width = image_sizes[i] | ||
| num_height_tokens = height // (processor.patch_size * processor.spatial_merge_size) | ||
| num_width_tokens = width // (processor.patch_size * processor.spatial_merge_size) | ||
| num_height_tokens = height // patch_size | ||
| num_width_tokens = width // patch_size | ||
| replace_tokens = [[processor.image_token] * num_width_tokens + [processor.image_break_token] | ||
| ] * num_height_tokens | ||
| # Flatten list | ||
|
|
@@ -52,15 +78,8 @@ def _get_new_tokens(i): | |
|
|
||
|
|
||
| register_template( | ||
| TemplateMeta( | ||
| MLLMTemplateType.mistral_2503, | ||
| prefix=['<s>'], | ||
| prompt=['[INST]{{QUERY}}[/INST]'], | ||
| chat_sep=['</s>'], | ||
| suffix=['</s>'], | ||
| system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], | ||
| default_system=mistral_2501_system, | ||
| template_cls=Mistral2503Template)) | ||
| Mistral3TemplateMeta( | ||
| MLLMTemplateType.mistral_2503, default_system=mistral_2501_system, template_cls=Mistral2503Template)) | ||
|
|
||
| devstral_small_2505_system = ( # from https://huggingface.co/mistralai/Devstral-Small-2505/blob/main/SYSTEM_PROMPT.txt | ||
| 'You are Devstral, a helpful agentic model trained by Mistral AI and using the OpenHands scaffold. ' | ||
|
|
@@ -122,12 +141,27 @@ def _get_new_tokens(i): | |
| 'executing a plan from the user, please don\'t try to directly work around it. Instead, propose a new ' | ||
| 'plan and confirm with the user before proceeding.\n</TROUBLESHOOTING>') | ||
|
|
||
| register_template(Mistral3TemplateMeta('devstral', default_system=devstral_small_2505_system)) | ||
|
|
||
|
|
||
| class Mistral2506Template(Mistral2503Template): | ||
|
|
||
| def _get_mistral_system(self): | ||
| from swift.llm import get_model_name | ||
| model_dir = self.model_info.model_dir | ||
| model_name = get_model_name(model_dir) | ||
| file_path = os.path.join(model_dir, 'SYSTEM_PROMPT.txt') | ||
| with open(file_path, 'r') as file: | ||
| system_prompt = file.read() | ||
| today = datetime.today().strftime('%Y-%m-%d') | ||
| yesterday = (datetime.today() - timedelta(days=1)).strftime('%Y-%m-%d') | ||
| return system_prompt.format(name=model_name, today=today, yesterday=yesterday) | ||
|
|
||
| def _swift_encode(self, inputs: StdTemplateInputs): | ||
| if inputs.system is None: | ||
| inputs.system = self._get_mistral_system() | ||
| return super()._swift_encode(inputs) | ||
|
|
||
|
|
||
| register_template( | ||
| TemplateMeta( | ||
| 'devstral', | ||
| prefix=['<s>'], | ||
| prompt=['[INST]{{QUERY}}[/INST]'], # the user query | ||
| chat_sep=['</s>'], | ||
| suffix=['</s>'], | ||
| system_prefix=['<s>[SYSTEM_PROMPT]{{SYSTEM}}[/SYSTEM_PROMPT]'], # the system prompt | ||
| default_system=devstral_small_2505_system)) | ||
| Mistral3TemplateMeta(MLLMTemplateType.mistral_2506, default_system=None, template_cls=Mistral2506Template)) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
MistralTokenizeris imported but not used in this function. It should be removed to keep the code clean.