Skip to content

Commit

Permalink
Move task casting to builder
Browse files Browse the repository at this point in the history
This makes sense for two reasons: 1) to handle both Dataset and DatasetDict objects, and 2) be closer to the post processing logic per split
  • Loading branch information
lewtun committed May 7, 2021
1 parent 3cf039d commit b2a02c5
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
10 changes: 10 additions & 0 deletions src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ def __init__(
name: Optional[str] = None,
hash: Optional[str] = None,
features: Optional[Features] = None,
task=None,
**config_kwargs,
):
"""Constructs a DatasetBuilder.
Expand All @@ -226,6 +227,7 @@ def __init__(
# DatasetBuilder name
self.name: str = camelcase_to_snakecase(self.__class__.__name__)
self.hash: Optional[str] = hash
self.task = task

# Prepare config: DatasetConfig contains name, version and description but can be extended by each dataset
config_kwargs = {key: value for key, value in config_kwargs.items() if value is not None}
Expand Down Expand Up @@ -813,6 +815,14 @@ def _build_single_dataset(
)
else:
ds.info.features = self.info.post_processed.features
# Rename feature column names to match task schema
tasks = [template.task for template in self.info.task_templates]
if self.task not in tasks:
raise ValueError(f"Task {self.task} not found! Avaliable tasks: {tasks}")
else:
for template in self.info.task_templates:
if template.task == self.task:
ds = ds.rename_columns(template.column_mapping)

return ds

Expand Down
6 changes: 1 addition & 5 deletions src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ def load_dataset(
data_files=data_files,
hash=hash,
features=features,
task=task,
**config_kwargs,
)

Expand All @@ -757,11 +758,6 @@ def load_dataset(
keep_in_memory if keep_in_memory is not None else is_small_dataset(builder_instance.info.dataset_size)
)
ds = builder_instance.as_dataset(split=split, ignore_verifications=ignore_verifications, in_memory=keep_in_memory)
# Rename feature column names to match task schema
for template in builder_instance.info.task_templates:
if template.task == task:
for k, v in template.column_mapping.items():
ds = ds.rename_column(k, v)
if save_infos:
builder_instance._save_infos()

Expand Down

1 comment on commit b2a02c5

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==1.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.022485 / 0.011353 (0.011132) 0.015553 / 0.011008 (0.004545) 0.051301 / 0.038508 (0.012793) 0.041889 / 0.023109 (0.018780) 0.356700 / 0.275898 (0.080802) 0.386354 / 0.323480 (0.062874) 0.010219 / 0.007986 (0.002233) 0.005105 / 0.004328 (0.000776) 0.013400 / 0.004250 (0.009150) 0.052959 / 0.037052 (0.015906) 0.353154 / 0.258489 (0.094665) 0.397468 / 0.293841 (0.103627) 0.166174 / 0.128546 (0.037628) 0.127694 / 0.075646 (0.052047) 0.436180 / 0.419271 (0.016908) 0.592659 / 0.043533 (0.549127) 0.355829 / 0.255139 (0.100690) 0.383395 / 0.283200 (0.100195) 1.973573 / 0.141683 (1.831890) 1.875492 / 1.452155 (0.423337) 1.934293 / 1.492716 (0.441577)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.009570 / 0.018006 (-0.008436) 0.472212 / 0.000490 (0.471722) 0.002758 / 0.000200 (0.002558) 0.000141 / 0.000054 (0.000087)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.046183 / 0.037411 (0.008772) 0.027024 / 0.014526 (0.012498) 0.029399 / 0.176557 (-0.147158) 0.050842 / 0.737135 (-0.686293) 0.030714 / 0.296338 (-0.265624)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.420741 / 0.215209 (0.205531) 4.266474 / 2.077655 (2.188819) 2.185401 / 1.504120 (0.681281) 1.967962 / 1.541195 (0.426767) 1.976489 / 1.468490 (0.507998) 6.770009 / 4.584777 (2.185232) 5.891087 / 3.745712 (2.145375) 8.362482 / 5.269862 (3.092621) 7.426137 / 4.565676 (2.860460) 0.663482 / 0.424275 (0.239207) 0.010945 / 0.007607 (0.003338) 0.540805 / 0.226044 (0.314761) 5.457662 / 2.268929 (3.188733) 2.640544 / 55.444624 (-52.804080) 2.205948 / 6.876477 (-4.670529) 2.225252 / 2.142072 (0.083179) 6.671000 / 4.805227 (1.865772) 4.224427 / 6.500664 (-2.276237) 7.406085 / 0.075469 (7.330616)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 10.475837 / 1.841788 (8.634050) 13.562910 / 8.074308 (5.488602) 30.186923 / 10.191392 (19.995531) 0.910423 / 0.680424 (0.229999) 0.649890 / 0.534201 (0.115689) 0.770468 / 0.579283 (0.191185) 0.614059 / 0.434364 (0.179695) 0.698665 / 0.540337 (0.158327) 1.560936 / 1.386936 (0.174000)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.023328 / 0.011353 (0.011975) 0.015890 / 0.011008 (0.004882) 0.052199 / 0.038508 (0.013691) 0.041754 / 0.023109 (0.018645) 0.337178 / 0.275898 (0.061280) 0.367824 / 0.323480 (0.044344) 0.012030 / 0.007986 (0.004044) 0.004973 / 0.004328 (0.000645) 0.013636 / 0.004250 (0.009386) 0.060780 / 0.037052 (0.023728) 0.341570 / 0.258489 (0.083081) 0.378245 / 0.293841 (0.084404) 0.157480 / 0.128546 (0.028934) 0.122379 / 0.075646 (0.046733) 0.434670 / 0.419271 (0.015398) 0.413580 / 0.043533 (0.370047) 0.346732 / 0.255139 (0.091593) 0.371743 / 0.283200 (0.088543) 1.618129 / 0.141683 (1.476446) 1.850333 / 1.452155 (0.398178) 1.896302 / 1.492716 (0.403586)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.009227 / 0.018006 (-0.008779) 0.472683 / 0.000490 (0.472193) 0.002945 / 0.000200 (0.002745) 0.000140 / 0.000054 (0.000086)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.043102 / 0.037411 (0.005690) 0.026689 / 0.014526 (0.012164) 0.030684 / 0.176557 (-0.145872) 0.060645 / 0.737135 (-0.676490) 0.031170 / 0.296338 (-0.265169)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.405830 / 0.215209 (0.190621) 4.048153 / 2.077655 (1.970499) 2.043084 / 1.504120 (0.538965) 1.815038 / 1.541195 (0.273843) 1.836481 / 1.468490 (0.367991) 6.467603 / 4.584777 (1.882826) 5.801189 / 3.745712 (2.055477) 8.063699 / 5.269862 (2.793838) 6.994333 / 4.565676 (2.428657) 0.647057 / 0.424275 (0.222782) 0.010574 / 0.007607 (0.002967) 0.528912 / 0.226044 (0.302868) 5.261144 / 2.268929 (2.992215) 2.536228 / 55.444624 (-52.908397) 2.160316 / 6.876477 (-4.716160) 2.161000 / 2.142072 (0.018928) 6.593261 / 4.805227 (1.788034) 4.315308 / 6.500664 (-2.185356) 5.963566 / 0.075469 (5.888097)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 10.629119 / 1.841788 (8.787331) 13.466627 / 8.074308 (5.392319) 30.115351 / 10.191392 (19.923959) 0.786121 / 0.680424 (0.105697) 0.600072 / 0.534201 (0.065871) 0.750124 / 0.579283 (0.170841) 0.557984 / 0.434364 (0.123620) 0.676782 / 0.540337 (0.136444) 1.511285 / 1.386936 (0.124349)

CML watermark

Please sign in to comment.