Skip to content

Commit

Permalink
Fix prefetch num_workers < 0
Browse files Browse the repository at this point in the history
  • Loading branch information
nlgranger committed Nov 23, 2023
1 parent e2df6c4 commit 3fc440f
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions seqtools/evaluation.py
Expand Up @@ -43,7 +43,7 @@ class ProcessBacked(AsyncWorker):

def __init__(self, seq, num_workers=0, buffer_size=10, init_fn=None, shm_size=0):
if num_workers <= 0:
num_workers = multiprocessing.cpu_count() - num_workers
num_workers = len(os.sched_getaffinity(0)) + num_workers
if num_workers <= 0:
raise ValueError("at least one worker required")
if buffer_size < num_workers:
Expand Down Expand Up @@ -274,7 +274,7 @@ class ThreadBackend(AsyncWorker):

def __init__(self, seq, num_workers=0, init_fn=None):
if num_workers <= 0:
num_workers = multiprocessing.cpu_count() - num_workers
num_workers = len(os.sched_getaffinity(0)) + num_workers
if num_workers <= 0:
raise ValueError("at least one worker required")

Expand Down
2 changes: 1 addition & 1 deletion tests/test_evaluation.py
Expand Up @@ -90,7 +90,7 @@ def compare_random_objects(a, b):
@pytest.mark.parametrize("prefetch_kwargs", prefetch_kwargs_set)
def test_prefetch_random_objects(prefetch_kwargs):
seq = [build_random_object() for _ in range(1000)]
y = prefetch(seq, -1, **prefetch_kwargs)
y = prefetch(seq, -1, max_buffered=len(os.sched_getaffinity(0)), **prefetch_kwargs)

assert len(seq) == len(y)

Expand Down

0 comments on commit 3fc440f

Please sign in to comment.