Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Importing PyTorch reduces multiprocessing performance for map #5929

Closed
Maxscha opened this issue Jun 6, 2023 · 2 comments
Closed

Importing PyTorch reduces multiprocessing performance for map #5929

Maxscha opened this issue Jun 6, 2023 · 2 comments

Comments

@Maxscha
Copy link

Maxscha commented Jun 6, 2023

Describe the bug

I noticed that the performance of my dataset preprocessing with map(...,num_proc=32) decreases when PyTorch is imported.

Steps to reproduce the bug

I created two example scripts to reproduce this behavior:

import datasets
datasets.disable_caching()

from datasets import Dataset
import time
    
PROC=32

if __name__ == "__main__":
    dataset = [True] * 10000000
    dataset = Dataset.from_dict({'train': dataset})
    

    start = time.time()
    dataset.map(lambda x: x, num_proc=PROC)
    end = time.time()
    print(end - start)

Takes around 4 seconds on my machine.

While the same code, but with an import torch:

import datasets
datasets.disable_caching()

from datasets import Dataset
import time
import torch
    
PROC=32

if __name__ == "__main__":
    dataset = [True] * 10000000
    dataset = Dataset.from_dict({'train': dataset})
    

    start = time.time()
    dataset.map(lambda x: x, num_proc=PROC)
    end = time.time()
    print(end - start)

takes around 22 seconds.

Expected behavior

I would expect that the import of torch to not have such a significant effect on the performance of map using multiprocessing.

Environment info

  • datasets version: 2.12.0
  • Platform: Linux-5.15.0-56-generic-x86_64-with-glibc2.35
  • Python version: 3.11.3
  • Huggingface_hub version: 0.15.1
  • PyArrow version: 12.0.0
  • Pandas version: 2.0.2
  • torch: 2.0.1
@mariosasko
Copy link
Collaborator

Hi! The times match when I run this code locally or on Colab.

Also, we use multiprocess, not multiprocessing, for parallelization, and torch's __init__.py (executed on import torch ) slightly modifies the latter.

@Maxscha
Copy link
Author

Maxscha commented Jun 16, 2023

Hey Mariosasko,

Thanks for looking into it. We further did some investigations after your comment and figured out it's only affecting some hardware/software configurations with the pytorch installation of conda-forge. Based on this we found the following issue in PyTorch: pytorch/pytorch#102269 with a quick fix for now.

Since it seems to be a deeper issue with forking processes, the difference betweenmultiprocess and multiprocessing didn't make a difference.

Closing this, since the issue comes from pytorch not dataset.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants