From 3a3e5a4da20bfcd75f8b6a6869b240af8feccc12 Mon Sep 17 00:00:00 2001 From: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com> Date: Tue, 4 May 2021 15:43:12 +0200 Subject: [PATCH] Add rename_columns method (#2312) --- src/datasets/arrow_dataset.py | 51 +++++++++++++++++++++++++++++++++++ tests/test_arrow_dataset.py | 28 +++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f0e7fa93bb8..7f563b8f9d6 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -1122,6 +1122,57 @@ def rename(columns): dataset._fingerprint = new_fingerprint return dataset + @fingerprint_transform(inplace=False) + def rename_columns(self, column_mapping: Dict[str, str], new_fingerprint) -> "Dataset": + """ + Rename several columns in the dataset, and move the features associated to the original columns under + the new column names. + + Args: + column_mapping (:obj:`Dict[str, str]`): A mapping of columns to rename to their new names + + Returns: + :class:`Dataset`: A copy of the dataset with renamed columns + """ + dataset = copy.deepcopy(self) + + extra_columns = set(column_mapping.keys()) - set(dataset.column_names) + if extra_columns: + raise ValueError( + f"Original column names {extra_columns} not in the dataset. " + f"Current columns in the dataset: {dataset._data.column_names}" + ) + + number_of_duplicates_in_new_columns = len(column_mapping.values()) - len(set(column_mapping.values())) + if number_of_duplicates_in_new_columns != 0: + raise ValueError( + "New column names must all be different, but this column mapping " + f"has {number_of_duplicates_in_new_columns} duplicates" + ) + + empty_new_columns = [new_col for new_col in column_mapping.values() if not new_col] + if empty_new_columns: + raise ValueError(f"New column names {empty_new_columns} are empty.") + + def rename(columns): + return [column_mapping[col] if col in column_mapping else col for col in columns] + + new_column_names = rename(self._data.column_names) + if self._format_columns is not None: + dataset._format_columns = rename(self._format_columns) + + dataset._info.features = Features( + { + column_mapping[col] if col in column_mapping else col: feature + for col, feature in (self._info.features or {}).items() + } + ) + + dataset._data = dataset._data.rename_columns(new_column_names) + dataset._data = update_metadata_with_features(dataset._data, self.features) + dataset._fingerprint = new_fingerprint + return dataset + def __len__(self): """Number of rows in the dataset.""" return self.num_rows diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 5be9565642b..fc8d897712d 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -510,6 +510,34 @@ def test_rename_column(self, in_memory): self.assertNotEqual(new_dset._fingerprint, fingerprint) assert_arrow_metadata_are_synced_with_dataset_features(new_dset) + def test_rename_columns(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + fingerprint = dset._fingerprint + with dset.rename_columns({"col_1": "new_name"}) as new_dset: + self.assertEqual(new_dset.num_columns, 3) + self.assertListEqual(list(new_dset.column_names), ["new_name", "col_2", "col_3"]) + self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"]) + self.assertNotEqual(new_dset._fingerprint, fingerprint) + + with dset.rename_columns({"col_1": "new_name", "col_2": "new_name2"}) as new_dset: + self.assertEqual(new_dset.num_columns, 3) + self.assertListEqual(list(new_dset.column_names), ["new_name", "new_name2", "col_3"]) + self.assertListEqual(list(dset.column_names), ["col_1", "col_2", "col_3"]) + self.assertNotEqual(new_dset._fingerprint, fingerprint) + + # Original column not in dataset + with self.assertRaises(ValueError): + dset.rename_columns({"not_there": "new_name"}) + + # Empty new name + with self.assertRaises(ValueError): + dset.rename_columns({"col_1": ""}) + + # Duplicates + with self.assertRaises(ValueError): + dset.rename_columns({"col_1": "new_name", "col_2": "new_name"}) + def test_concatenate(self, in_memory): data1, data2, data3 = {"id": [0, 1, 2]}, {"id": [3, 4, 5]}, {"id": [6, 7]} info1 = DatasetInfo(description="Dataset1")