Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use multiprocess from pathos for multiprocessing #656

Merged
merged 7 commits into from Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.py
Expand Up @@ -79,7 +79,9 @@
# filesystem locks e.g. to prevent parallel downloads
"filelock",
# for fast hashing
"xxhash"
"xxhash",
# for better multiprocessing
"multiprocess"
]

BENCHMARKS_REQUIRE = [
Expand Down
24 changes: 14 additions & 10 deletions src/datasets/arrow_dataset.py
Expand Up @@ -17,6 +17,7 @@
""" Simple Dataset wrapping an Arrow Table."""

import contextlib
import copy
import json
import os
import pickle
Expand All @@ -27,13 +28,13 @@
from dataclasses import asdict
from functools import partial, wraps
from math import ceil, floor
from multiprocessing import Pool, RLock
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas as pd
import pyarrow as pa
from multiprocess import Pool, RLock
from tqdm.auto import tqdm

from .arrow_reader import ArrowReader
Expand Down Expand Up @@ -389,6 +390,8 @@ def __getstate__(self):
state["_data"] = None
if self._indices_data_files:
state["_indices"] = None
logger.debug("Copying history")
state["_inplace_history"] = [{"transforms": list(h["transforms"])} for h in state["_inplace_history"]]
return state

def __setstate__(self, state):
Expand Down Expand Up @@ -1684,15 +1687,18 @@ def _new_dataset_with_indices(
indices_pa_table = indices_f.read_all()

# Return new Dataset object
# don't forget to copy the objects
return Dataset(
self._data,
data_files=data_files,
info=self.info,
data_files=copy.deepcopy(data_files),
info=self.info.copy(),
split=self.split,
indices_table=indices_pa_table,
indices_data_files=indices_data_files,
indices_data_files=copy.deepcopy(indices_data_files),
fingerprint=fingerprint,
inplace_history=self._inplace_history, # in-place transforms have to be kept as we kept the same data_files
inplace_history=copy.deepcopy(
self._inplace_history
), # in-place transforms have to be kept as we kept the same data_files
)

@transmit_format
Expand Down Expand Up @@ -2486,8 +2492,8 @@ def concatenate_datasets(
# Concatenate tables

table = pa.concat_tables(dset._data for dset in dsets if len(dset._data) > 0)
data_files = [f for dset in dsets for f in dset._data_files]
inplace_history = [h for dset in dsets for h in dset._inplace_history]
data_files = [copy.deepcopy(f) for dset in dsets for f in dset._data_files]
inplace_history = [copy.deepcopy(h) for dset in dsets for h in dset._inplace_history]

def apply_offset_to_indices_table(table, offset):
if offset == 0:
Expand Down Expand Up @@ -2544,10 +2550,8 @@ def apply_offset_to_indices_table(table, offset):
indices_table = pa.concat_tables(indices_tables)
else:
indices_table = pa.Table.from_batches([], schema=pa.schema({"indices": pa.int64()}))
indices_data_files = None # can't reuse same files as an offset was applied
else:
indices_table = None
indices_data_files = None
if info is None:
info = DatasetInfo.from_merge([dset.info for dset in dsets])
fingerprint = update_fingerprint(
Expand All @@ -2559,7 +2563,7 @@ def apply_offset_to_indices_table(table, offset):
split=split,
data_files=data_files,
indices_table=indices_table,
indices_data_files=indices_data_files,
indices_data_files=None, # can't reuse same files as an offset was applied
fingerprint=fingerprint,
inplace_history=inplace_history,
)
Expand Down
2 changes: 1 addition & 1 deletion src/datasets/info.py
Expand Up @@ -160,7 +160,7 @@ def _dump_license(self, file):

@classmethod
def from_merge(cls, dataset_infos: List["DatasetInfo"]):
dataset_infos = [dset_info for dset_info in dataset_infos if dset_info is not None]
dataset_infos = [dset_info.copy() for dset_info in dataset_infos if dset_info is not None]
description = "\n\n".join([info.description for info in dataset_infos])
citation = "\n\n".join([info.citation for info in dataset_infos])
homepage = "\n\n".join([info.homepage for info in dataset_infos])
Expand Down
19 changes: 17 additions & 2 deletions tests/test_arrow_dataset.py
Expand Up @@ -532,7 +532,7 @@ def func(x, i):
del dset, dset_test

def test_map_multiprocessing(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir: # standard
dset = self._create_dummy_dataset(in_memory, tmp_dir)

self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
Expand All @@ -549,7 +549,7 @@ def test_map_multiprocessing(self, in_memory):
self.assertNotEqual(dset_test._fingerprint, fingerprint)
del dset, dset_test

with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir: # with_indices
dset = self._create_dummy_dataset(in_memory, tmp_dir)
fingerprint = dset._fingerprint
dset_test = dset.map(picklable_map_function_with_indices, num_proc=3, with_indices=True)
Expand All @@ -564,6 +564,21 @@ def test_map_multiprocessing(self, in_memory):
self.assertNotEqual(dset_test._fingerprint, fingerprint)
del dset, dset_test

with tempfile.TemporaryDirectory() as tmp_dir: # lambda (requires multiprocess from pathos)
dset = self._create_dummy_dataset(in_memory, tmp_dir)
fingerprint = dset._fingerprint
dset_test = dset.map(lambda x: {"id": int(x["filename"].split("_")[-1])}, num_proc=2)
self.assertEqual(len(dset_test), 30)
self.assertDictEqual(dset.features, Features({"filename": Value("string")}))
self.assertDictEqual(
dset_test.features,
Features({"filename": Value("string"), "id": Value("int64")}),
)
self.assertEqual(len(dset_test._data_files), 0 if in_memory else 2)
self.assertListEqual(dset_test["id"], list(range(30)))
self.assertNotEqual(dset_test._fingerprint, fingerprint)
del dset, dset_test

def test_new_features(self, in_memory):
with tempfile.TemporaryDirectory() as tmp_dir:
dset = self._create_dummy_dataset(in_memory, tmp_dir)
Expand Down