Skip to content

Commit

Permalink
Joblib parallel valid input for convlutions.
Browse files Browse the repository at this point in the history
  • Loading branch information
jason-neal committed Sep 28, 2018
1 parent 0e7c50b commit 741e59f
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions eniric/broaden.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def rotational_convolution(
*,
epsilon: float = 0.6,
normalize: bool = True,
num_procs: Optional[Union[int, pool.Pool]] = None,
num_procs: Optional[Union[int, pool.Pool, joblib.parallel.Parallel]] = None,
verbose: bool = True,
) -> ndarray:
"""Perform Rotational convolution.
Expand All @@ -59,7 +59,7 @@ def rotational_convolution(
Number of processes to use with multiprocess.
If None it is assigned to 1 less then total number of cores.
If num_procs = 0 or 1, then multiprocess is not used.
Can also be a multiprocess.pool.Pool instance.
Can also be a Joblib.parallel.Parallel or multiprocess.pool.Pool instance.
verbose: bool
Show the tqdm progress bar (default = True).
Expand Down Expand Up @@ -130,13 +130,19 @@ def element_rot_convolution(single_wav: float) -> float:
convolved_flux = np.empty_like(wavelength) # Memory assignment
for ii, single_wav in enumerate(tqdm_wav):
convolved_flux[ii] = element_rot_convolution(single_wav)
elif isinstance(num_procs, joblib.parallel.Parallel):
convolved_flux = np.array(
num_procs(delayed(element_rot_convolution)(wav) for wav in tqdm_wav)
)
else:
try:
# Assume num_procs was a multiprocess.pool.Pool
convolved_flux = np.array(num_procs.map(element_rot_convolution, tqdm_wav, chunksize=1023))
convolved_flux = np.array(
num_procs.map(element_rot_convolution, tqdm_wav, chunksize=1023)
)
except AttributeError:
raise TypeError(
"num_proc must be an int or a multiprocess Pool. Not '{}'".format(
"num_proc must be an int, joblib Parallel or a multiprocess Pool. Not '{}'".format(
type(num_procs)
)
)
Expand All @@ -153,7 +159,7 @@ def resolution_convolution(
*,
fwhm_lim: float = 5.0,
normalize: bool = True,
num_procs: Optional[Union[int, pool.Pool]] = None,
num_procs: Optional[Union[int, pool.Pool, joblib.parallel.Parallel]] = None,
verbose: bool = True,
) -> ndarray:
"""Perform Resolution convolution.
Expand All @@ -176,7 +182,7 @@ def resolution_convolution(
Number of processes to use with multiprocess.
If None it is assigned to 1 less then total number of cores.
If num_procs = 0 or 1, then multiprocess is not used.
Can also be a multiprocess.pool.Pool instance.
Can also be a joblib.parallel.Parallel or multiprocess.pool.Pool instance.
verbose: bool
Show the tqdm progress bar (default = True).
Expand Down Expand Up @@ -240,13 +246,20 @@ def element_res_convolution(single_wav: float) -> float:
convolved_flux = np.empty_like(wavelength) # Memory assignment
for jj, single_wav in enumerate(tqdm_wav):
convolved_flux[jj] = element_res_convolution(single_wav)

elif isinstance(num_procs, joblib.parallel.Parallel):
convolved_flux = np.array(
num_procs(delayed(element_res_convolution)(wav) for wav in tqdm_wav)
)
else:
# Assume num_procs was a multiprocess.pool.Pool
try:
convolved_flux = np.array(num_procs.map(element_res_convolution, tqdm_wav, chunksize=1023))
convolved_flux = np.array(
num_procs.map(element_res_convolution, tqdm_wav, chunksize=1023)
)
except AttributeError:
raise TypeError(
"num_proc must be an int or a multiprocess Pool. Not '{}'".format(
"num_proc must be an int or a multiprocess Pool or joblib.Parallel. Not '{}'".format(
type(num_procs)
)
)
Expand All @@ -263,7 +276,7 @@ def convolution(
*,
epsilon: float = 0.6,
fwhm_lim: float = 5.0,
num_procs: Optional[Union[int, pool.Pool]] = None,
num_procs: Optional[Union[int, pool.Pool, joblib.parallel.Parallel]] = None,
normalize: bool = True,
verbose: bool = True,
):
Expand Down Expand Up @@ -293,7 +306,7 @@ def convolution(
Number of processes to use with multiprocess.
If None it is assigned to 1 less then total number of cores.
If num_procs = 0 or 1, then multiprocess is not used.
Can also be a multiprocess.pool.Pool instance.
Can also be a joblib.parallel.Parallel or multiprocess.pool.Pool instance.
verbose: bool
Show the twdm progress bars (default = True).
Expand Down

0 comments on commit 741e59f

Please sign in to comment.