diff --git a/swift/llm/dataset/preprocessor/core.py b/swift/llm/dataset/preprocessor/core.py index 8fdcb9e491..238d999ea7 100644 --- a/swift/llm/dataset/preprocessor/core.py +++ b/swift/llm/dataset/preprocessor/core.py @@ -314,7 +314,6 @@ def __call__( dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state) map_kwargs = {'batched': True, 'batch_size': batch_size} - cache_file_name = None if isinstance(dataset, HfDataset): if not load_from_cache_file and is_dist() and not is_master(): load_from_cache_file = True @@ -326,11 +325,11 @@ def __call__( dataset = RowPreprocessor.get_features_dataset(dataset) if 'solution' in dataset.features: with safe_ddp_context(None, True): - if not dataset.cache_files: - cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache', - f'{dataset._fingerprint}.arrow') - dataset = dataset.map( - lambda x: {'__#solution': x['solution']}, **map_kwargs, cache_file_name=cache_file_name) + if isinstance(dataset, HfDataset) and not dataset.cache_files: + map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache', + f'{dataset._fingerprint}.arrow') + dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs) + map_kwargs.pop('cache_file_name', None) dataset = self._rename_columns(dataset) dataset = self.prepare_dataset(dataset) dataset = self._cast_pil_image(dataset) @@ -338,9 +337,9 @@ def __call__( ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False with self._patch_arrow_writer(), safe_ddp_context(None, True): try: - if not dataset.cache_files: - cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache', - f'{dataset._fingerprint}.arrow') + if isinstance(dataset, HfDataset) and not dataset.cache_files: + map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache', + f'{dataset._fingerprint}.arrow') dataset_mapped = dataset.map( self.batched_preprocess, fn_kwargs={ @@ -348,7 +347,6 @@ def __call__( 'ignore_max_length_error': ignore_max_length_error }, remove_columns=list(dataset.features.keys()), - cache_file_name=cache_file_name, **map_kwargs) except NotImplementedError: pass