From 5b24cfaa2f6339f7bbadb3079fb7b52a25c38347 Mon Sep 17 00:00:00 2001 From: "yuze.zyz" Date: Fri, 29 Dec 2023 17:48:01 +0800 Subject: [PATCH] fix --- swift/llm/utils/dataset.py | 3 ++- swift/ui/llm_infer/llm_infer.py | 10 ++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/swift/llm/utils/dataset.py b/swift/llm/utils/dataset.py index 6438c2bffe..b8c8dee611 100644 --- a/swift/llm/utils/dataset.py +++ b/swift/llm/utils/dataset.py @@ -577,7 +577,8 @@ def reorganize_row_simple(sample) -> Dict[str, str]: register_dataset( DatasetName.hh_rlhf, - 'AI-ModelScope/hh-rlhf', [('default', 'train')], [('default', 'test')], + 'AI-ModelScope/hh-rlhf', [('harmless-base', 'train')], + [('harmless-base', 'test')], process_hh_rlhf, get_dataset_from_repo, tags=['hfrl', 'dpo', 'pairwise']) diff --git a/swift/ui/llm_infer/llm_infer.py b/swift/ui/llm_infer/llm_infer.py index c6f16c8b77..76265b1b32 100644 --- a/swift/ui/llm_infer/llm_infer.py +++ b/swift/ui/llm_infer/llm_infer.py @@ -1,4 +1,5 @@ import os +import re from dataclasses import fields from typing import Type @@ -19,6 +20,9 @@ class LLMInfer(BaseUI): sub_ui = [Model] + int_regex = r'^[-+]?[0-9]+$' + float_regex = r'[-+]?(?:\d*\.*\d+)' + locale_dict = { 'generate_alert': { 'value': { @@ -159,6 +163,12 @@ def prepare_checkpoint(cls, *args): compare_value_ui = str(value) if not isinstance( value, (list, dict)) else value if key in infer_args and compare_value_ui != compare_value_arg and value: + if isinstance(value, str) and re.fullmatch( + cls.int_regex, value): + value = int(value) + elif isinstance(value, str) and re.fullmatch( + cls.float_regex, value): + value = float(value) kwargs[key] = value if not isinstance( value, list) else ' '.join(value) kwargs_is_list[key] = isinstance(value, list)