Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions swift/llm/dataset/preprocessor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def __call__(
dataset = sample_dataset(dataset, self.dataset_sample, True, self.random_state)

map_kwargs = {'batched': True, 'batch_size': batch_size}
cache_file_name = None
if isinstance(dataset, HfDataset):
if not load_from_cache_file and is_dist() and not is_master():
load_from_cache_file = True
Expand All @@ -326,29 +325,28 @@ def __call__(
dataset = RowPreprocessor.get_features_dataset(dataset)
if 'solution' in dataset.features:
with safe_ddp_context(None, True):
if not dataset.cache_files:
cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
f'{dataset._fingerprint}.arrow')
dataset = dataset.map(
lambda x: {'__#solution': x['solution']}, **map_kwargs, cache_file_name=cache_file_name)
if isinstance(dataset, HfDataset) and not dataset.cache_files:
map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
f'{dataset._fingerprint}.arrow')
Comment on lines +328 to +330
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logic for setting cache_file_name is duplicated on lines 340-342. To improve code clarity and maintainability, you could extract this repeated logic into a helper function. This would centralize the caching logic, making it easier to manage and modify in the future.

For example, you could introduce a helper function:

def _set_cache_file_name_in_kwargs(dataset, kwargs):
    if isinstance(dataset, HfDataset) and not dataset.cache_files:
        kwargs['cache_file_name'] = os.path.join(
            get_cache_dir(), 'datasets', 'map_cache', f'{dataset._fingerprint}.arrow')

And then call it in both places:

_set_cache_file_name_in_kwargs(dataset, map_kwargs)

dataset = dataset.map(lambda x: {'__#solution': x['solution']}, **map_kwargs)
map_kwargs.pop('cache_file_name', None)
dataset = self._rename_columns(dataset)
dataset = self.prepare_dataset(dataset)
dataset = self._cast_pil_image(dataset)

ignore_max_length_error = True if isinstance(dataset, HfDataset) and num_proc > 1 else False
with self._patch_arrow_writer(), safe_ddp_context(None, True):
try:
if not dataset.cache_files:
cache_file_name = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
f'{dataset._fingerprint}.arrow')
if isinstance(dataset, HfDataset) and not dataset.cache_files:
map_kwargs['cache_file_name'] = os.path.join(get_cache_dir(), 'datasets', 'map_cache',
f'{dataset._fingerprint}.arrow')
dataset_mapped = dataset.map(
self.batched_preprocess,
fn_kwargs={
'strict': strict,
'ignore_max_length_error': ignore_max_length_error
},
remove_columns=list(dataset.features.keys()),
cache_file_name=cache_file_name,
**map_kwargs)
except NotImplementedError:
pass
Expand Down
Loading