Skip to content

Commit

Permalink
Add rename_columns method (#2312)
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed May 4, 2021
1 parent 097129d commit 3a3e5a4
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,57 @@ def rename(columns):
dataset._fingerprint = new_fingerprint
return dataset

@fingerprint_transform(inplace=False)
def rename_columns(self, column_mapping: Dict[str, str], new_fingerprint) -> "Dataset":
"""
Rename several columns in the dataset, and move the features associated to the original columns under
the new column names.
Args:
column_mapping (:obj:`Dict[str, str]`): A mapping of columns to rename to their new names
Returns:
:class:`Dataset`: A copy of the dataset with renamed columns
"""
dataset = copy.deepcopy(self)

extra_columns = set(column_mapping.keys()) - set(dataset.column_names)
if extra_columns:
raise ValueError(
f"Original column names {extra_columns} not in the dataset. "
f"Current columns in the dataset: {dataset._data.column_names}"
)

number_of_duplicates_in_new_columns = len(column_mapping.values()) - len(set(column_mapping.values()))
if number_of_duplicates_in_new_columns != 0:
raise ValueError(
"New column names must all be different, but this column mapping "
f"has {number_of_duplicates_in_new_columns} duplicates"
)

empty_new_columns = [new_col for new_col in column_mapping.values() if not new_col]
if empty_new_columns:
raise ValueError(f"New column names {empty_new_columns} are empty.")

def rename(columns):
return [column_mapping[col] if col in column_mapping else col for col in columns]

new_column_names = rename(self._data.column_names)
if self._format_columns is not None:
dataset._format_columns = rename(self._format_columns)

dataset._info.features = Features(
{
column_mapping[col] if col in column_mapping else col: feature
for col, feature in (self._info.features or {}).items()
}
)

dataset._data = dataset._data.rename_columns(new_column_names)
dataset._data = update_metadata_with_features(dataset._data, self.features)
dataset._fingerprint = new_fingerprint
return dataset

def __len__(self):
"""Number of rows in the dataset."""
return self.num_rows
Expand Down
28 changes: 28 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,34 @@ def test_rename_column(self, in_memory):
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

def test_rename_columns(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
fingerprint = dset._fingerprint
with dset.rename_columns({"col_1": "new_name"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

# Original column not in dataset
with self.assertRaises(ValueError):
dset.rename_columns({"not_there": "new_name"})

# Empty new name
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": ""})

# Duplicates
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})

def test_concatenate(self, in_memory):
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
info1 = DatasetInfo(description="Dataset1")
Expand Down

1 comment on commit 3a3e5a4

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==1.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.032067 / 0.011353 (0.020714) 0.020258 / 0.011008 (0.009250) 0.056225 / 0.038508 (0.017717) 0.040903 / 0.023109 (0.017794) 0.418275 / 0.275898 (0.142377) 0.510165 / 0.323480 (0.186685) 0.013455 / 0.007986 (0.005469) 0.006293 / 0.004328 (0.001965) 0.012611 / 0.004250 (0.008360) 0.061439 / 0.037052 (0.024387) 0.418267 / 0.258489 (0.159778) 0.484866 / 0.293841 (0.191025) 0.207516 / 0.128546 (0.078970) 0.152944 / 0.075646 (0.077297) 0.513789 / 0.419271 (0.094518) 0.725869 / 0.043533 (0.682336) 0.438428 / 0.255139 (0.183289) 0.486577 / 0.283200 (0.203377) 2.343049 / 0.141683 (2.201366) 2.042165 / 1.452155 (0.590010) 2.130017 / 1.492716 (0.637301)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.010297 / 0.018006 (-0.007709) 0.594056 / 0.000490 (0.593566) 0.000336 / 0.000200 (0.000136) 0.000098 / 0.000054 (0.000043)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.057781 / 0.037411 (0.020369) 0.030279 / 0.014526 (0.015753) 0.033086 / 0.176557 (-0.143471) 0.054116 / 0.737135 (-0.683019) 0.037184 / 0.296338 (-0.259154)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.548574 / 0.215209 (0.333365) 5.485787 / 2.077655 (3.408133) 2.444717 / 1.504120 (0.940597) 2.102051 / 1.541195 (0.560856) 2.128623 / 1.468490 (0.660133) 8.290683 / 4.584777 (3.705906) 7.255626 / 3.745712 (3.509914) 10.274102 / 5.269862 (5.004240) 9.080799 / 4.565676 (4.515123) 0.822332 / 0.424275 (0.398057) 0.011788 / 0.007607 (0.004181) 0.658865 / 0.226044 (0.432820) 6.811584 / 2.268929 (4.542655) 3.087480 / 55.444624 (-52.357144) 2.403076 / 6.876477 (-4.473401) 2.500972 / 2.142072 (0.358899) 8.341675 / 4.805227 (3.536447) 7.078717 / 6.500664 (0.578053) 8.773354 / 0.075469 (8.697885)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 13.051911 / 1.841788 (11.210123) 15.075130 / 8.074308 (7.000822) 43.197041 / 10.191392 (33.005649) 1.005627 / 0.680424 (0.325203) 0.660685 / 0.534201 (0.126484) 0.927009 / 0.579283 (0.347726) 0.736531 / 0.434364 (0.302167) 0.859509 / 0.540337 (0.319172) 1.756739 / 1.386936 (0.369803)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.030623 / 0.011353 (0.019270) 0.019298 / 0.011008 (0.008290) 0.061032 / 0.038508 (0.022524) 0.051464 / 0.023109 (0.028355) 0.374172 / 0.275898 (0.098274) 0.421256 / 0.323480 (0.097776) 0.013863 / 0.007986 (0.005877) 0.006442 / 0.004328 (0.002113) 0.012828 / 0.004250 (0.008578) 0.065375 / 0.037052 (0.028323) 0.401732 / 0.258489 (0.143242) 0.423524 / 0.293841 (0.129683) 0.192425 / 0.128546 (0.063879) 0.153100 / 0.075646 (0.077454) 0.481836 / 0.419271 (0.062565) 0.488691 / 0.043533 (0.445158) 0.372744 / 0.255139 (0.117605) 0.416228 / 0.283200 (0.133028) 1.926585 / 0.141683 (1.784902) 2.072339 / 1.452155 (0.620185) 2.084269 / 1.492716 (0.591552)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.009786 / 0.018006 (-0.008220) 0.590207 / 0.000490 (0.589717) 0.000403 / 0.000200 (0.000203) 0.000065 / 0.000054 (0.000010)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.049635 / 0.037411 (0.012223) 0.030008 / 0.014526 (0.015482) 0.030519 / 0.176557 (-0.146038) 0.052520 / 0.737135 (-0.684615) 0.035876 / 0.296338 (-0.260462)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.525581 / 0.215209 (0.310372) 5.325394 / 2.077655 (3.247739) 2.384538 / 1.504120 (0.880419) 2.060514 / 1.541195 (0.519319) 2.036223 / 1.468490 (0.567733) 7.906484 / 4.584777 (3.321707) 7.071843 / 3.745712 (3.326130) 9.977009 / 5.269862 (4.707147) 8.809572 / 4.565676 (4.243896) 0.781536 / 0.424275 (0.357261) 0.012289 / 0.007607 (0.004682) 0.685966 / 0.226044 (0.459921) 6.830348 / 2.268929 (4.561419) 3.011607 / 55.444624 (-52.433018) 2.448931 / 6.876477 (-4.427545) 2.397251 / 2.142072 (0.255179) 8.207187 / 4.805227 (3.401960) 8.057550 / 6.500664 (1.556886) 8.840962 / 0.075469 (8.765493)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 13.196848 / 1.841788 (11.355060) 14.542436 / 8.074308 (6.468128) 42.814879 / 10.191392 (32.623487) 0.937634 / 0.680424 (0.257210) 0.668668 / 0.534201 (0.134467) 0.895705 / 0.579283 (0.316422) 0.739638 / 0.434364 (0.305274) 0.827470 / 0.540337 (0.287133) 1.736772 / 1.386936 (0.349836)

CML watermark

Please sign in to comment.