From 31ff95d4bdc58ff1355b6dabf7d66a4652af4641 Mon Sep 17 00:00:00 2001 From: Quentin Lhoest Date: Tue, 29 Sep 2020 14:44:15 +0200 Subject: [PATCH] fix column order issue in cast --- src/datasets/arrow_dataset.py | 9 +++++---- tests/test_arrow_dataset.py | 3 ++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index 03dc3620429..b08ceb0f41a 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -600,18 +600,19 @@ def cast_(self, features: Features): Args: features (:class:`datasets.Features`): New features to cast the dataset to. - The name and order of the fields in the features must match the current column names. + The name of the fields in the features must match the current column names. The type of the data must also be convertible from one type to the other. For non-trivial conversion, e.g. string <-> ClassLabel you should use :func:`map` to update the Dataset. """ - if list(features) != self._data.column_names: + if sorted(features) != sorted(self._data.column_names): raise ValueError( - f"The columns in features ({list(features)}) must be identical and in the same order " + f"The columns in features ({list(features)}) must be identical " f"as the columns in the dataset: {self._data.column_names}" ) self._info.features = features - schema = pa.schema(features.type) + type = features.type + schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names}) self._data = self._data.cast(schema) @fingerprint(inplace=True) diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index fe55cca7573..b2d8f7059ef 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -49,7 +49,7 @@ def reduce_ex(self): datasets.arrow_dataset.logger.__reduce_ex__ = reduce_ex - def _create_dummy_dataset(self, in_memory: bool, tmp_dir: str, multiple_columns=False): + def _create_dummy_dataset(self, in_memory: bool, tmp_dir: str, multiple_columns=False) -> Dataset: if multiple_columns: data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]} dset = Dataset.from_dict(data) @@ -289,6 +289,7 @@ def test_cast_(self, in_memory): dset = self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) features = dset.features features["col_1"] = Value("float64") + features = Features({k: features[k] for k in list(features)[::-1]}) fingerprint = dset._fingerprint dset.cast_(features) self.assertEqual(dset.num_columns, 2)