Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix column order issue in cast #684

Merged
merged 1 commit into from Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions src/datasets/arrow_dataset.py
Expand Up @@ -600,18 +600,19 @@ def cast_(self, features: Features):

Args:
features (:class:`datasets.Features`): New features to cast the dataset to.
The name and order of the fields in the features must match the current column names.
The name of the fields in the features must match the current column names.
The type of the data must also be convertible from one type to the other.
For non-trivial conversion, e.g. string <-> ClassLabel you should use :func:`map` to update the Dataset.
"""
if list(features) != self._data.column_names:
if sorted(features) != sorted(self._data.column_names):
raise ValueError(
f"The columns in features ({list(features)}) must be identical and in the same order "
f"The columns in features ({list(features)}) must be identical "
f"as the columns in the dataset: {self._data.column_names}"
)

self._info.features = features
schema = pa.schema(features.type)
type = features.type
schema = pa.schema({col_name: type[col_name].type for col_name in self._data.column_names})
self._data = self._data.cast(schema)

@fingerprint(inplace=True)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_arrow_dataset.py
Expand Up @@ -49,7 +49,7 @@ def reduce_ex(self):

datasets.arrow_dataset.logger.__reduce_ex__ = reduce_ex

def _create_dummy_dataset(self, in_memory: bool, tmp_dir: str, multiple_columns=False):
def _create_dummy_dataset(self, in_memory: bool, tmp_dir: str, multiple_columns=False) -> Dataset:
if multiple_columns:
data = {"col_1": [3, 2, 1, 0], "col_2": ["a", "b", "c", "d"]}
dset = Dataset.from_dict(data)
Expand Down Expand Up @@ -289,6 +289,7 @@ def test_cast_(self, in_memory):
dset = self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True)
features = dset.features
features["col_1"] = Value("float64")
features = Features({k: features[k] for k in list(features)[::-1]})
fingerprint = dset._fingerprint
dset.cast_(features)
self.assertEqual(dset.num_columns, 2)
Expand Down