Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions swift/llm/model/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def get_matched_model_group(self, model_name: str) -> Optional[ModelGroup]:
for key in ['ms_model_id', 'hf_model_id', 'model_path']:
value = getattr(model, key)

if isinstance(value, str) and model_name == value.rsplit('/', 1)[-1]:
if isinstance(value, str) and model_name == value.rsplit('/', 1)[-1].lower():
return model_group

def check_requires(self, model_info=None):
Expand Down Expand Up @@ -480,7 +480,7 @@ def get_all_models() -> List[str]:


def get_matched_model_meta(model_id_or_path: str) -> Optional[ModelMeta]:
model_name = get_model_name(model_id_or_path)
model_name = get_model_name(model_id_or_path).lower()
for model_type, model_meta in MODEL_MAPPING.items():
model_group = ModelMeta.get_matched_model_group(model_meta, model_name)
if model_group is not None:
Expand Down
4 changes: 2 additions & 2 deletions swift/megatron/utils/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:

def convert_mcore2hf(args: ExportArguments) -> None:
from swift.megatron import prepare_mcore_model, adapter_state_dict_context
hf_model, template = prepare_model_template(
args, load_model=args.to_hf, patch_offload=not args.test_convert_precision)
_, template = prepare_model_template(args, load_model=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change refactors the logic to delay loading the hf_model until it's needed inside the if args.to_hf: block, which is a good practice for memory efficiency. However, this introduces a second call to prepare_model_template on line 297. While this improves code structure, it's worth ensuring that calling prepare_model_template twice doesn't introduce a significant performance overhead. If the function is lightweight when load_model=False, this is fine. Otherwise, it might be better to refactor prepare_model_template to separate template creation from model loading more explicitly.

processor = template.processor

megatron_model_meta = get_megatron_model_meta(args.model_type)
Expand Down Expand Up @@ -295,6 +294,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
mg_model = peft_model.merge_and_unload()
logger.info('Megatron model created successfully.')
if args.to_hf:
hf_model = prepare_model_template(args, patch_offload=not args.test_convert_precision)[0]
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
Expand Down
4 changes: 2 additions & 2 deletions swift/trainers/rlhf_arguments.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import List
from typing import List, Optional

from trl import CPOConfig as HfCPOConfig
from trl import DPOConfig as HfDPOConfig
Expand All @@ -15,7 +15,7 @@

@dataclass
class DPOConfig(SwiftArgumentsMixin, HfDPOConfig):
pass
ld_alpha: Optional[float] = None # compat trl==0.15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment # compat trl==0.15 is a bit confusing as there doesn't seem to be a 0.15 version for the trl library. This could make future maintenance more difficult. The ld_alpha parameter was introduced in trl version 0.7.1. To improve clarity, consider updating the comment to specify the correct version range, for example: # compat trl < 0.7.1.

Suggested change
ld_alpha: Optional[float] = None # compat trl==0.15
ld_alpha: Optional[float] = None # compat trl < 0.7.1



@dataclass
Expand Down
3 changes: 3 additions & 0 deletions swift/trainers/rlhf_trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from accelerate.utils import gather_object
from peft import PeftModel
from transformers import PreTrainedModel
from transformers.utils.versions import require_version
from trl import DPOTrainer as HFDPOTrainer
from trl.trainer.dpo_config import DPOConfig
from trl.trainer.utils import RunningMoments, selective_log_softmax
Expand Down Expand Up @@ -71,6 +72,8 @@ def __init__(self,

if 'bco_pair' in loss_types:
self.running = RunningMoments(self.accelerator)
if self.args.ld_alpha is not None:
require_version('trl>=0.18', '`ld_alpha` requires that "trl>=0.18".')
if self.template.packing:
self.accelerator.gather_for_metrics = new_gather_function

Expand Down