Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ Users can check the [documentation of SWIFT](docs/source/GetStarted/快速使用
- bluelm series: [bluelm-7b](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base/summary), [bluelm-7b-chat](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat/summary), [bluelm-7b-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base-32K/summary), [bluelm-7b-chat-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat-32K/summary)
- mistral series: [mistral-7b](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1/summary), [mistral-7b-chat](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)
- yi series: [yi-6b](https://modelscope.cn/models/01ai/Yi-6B/summary), [yi-34b](https://modelscope.cn/models/01ai/Yi-34B/summary), [yi-34b-chat](https://modelscope.cn/models/01ai/Yi-34B-Chat/summary)
- zephyr series: zephyr-7b-beta-chat(https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
- zephyr series: [zephyr-7b-beta-chat](https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
- ziya series: [ziya2-13b](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Base/summary), [ziya2-13b-chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
- skywork series: [skywork-13b](https://modelscope.cn/models/skywork/Skywork-13B-base/summary), [skywork-13b-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat/summary)
- other: [polylm-13b](https://modelscope.cn/models/damo/nlp_polylm_13b_text_generation/summary), [seqgpt-560m](https://modelscope.cn/models/damo/nlp_seqgpt-560m/summary)
Expand Down
2 changes: 1 addition & 1 deletion README_CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ SWIFT(Scalable lightWeight Infrastructure for Fine-Tuning)是一个可扩展
- bluelm 系列: [bluelm-7b](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base/summary), [bluelm-7b-chat](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat/summary), [bluelm-7b-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Base-32K/summary), [bluelm-7b-chat-32k](https://modelscope.cn/models/vivo-ai/BlueLM-7B-Chat-32K/summary)
- mistral 系列: [mistral-7b](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-v0.1/summary), [mistral-7b-chat](https://modelscope.cn/models/AI-ModelScope/Mistral-7B-Instruct-v0.1/summary)
- yi 系列: [yi-6b](https://modelscope.cn/models/01ai/Yi-6B/summary), [yi-34b](https://modelscope.cn/models/01ai/Yi-34B/summary), [yi-34b-chat](https://modelscope.cn/models/01ai/Yi-34B-Chat/summary)
- zephyr 系列: zephyr-7b-beta-chat(https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
- zephyr 系列: [zephyr-7b-beta-chat](https://modelscope.cn/models/modelscope/zephyr-7b-beta/summary)
- ziya 系列: [ziya2-13b](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Base/summary), [ziya2-13b-chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
- skywork 系列: [skywork-13b](https://modelscope.cn/models/skywork/Skywork-13B-base/summary), [skywork-13b-chat](https://modelscope.cn/models/skywork/Skywork-13B-chat/summary)
- other: [polylm-13b](https://modelscope.cn/models/damo/nlp_polylm_13b_text_generation/summary), [seqgpt-560m](https://modelscope.cn/models/damo/nlp_seqgpt-560m/summary)
Expand Down
5 changes: 3 additions & 2 deletions swift/llm/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from swift.utils import (append_to_jsonl, get_logger, print_model_info,
read_multi_line, seed_everything, show_layers)
from .utils import (InferArguments, Template, get_dataset, get_model_tokenizer,
get_template, inference, inference_stream)
get_template, inference, inference_stream,
set_generation_config)

logger = get_logger()

Expand Down Expand Up @@ -141,7 +142,7 @@ def prepare_model_template(
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
logger.info(f'generation_config: {generation_config}')
model.generation_config = generation_config
set_generation_config(model, generation_config)
return model, template


Expand Down
4 changes: 2 additions & 2 deletions swift/llm/rome.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
show_layers)
from ..tuners.rome import RomeConfig
from .utils import (RomeArguments, Template, get_dataset, get_model_tokenizer,
get_template, inference)
get_template, inference, set_generation_config)

logger = get_logger()

Expand Down Expand Up @@ -72,7 +72,7 @@ def rome_infer(args: RomeArguments) -> None:
logger.info(f'generation_config: {generation_config}')
if args.overwrite_generation_config:
generation_config.save_pretrained(args.ckpt_dir)
model.generation_config = generation_config
set_generation_config(model, generation_config)

# Inference
if args.eval_human:
Expand Down
12 changes: 9 additions & 3 deletions swift/llm/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
seed_everything, show_layers)
from .utils import (SftArguments, Template, add_self_cognition_dataset,
data_collate_fn, dataset_map, find_all_linear_for_lora,
get_dataset, get_model_tokenizer, get_template,
print_example, sort_by_max_length, stat_dataset)
get_additional_saved_files, get_dataset,
get_model_tokenizer, get_template, print_example,
set_generation_config, sort_by_max_length, stat_dataset)

logger = get_logger()

Expand Down Expand Up @@ -182,11 +183,15 @@ def llm_sft(args: SftArguments) -> str:
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id)
logger.info(f'generation_config: {generation_config}')
set_generation_config(model, generation_config)
evaluation_strategy = IntervalStrategy.STEPS
load_best_model_at_end = True
if val_dataset is None:
evaluation_strategy = IntervalStrategy.NO
load_best_model_at_end = False
additional_saved_files = []
if args.sft_type == 'full':
additional_saved_files = get_additional_saved_files(args.model_type)
training_args = Seq2SeqTrainingArguments(
output_dir=args.output_dir,
evaluation_strategy=evaluation_strategy,
Expand Down Expand Up @@ -230,7 +235,8 @@ def llm_sft(args: SftArguments) -> str:
only_save_model=args.only_save_model,
train_sampler_random=args.train_sampler_random,
report_to=args.report_to,
deepspeed=args.deepspeed)
deepspeed=args.deepspeed,
additional_saved_files=additional_saved_files)

if args.gradient_checkpointing:
model.enable_input_require_grads()
Expand Down
10 changes: 5 additions & 5 deletions swift/llm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
get_dataset_from_repo, load_dataset_from_local,
load_ms_dataset, register_dataset)
from .model import (MODEL_MAPPING, GetModelTokenizerFunction, LoRATM,
ModelType, get_default_lora_target_modules,
get_default_template_type, get_model_tokenizer,
get_model_tokenizer_from_repo,
ModelType, get_additional_saved_files,
get_default_lora_target_modules, get_default_template_type,
get_model_tokenizer, get_model_tokenizer_from_repo,
get_model_tokenizer_from_sdk, register_model)
from .preprocess import (AlpacaPreprocessor, ClsPreprocessor,
ComposePreprocessor, ConversationsPreprocessor,
Expand All @@ -19,5 +19,5 @@
from .utils import (data_collate_fn, dataset_map, download_dataset,
find_all_linear_for_lora, history_to_messages, inference,
inference_stream, limit_history_length,
messages_to_history, print_example, sort_by_max_length,
stat_dataset)
messages_to_history, print_example, set_generation_config,
sort_by_max_length, stat_dataset)
12 changes: 8 additions & 4 deletions swift/llm/utils/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from swift.hub import HubApi, ModelScopeConfig
from swift.utils import (add_version_to_work_dir, broadcast_string,
get_dist_setting, is_dist, is_master)
from .dataset import (DATASET_MAPPING, DatasetName, get_custom_dataset,
register_dataset)
from .model import (MODEL_MAPPING, ModelType, dtype_mapping,
from .dataset import DATASET_MAPPING, get_custom_dataset, register_dataset
from .model import (MODEL_MAPPING, dtype_mapping,
get_default_lora_target_modules, get_default_template_type)
from .template import TEMPLATE_MAPPING, TemplateType

Expand Down Expand Up @@ -431,8 +430,13 @@ def select_dtype(
assert torch_dtype in {torch.float16, torch.bfloat16, torch.float32}
if torch_dtype == torch.float16:
if isinstance(args, SftArguments) and args.sft_type == 'full':
args.dtype = 'fp32'
torch_dtype = torch.float32
logger.warning('Setting torch_dtype: torch.float32')
logger.warning(
'Fine-tuning with full parameters does not support fp16, and is prone to NaN. '
'We will use the fp32 & AMP approach, which consumes approximately twice the memory of bf16.'
)
logger.info(f'Setting torch_dtype: {torch_dtype}')
fp16, bf16 = True, False
elif torch_dtype == torch.bfloat16:
support_bf16 = torch.cuda.is_bf16_supported()
Expand Down
22 changes: 22 additions & 0 deletions swift/llm/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,13 @@ def get_model_tokenizer_qwen_vl(model_dir: str,
]
get_qwen_function = kwargs.pop('get_qwen_function',
get_model_tokenizer_qwen_chat)
tokenizer_config = get_tokenizer_config(model_dir)
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
tokenizer_cls = get_class_from_dynamic_module(class_ref, model_dir)
tokenizer_cls._auto_class = 'AutoTokenizer'
tokenizer_cls.IMAGE_ST = () # fix no attr `self.IMAGE_ST` bug
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(
model_dir, trust_remote_code=True)
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
load_model, **kwargs)
if model is not None:
Expand Down Expand Up @@ -870,6 +877,13 @@ def get_model_tokenizer_qwen_audio(model_dir: str,
load_model: bool = True,
**kwargs):
get_qwen_function = kwargs.pop('get_qwen_function')
tokenizer_config = get_tokenizer_config(model_dir)
class_ref = tokenizer_config['auto_map']['AutoTokenizer'][0]
tokenizer_cls = get_class_from_dynamic_module(class_ref, model_dir)
tokenizer_cls._auto_class = 'AutoTokenizer'
tokenizer_cls.AUDIO_ST = () # fix no attr `self.AUDIO_ST` bug
kwargs['tokenizer'] = tokenizer_cls.from_pretrained(
model_dir, trust_remote_code=True)
model, tokenizer = get_qwen_function(model_dir, torch_dtype, model_kwargs,
load_model, **kwargs)
if model is not None:
Expand Down Expand Up @@ -1148,6 +1162,14 @@ def get_model_tokenizer(
return model, tokenizer


def get_additional_saved_files(model_type: str) -> List[str]:
if 'qwen-vl' in model_type:
return ['SimSun.ttf']
elif 'qwen-audio' in model_type:
return ['mel_filters.npz']
return []


def get_default_template_type(model_type: str) -> Optional[str]:
return MODEL_MAPPING[model_type].get('template')

Expand Down
2 changes: 1 addition & 1 deletion swift/llm/utils/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from copy import deepcopy
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from torch import Tensor
from transformers import PreTrainedTokenizerBase, StoppingCriteria
Expand Down
15 changes: 13 additions & 2 deletions swift/llm/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from transformers import (PreTrainedModel, PreTrainedTokenizerBase,
StoppingCriteriaList, TextStreamer, trainer)
from transformers import (GenerationConfig, PreTrainedModel,
PreTrainedTokenizerBase, StoppingCriteriaList,
TextStreamer, trainer)

from swift.hub import ModelScopeConfig
from swift.utils import (get_dist_setting, get_logger, is_ddp_plus_mp, is_dist,
Expand Down Expand Up @@ -540,6 +541,16 @@ def messages_to_history(messages: Messages) -> Dict[str, Any]:
}


def set_generation_config(model: Module,
generation_config: GenerationConfig) -> None:
if hasattr(model, 'generation_config'):
old_generation_config = model.generation_config
for k, v in old_generation_config.__dict__.items():
if k not in generation_config.__dict__:
setattr(generation_config, k, v)
model.geneartion_config = generation_config


# monkey patching
MsDataset.load = _msdataset_ddp_load
if is_ddp_plus_mp():
Expand Down
7 changes: 7 additions & 0 deletions swift/trainers/arguments.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.

from dataclasses import dataclass, field
from typing import List, Optional

from transformers.training_args import TrainingArguments as HfTrainingArguments
from transformers.training_args_seq2seq import \
Expand All @@ -18,6 +19,12 @@ class SwiftArgumentsMixin:
'choices':
{'end', 'push_best', 'push_last', 'checkpoint', 'all_checkpoints'}
})
additional_saved_files: Optional[List[str]] = None

def __post_init__(self):
if self.additional_saved_files is None:
self.additional_saved_files = []
super().__post_init__()


@dataclass
Expand Down
8 changes: 8 additions & 0 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,14 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None):
self.tokenizer.save_pretrained(output_dir)
# training_args.bin
torch.save(self.args, os.path.join(output_dir, 'training_args.bin'))
# additional files
additional_files = getattr(self.args, 'additional_saved_files', [])
if model_dir is not None:
for file in additional_files:
src_path = os.path.join(model_dir, file)
dst_path = os.path.join(output_dir, file)
if os.path.exists(src_path):
shutil.copy(src_path, dst_path)

def _save_checkpoint(self, model, trial, metrics=None):
only_save_model = self.args.only_save_model
Expand Down