Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mapping gets stuck at 99% #6077

Open
Laurent2916 opened this issue Jul 26, 2023 · 4 comments
Open

Mapping gets stuck at 99% #6077

Laurent2916 opened this issue Jul 26, 2023 · 4 comments

Comments

@Laurent2916
Copy link
Contributor

Laurent2916 commented Jul 26, 2023

Describe the bug

Hi !

I'm currently working with a large (~150GB) unnormalized dataset at work.
The dataset is available on a read-only filesystem internally, and I use a loading script to retreive it.

I want to normalize the features of the dataset, meaning I need to compute the mean and standard deviation metric for each feature of the entire dataset. I cannot load the entire dataset to RAM as it is too big, so following this discussion on the huggingface discourse I am using a map operation to first compute the metrics and a second map operation to apply them on the dataset.

The problem lies in the second mapping, as it gets stuck at ~99%. By checking what the process does (using htop and strace) it seems to be doing a lot of I/O operations, and I'm not sure why.

Obviously, I could always normalize the dataset externally and then load it using a loading script. However, since the internal dataset is updated fairly frequently, using the library to perform normalization automatically would make it much easier for me.

Steps to reproduce the bug

I'm able to reproduce the problem using the following scripts:

# random_data.py

import datasets
import torch

_VERSION = "1.0.0"


class RandomDataset(datasets.GeneratorBasedBuilder):
    def _info(self):
        return datasets.DatasetInfo(
            version=_VERSION,
            supervised_keys=None,
            features=datasets.Features(
                {
                    "positions": datasets.Array2D(
                        shape=(30000, 3),
                        dtype="float32",
                    ),
                    "normals": datasets.Array2D(
                        shape=(30000, 3),
                        dtype="float32",
                    ),
                    "features": datasets.Array2D(
                        shape=(30000, 6),
                        dtype="float32",
                    ),
                    "scalars": datasets.Sequence(
                        feature=datasets.Value("float32"),
                        length=20,
                    ),
                },
            ),
        )

    def _split_generators(self, dl_manager):
        return [
            datasets.SplitGenerator(
                name=datasets.Split.TRAIN,  # type: ignore
                gen_kwargs={"nb_samples": 1000},
            ),
            datasets.SplitGenerator(
                name=datasets.Split.TEST,  # type: ignore
                gen_kwargs={"nb_samples": 100},
            ),
        ]

    def _generate_examples(self, nb_samples: int):
        for idx in range(nb_samples):
            yield idx, {
                "positions": torch.randn(30000, 3),
                "normals": torch.randn(30000, 3),
                "features": torch.randn(30000, 6),
                "scalars": torch.randn(20),
            }
# main.py

import datasets
import torch


def apply_mean_std(
    dataset: datasets.Dataset,
    means: dict[str, torch.Tensor],
    stds: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
    """Normalize the dataset using the mean and standard deviation of each feature.

    Args:
        dataset (`Dataset`): A huggingface dataset.
        mean (`dict[str, Tensor]`): A dictionary containing the mean of each feature.
        std (`dict[str, Tensor]`): A dictionary containing the standard deviation of each feature.

    Returns:
        dict: A dictionary containing the normalized dataset.
    """
    result = {}

    for key in means.keys():
        # extract data from dataset
        data: torch.Tensor = dataset[key]  # type: ignore

        # extract mean and std from dict
        mean = means[key]  # type: ignore
        std = stds[key]  # type: ignore

        # normalize data
        normalized_data = (data - mean) / std

        result[key] = normalized_data

    return result


# get dataset
ds = datasets.load_dataset(
    path="random_data.py",
    split="train",
).with_format("torch")

# compute mean (along last axis)
means = {key: torch.zeros(ds[key][0].shape[-1]) for key in ds.column_names}
means_sq = {key: torch.zeros(ds[key][0].shape[-1]) for key in ds.column_names}

for batch in ds.iter(batch_size=8):
    for key in ds.column_names:
        data = batch[key]
        batch_size = data.shape[0]
        data = data.reshape(-1, data.shape[-1])
        means[key] += data.mean(dim=0) / len(ds) * batch_size
        means_sq[key] += (data**2).mean(dim=0) / len(ds) * batch_size

# compute std (along last axis)
stds = {key: torch.sqrt(means_sq[key] - means[key] ** 2) for key in ds.column_names}

# normalize each feature of the dataset
ds_normalized = ds.map(
    desc="Applying mean/std",  # type: ignore
    function=apply_mean_std,
    batched=False,
    fn_kwargs={
        "means": means,
        "stds": stds,
    },
)

Expected behavior

Using the previous scripts, the ds_normalized mapping completes in ~5 minutes, but any subsequent use of ds_normalized is really really slow, for example reapplying apply_mean_std to ds_normalized takes forever. This is very strange, I'm sure I must be missing something, but I would still expect this to be faster.

Environment info

  • datasets version: 2.13.1
  • Platform: Linux-3.10.0-1160.66.1.el7.x86_64-x86_64-with-glibc2.17
  • Python version: 3.10.12
  • Huggingface_hub version: 0.15.1
  • PyArrow version: 12.0.0
  • Pandas version: 2.0.2
@mariosasko
Copy link
Collaborator

The MAX_MAP_BATCH_SIZE = 1_000_000_000 hack is bad as it loads the entire dataset into RAM when performing .map. Instead, it's best to use .iter(batch_size) to iterate over the data batches and compute mean for each column. (stddev can be computed in another pass).

Also, these arrays are big, so it makes sense to reduce batch_size/writer_batch_size to avoid RAM issues and slow IO.

@Laurent2916
Copy link
Contributor Author

Hi @mariosasko !

I agree, it's an ugly hack, but it was convenient since the resulting mean_std could be cached by the library. For my large dataset (which doesn't fit in RAM), I'm actually using something similar to what you suggested. I got rid of the first mapping in the above scripts and replaced it with an iterator, but the issue with the second mapping still persists.

@mariosasko
Copy link
Collaborator

Have you tried to reduce batch_size/writer_batch_size in the 2nd .map? Also, can you interrupt the process when it gets stuck and share the error stack trace?

@Laurent2916
Copy link
Contributor Author

I think batch_size/writer_batch_size is already at its lowest in the 2nd .map since batched=False implies batch_size=1 and len(ds) = 1000 = writer_batch_size.

Here is also a bunch of stack traces when I interrupted the process:

stack trace 1
(pyg)[d623204@rosetta-bigviz01 stage-laurent-f]$ python src/random_scripts/uses_random_data.py 
Found cached dataset random_data (/local_scratch/lfainsin/.cache/huggingface/datasets/random_data/default/0.0.0/444e214e1d0e6298cfd3f2368323ec37073dc1439f618e19395b1f421c69b066)
Applying mean/std:  97%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████     | 967/1000 [00:01<00:00, 534.87 examples/s]Traceback (most recent call last):                                                                                                                                                                                      
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3449, in _map_single
    writer.write(example)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 490, in write
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 320, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 263, in _cast_to_python_objects
    def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool) -> Tuple[Any, bool]:
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs_new/data/users/lfainsin/stage-laurent-f/src/random_scripts/uses_random_data.py", line 62, in <module>
    ds_normalized = ds.map(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 580, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 545, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3087, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3492, in _map_single
    writer.finalize()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 584, in finalize
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in <listcomp>
    [
KeyboardInterrupt
stack trace 2
(pyg)[d623204@rosetta-bigviz01 stage-laurent-f]$ python src/random_scripts/uses_random_data.py 
Found cached dataset random_data (/local_scratch/lfainsin/.cache/huggingface/datasets/random_data/default/0.0.0/444e214e1d0e6298cfd3f2368323ec37073dc1439f618e19395b1f421c69b066)
Applying mean/std:  99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 988/1000 [00:20<00:00, 526.19 examples/s]Applying mean/std: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 999/1000 [00:21<00:00,  9.66 examples/s]Traceback (most recent call last):                                                                                                                                                                                      
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3449, in _map_single
    writer.write(example)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 490, in write
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 320, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 263, in _cast_to_python_objects
    def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_casting: bool) -> Tuple[Any, bool]:
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs_new/data/users/lfainsin/stage-laurent-f/src/random_scripts/uses_random_data.py", line 62, in <module>
    ds_normalized = ds.map(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 580, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 545, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3087, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3492, in _map_single
    writer.finalize()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 584, in finalize
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 320, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 291, in _cast_to_python_objects
    if config.JAX_AVAILABLE and "jax" in sys.modules:
KeyboardInterrupt
stack trace 3
(pyg)[d623204@rosetta-bigviz01 stage-laurent-f]$ python src/random_scripts/uses_random_data.py 
Found cached dataset random_data (/local_scratch/lfainsin/.cache/huggingface/datasets/random_data/default/0.0.0/444e214e1d0e6298cfd3f2368323ec37073dc1439f618e19395b1f421c69b066)
Applying mean/std:  99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎ | 989/1000 [00:01<00:00, 504.80 examples/s]Traceback (most recent call last):                                                                                                                                                                                      
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3449, in _map_single
    writer.write(example)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 490, in write
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 320, in <listcomp>
    _cast_to_python_objects(
KeyboardInterrupt

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 179, in __arrow_array__
    storage = to_pyarrow_listarray(data, pa_type)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 1466, in to_pyarrow_listarray
    return pa.array(data, pa_type.storage_dtype)
  File "pyarrow/array.pxi", line 320, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 39, in pyarrow.lib._sequence_to_array
  File "pyarrow/error.pxi", line 144, in pyarrow.lib.pyarrow_internal_check_status
  File "pyarrow/error.pxi", line 123, in pyarrow.lib.check_status
pyarrow.lib.ArrowTypeError: Could not convert tensor([[-1.0273, -0.8037, -0.6860],
        [-0.5034, -1.2685, -0.0558],
        [-1.0908, -1.1820, -0.3178],
        ...,
        [-0.8171,  0.1781, -0.5903],
        [ 0.4370,  1.9305,  0.5899],
        [-0.1426,  0.9053, -1.7559]]) with type Tensor: was not a sequence or recognized null for conversion to list type

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/gpfs_new/data/users/lfainsin/stage-laurent-f/src/random_scripts/uses_random_data.py", line 62, in <module>
    ds_normalized = ds.map(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 580, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 545, in wrapper
    out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3087, in map
    for rank, done, content in Dataset._map_single(**dataset_kwargs):
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_dataset.py", line 3492, in _map_single
    writer.finalize()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 584, in finalize
    self.write_examples_on_file()
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 448, in write_examples_on_file
    self.write_batch(batch_examples=batch_examples)
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 553, in write_batch
    arrays.append(pa.array(typed_sequence))
  File "pyarrow/array.pxi", line 236, in pyarrow.lib.array
  File "pyarrow/array.pxi", line 110, in pyarrow.lib._handle_arrow_array_protocol
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/arrow_writer.py", line 223, in __arrow_array__
    return pa.array(cast_to_python_objects(data, only_1d_for_numpy=True))
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 446, in cast_to_python_objects
    return _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 407, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 408, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 319, in _cast_to_python_objects
    [
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 320, in <listcomp>
    _cast_to_python_objects(
  File "/local_scratch/lfainsin/.conda/envs/pyg/lib/python3.10/site-packages/datasets/features/features.py", line 298, in _cast_to_python_objects
    if obj.ndim == 0:
KeyboardInterrupt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants