Skip to content

Commit

Permalink
Add rename_columnS method
Browse files Browse the repository at this point in the history
  • Loading branch information
SBrandeis committed May 4, 2021
1 parent 097129d commit ab27d38
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 0 deletions.
51 changes: 51 additions & 0 deletions src/datasets/arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,57 @@ def rename(columns):
dataset._fingerprint = new_fingerprint
return dataset

@fingerprint_transform(inplace=False)
def rename_columns(self, column_mapping: Dict[str, str], new_fingerprint) -> "Dataset":
"""
Rename several columns in the dataset, and move the features associated to the original columns under
the new column names.
Args:
column_mapping (:obj:`Dict[str, str]`): A mapping of columns to rename to their new names
Returns:
:class:`Dataset`: A copy of the dataset with renamed columns
"""
dataset = copy.deepcopy(self)

extra_columns = set(column_mapping.keys()) - set(dataset.column_names)
if extra_columns:
raise ValueError(
f"Original column names {extra_columns} not in the dataset. "
f"Current columns in the dataset: {dataset._data.column_names}"
)

number_of_duplicates_in_new_columns = len(column_mapping.values()) - len(set(column_mapping.values()))
if number_of_duplicates_in_new_columns != 0:
raise ValueError(
"New column names must all be different, but this column mapping "
f"has {number_of_duplicates_in_new_columns} duplicates"
)

empty_new_columns = [new_col for new_col in column_mapping.values() if not new_col]
if empty_new_columns:
raise ValueError(f"New column names {empty_new_columns} are empty.")

def rename(columns):
return [column_mapping[col] if col in column_mapping else col for col in columns]

new_column_names = rename(self._data.column_names)
if self._format_columns is not None:
dataset._format_columns = rename(self._format_columns)

dataset._info.features = Features(
{
column_mapping[col] if col in column_mapping else col: feature
for col, feature in (self._info.features or {}).items()
}
)

dataset._data = dataset._data.rename_columns(new_column_names)
dataset._data = update_metadata_with_features(dataset._data, self.features)
dataset._fingerprint = new_fingerprint
return dataset

def __len__(self):
"""Number of rows in the dataset."""
return self.num_rows
Expand Down
28 changes: 28 additions & 0 deletions tests/test_arrow_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,34 @@ def test_rename_column(self, in_memory):
self.assertNotEqual(new_dset._fingerprint, fingerprint)
assert_arrow_metadata_are_synced_with_dataset_features(new_dset)

def test_rename_columns(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
fingerprint = dset._fingerprint
with dset.rename_columns({"col_1": "new_name"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset:
self.assertEqual(new_dset.num_columns, 3)
self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"])
self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"])
self.assertNotEqual(new_dset._fingerprint, fingerprint)

# Original column not in dataset
with self.assertRaises(ValueError):
dset.rename_columns({"not_there": "new_name"})

# Empty new name
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": ""})

# Duplicates
with self.assertRaises(ValueError):
dset.rename_columns({"col_1": "new_name", "col_2": "new_name"})

def test_concatenate(self, in_memory):
data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]}
info1 = DatasetInfo(description="Dataset1")
Expand Down

1 comment on commit ab27d38

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==1.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.018517 / 0.011353 (0.007164) 0.012552 / 0.011008 (0.001544) 0.040150 / 0.038508 (0.001642) 0.033039 / 0.023109 (0.009930) 0.298039 / 0.275898 (0.022141) 0.325969 / 0.323480 (0.002489) 0.009287 / 0.007986 (0.001302) 0.004449 / 0.004328 (0.000121) 0.009568 / 0.004250 (0.005318) 0.043781 / 0.037052 (0.006728) 0.293566 / 0.258489 (0.035077) 0.331445 / 0.293841 (0.037604) 0.116916 / 0.128546 (-0.011631) 0.091307 / 0.075646 (0.015661) 0.349965 / 0.419271 (-0.069306) 0.338863 / 0.043533 (0.295330) 0.297259 / 0.255139 (0.042120) 0.314492 / 0.283200 (0.031292) 1.370780 / 0.141683 (1.229097) 1.388423 / 1.452155 (-0.063732) 1.442665 / 1.492716 (-0.050052)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.008045 / 0.018006 (-0.009962) 0.441779 / 0.000490 (0.441289) 0.001866 / 0.000200 (0.001666) 0.000056 / 0.000054 (0.000002)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.035494 / 0.037411 (-0.001917) 0.021092 / 0.014526 (0.006566) 0.026707 / 0.176557 (-0.149849) 0.042529 / 0.737135 (-0.694606) 0.028411 / 0.296338 (-0.267927)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.357758 / 0.215209 (0.142549) 3.601952 / 2.077655 (1.524298) 1.819113 / 1.504120 (0.314993) 1.639318 / 1.541195 (0.098124) 1.707005 / 1.468490 (0.238515) 5.060315 / 4.584777 (0.475538) 4.814571 / 3.745712 (1.068859) 7.232858 / 5.269862 (1.962997) 5.918046 / 4.565676 (1.352369) 0.545302 / 0.424275 (0.121027) 0.008305 / 0.007607 (0.000698) 0.407688 / 0.226044 (0.181644) 4.074179 / 2.268929 (1.805251) 1.962554 / 55.444624 (-53.482070) 1.664415 / 6.876477 (-5.212062) 1.736347 / 2.142072 (-0.405725) 5.442435 / 4.805227 (0.637207) 4.007882 / 6.500664 (-2.492782) 5.255392 / 0.075469 (5.179923)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 8.433606 / 1.841788 (6.591819) 10.942052 / 8.074308 (2.867743) 22.753179 / 10.191392 (12.561787) 0.622912 / 0.680424 (-0.057512) 0.458244 / 0.534201 (-0.075957) 0.564552 / 0.579283 (-0.014731) 0.457563 / 0.434364 (0.023199) 0.546832 / 0.540337 (0.006495) 1.172427 / 1.386936 (-0.214509)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.018501 / 0.011353 (0.007148) 0.012315 / 0.011008 (0.001307) 0.040395 / 0.038508 (0.001887) 0.032868 / 0.023109 (0.009758) 0.254242 / 0.275898 (-0.021656) 0.288645 / 0.323480 (-0.034835) 0.009711 / 0.007986 (0.001726) 0.004370 / 0.004328 (0.000041) 0.009821 / 0.004250 (0.005570) 0.050189 / 0.037052 (0.013136) 0.251753 / 0.258489 (-0.006736) 0.288147 / 0.293841 (-0.005694) 0.121852 / 0.128546 (-0.006694) 0.090962 / 0.075646 (0.015316) 0.334527 / 0.419271 (-0.084745) 0.514981 / 0.043533 (0.471448) 0.250685 / 0.255139 (-0.004454) 0.275917 / 0.283200 (-0.007283) 3.263349 / 0.141683 (3.121666) 1.397861 / 1.452155 (-0.054293) 1.458965 / 1.492716 (-0.033751)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.007185 / 0.018006 (-0.010821) 0.431976 / 0.000490 (0.431487) 0.001819 / 0.000200 (0.001620) 0.000049 / 0.000054 (-0.000005)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.032984 / 0.037411 (-0.004427) 0.021964 / 0.014526 (0.007439) 0.028186 / 0.176557 (-0.148370) 0.042896 / 0.737135 (-0.694240) 0.028919 / 0.296338 (-0.267419)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.305585 / 0.215209 (0.090376) 3.042078 / 2.077655 (0.964423) 1.485254 / 1.504120 (-0.018865) 1.319445 / 1.541195 (-0.221750) 1.376789 / 1.468490 (-0.091701) 5.195005 / 4.584777 (0.610228) 4.418606 / 3.745712 (0.672894) 6.701337 / 5.269862 (1.431475) 5.296455 / 4.565676 (0.730778) 0.481381 / 0.424275 (0.057106) 0.008426 / 0.007607 (0.000818) 0.393575 / 0.226044 (0.167530) 3.917555 / 2.268929 (1.648626) 1.864383 / 55.444624 (-53.580241) 1.572706 / 6.876477 (-5.303771) 1.630590 / 2.142072 (-0.511483) 5.151431 / 4.805227 (0.346204) 3.819291 / 6.500664 (-2.681373) 5.164390 / 0.075469 (5.088921)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 8.185434 / 1.841788 (6.343646) 10.623409 / 8.074308 (2.549100) 22.824700 / 10.191392 (12.633308) 0.687633 / 0.680424 (0.007210) 0.454497 / 0.534201 (-0.079703) 0.551890 / 0.579283 (-0.027393) 0.429988 / 0.434364 (-0.004376) 0.493197 / 0.540337 (-0.047141) 1.152321 / 1.386936 (-0.234616)

CML watermark

Please sign in to comment.