-
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
import sys
import multiprocessing as mp
from concurrent.futures import as_completed, ProcessPoolExecutor, ThreadPoolExecutor
from typing import (
Callable,
List,
Dict,
Generator,
)
def run_in_thread_pool(
func: Callable,
params: List[Dict] = [],
) -> Generator:
"""
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
def task(seq, text):
return (seq, self._embedding_func(text, engine=self.deployment))
params = [{"seq": i, "text": text} for i, text in enumerate(texts)]
result = list(run_in_thread_pool(func=task, params=params))
"""
tasks = []
with ThreadPoolExecutor() as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))
for obj in as_completed(tasks):
try:
yield obj.result()
except Exception as e:
raise Exception("error in sub thread: {}".format(e))
def run_in_process_pool(
func: Callable,
params: List[Dict] = [],
) -> Generator:
"""
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
"""
tasks = []
max_workers = None
if sys.platform.startswith("win"):
max_workers = min(
mp.cpu_count(), 60
) # max_workers should not exceed 60 on windows
with ProcessPoolExecutor(max_workers=max_workers) as pool:
for kwargs in params:
tasks.append(pool.submit(func, **kwargs))
for obj in as_completed(tasks):
try:
yield obj.result()
except Exception as e:
raise Exception("error in sub thread: {}".format(e))