From 40f9a56c8f26be3895d1be2cc2e2bec39e4b9456 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 16 Dec 2024 21:28:28 +0800 Subject: [PATCH] fix dataset --- swift/llm/dataset/preprocessor/core.py | 13 ++++++++----- swift/llm/dataset/preprocessor/extra.py | 7 ++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index ea8d220789..24c6fa73a6 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -36,6 +36,7 @@ def __init__(self, random_state: Union[np.random.RandomState, int, None] = None, traceback_limit: int = 10) -> None: self.columns_mapping = columns_mapping or {} + self.origin_columns_mapping = self.columns_mapping.copy() # Higher priority and raise Error images_keys = ['images', 'image'] audios_keys = ['audios', 'audio'] videos_keys = ['videos', 'video'] @@ -179,13 +180,14 @@ def batched_preprocess(self, batched_row: Dict[str, Any], *, strict: bool) -> Di return res - @staticmethod - def safe_rename_columns(dataset: DATASET_TYPE, columns_mapping: Dict[str, Any]) -> DATASET_TYPE: + def _rename_columns(self, dataset: DATASET_TYPE) -> DATASET_TYPE: dataset = get_features_dataset(dataset) + dataset = dataset.rename_columns(self.origin_columns_mapping) + columns_keys = {k.lower(): k for k in dataset.features.keys()} # lower -> lower/upper safe_columns_mapping = { columns_keys[k.lower()]: v - for k, v in columns_mapping.items() if k.lower() in columns_keys + for k, v in self.columns_mapping.items() if k.lower() in columns_keys } counter = Counter(safe_columns_mapping.values()) @@ -251,7 +253,7 @@ def __call__( if self.dataset_sample is not None: dataset = sample_dataset(dataset, self.dataset_sample, self.random_state) - dataset = self.safe_rename_columns(dataset, self.columns_mapping) + dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) map_kwargs = {} @@ -474,6 +476,7 @@ def __call__( strict: bool = False, load_from_cache_file: bool = False, ) -> DATASET_TYPE: - dataset = RowPreprocessor.safe_rename_columns(dataset, self.columns_mapping) + dataset = get_features_dataset(dataset) + dataset = dataset.rename_columns(self.columns_mapping) preprocessor = self._get_preprocessor(dataset) return preprocessor(dataset, num_proc=num_proc, load_from_cache_file=load_from_cache_file, strict=strict) diff --git a/swift/llm/dataset/preprocessor/extra.py b/swift/llm/dataset/preprocessor/extra.py index 7c27c08e05..06f64d104f 100644 --- a/swift/llm/dataset/preprocessor/extra.py +++ b/swift/llm/dataset/preprocessor/extra.py @@ -66,11 +66,8 @@ def __init__(self, super().__init__(columns_mapping=columns_mapping, **kwargs) def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]: - row = super().preprocess(row) - messages = row['messages'] - query_message = messages[-2] - query_message['content'] = self.prompt.replace(self.query_tag, query_message['content']) - return row + row['query'] = self.prompt.replace(self.query_tag, row['query']) + return super().preprocess(row) class ClsPreprocessor(ResponsePreprocessor):