Skip to content

Commit

Permalink
Merge pull request #11 from lsst/u/bvan/DM-27377_mp_context
Browse files Browse the repository at this point in the history
Use multiprocessing context with fork on all systems
  • Loading branch information
erykoff committed Dec 7, 2020
2 parents cef539a + 739ec38 commit 793f82f
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 25 deletions.
4 changes: 2 additions & 2 deletions fgcm/fgcmBrightObs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .fgcmChisq import FgcmChisq

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -143,7 +142,8 @@ def brightestObsMeanMag(self,debug=False,computeSEDSlopes=False):
(nSections,time.time() - prepStartTime))

# make a pool
pool = Pool(processes=self.nCore)
mp_ctx = multiprocessing.get_context("fork")
pool = mp_ctx.Pool(processes=self.nCore)
pool.map(self._worker,workerList,chunksize=1)
pool.close()
pool.join()
Expand Down
6 changes: 3 additions & 3 deletions fgcm/fgcmChisq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from .fgcmNumbaUtilities import numba_test, add_at_1d, add_at_2d, add_at_3d

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -325,8 +324,9 @@ def __call__(self,fitParams,fitterUnits=False,computeDerivatives=False,computeSE
else:
# regular multi-core

mp_ctx = multiprocessing.get_context('fork')
# make a dummy process to discover starting child number
proc = multiprocessing.Process()
proc = mp_ctx.Process()
workerIndex = proc._identity[0]+1
proc = None

Expand Down Expand Up @@ -366,7 +366,7 @@ def __call__(self,fitParams,fitterUnits=False,computeDerivatives=False,computeSE
self.fgcmLog.debug('Running chisq on %d cores' % (self.nCore))

# make a pool
pool = Pool(processes=self.nCore)
pool = mp_ctx.Pool(processes=self.nCore)
# Compute magnitudes
pool.map(self._magWorker, workerList, chunksize=1)

Expand Down
6 changes: 3 additions & 3 deletions fgcm/fgcmComputeStepUnits.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .fgcmNumbaUtilities import numba_test, add_at_1d, add_at_2d, add_at_3d

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -112,7 +111,8 @@ def run(self, fitParams):
# going to have one or the other, and this doesn't care which is which
self.nSums += 2 * self.fgcmPars.nFitPars

proc = multiprocessing.Process()
mp_ctx = multiprocessing.get_context("fork")
proc = mp_ctx.Process()
workerIndex = proc._identity[0]+1
proc = None

Expand Down Expand Up @@ -141,7 +141,7 @@ def run(self, fitParams):
workerList.sort(key=lambda elt:elt[1].size, reverse=True)

# make a pool
pool = Pool(processes=self.nCore)
pool = mp_ctx.Pool(processes=self.nCore)
# Compute magnitudes
pool.map(self._stepWorker, workerList, chunksize=1)

Expand Down
10 changes: 0 additions & 10 deletions fgcm/fgcmFitCycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,6 @@

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

import multiprocessing
import platform


# Fix for multiprocessing/matplotlib but on python <= 3.7 and macos >= 10.15
if len(platform.mac_ver()[0]) > 0 and sys.version_info.major == 3 and sys.version_info.minor < 8:
parts = platform.mac_ver()[0].split('.')
if (int(parts[0]) > 10) or (int(parts[0]) == 10 and int(parts[0]) >= 15):
multiprocessing.set_start_method('forkserver')


class FgcmFitCycle(object):
"""
Expand Down
4 changes: 2 additions & 2 deletions fgcm/fgcmRetrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import matplotlib.pyplot as plt

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -149,7 +148,8 @@ def computeRetrievalIntegrals(self,debug=False):
uExpIndexList = np.array_split(uExpIndex,nSections)

# may want to sort by nObservations, but only if we pre-split
pool = Pool(processes=self.nCore)
mp_ctx = multiprocessing.get_context("fork")
pool = mp_ctx.Pool(processes=self.nCore)
pool.map(self._worker, uExpIndexList, chunksize=1)
pool.close()
pool.join()
Expand Down
4 changes: 2 additions & 2 deletions fgcm/fgcmSigmaCal.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .fgcmUtilities import retrievalFlagDict

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -183,7 +182,8 @@ def run(self, applyGray=True):
for i, s in enumerate(sigmaCals):
self.sigmaCal = s

pool = Pool(processes=self.nCore)
mp_ctx = multiprocessing.get_context("fork")
pool = mp_ctx.Pool(processes=self.nCore)
pool.map(self._worker, workerList, chunksize=1)
pool.close()
pool.join()
Expand Down
6 changes: 3 additions & 3 deletions fgcm/fgcmZpsToApply.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fgcmNumbaUtilities import numba_test, add_at_1d, add_at_2d, add_at_3d

import multiprocessing
from multiprocessing import Pool

from .sharedNumpyMemManager import SharedNumpyMemManager as snmm

Expand Down Expand Up @@ -138,7 +137,8 @@ def applyZeropoints(self):

goodStarsSub, goodObs = self.fgcmStars.getGoodObsIndices(goodStars, expFlag=self.fgcmPars.expFlag)

proc = multiprocessing.Process()
mp_ctx = multiprocessing.get_context('fork')
proc = mp_ctx.Process()
workerIndex = proc._identity[0] + 1
proc = None

Expand All @@ -157,7 +157,7 @@ def applyZeropoints(self):

workerList.sort(key=lambda elt:elt[1].size, reverse=True)

pool = Pool(processes=self.nCore)
pool = mp_ctx.Pool(processes=self.nCore)
pool.map(self._worker, workerList, chunksize=1)

pool.close()
Expand Down

0 comments on commit 793f82f

Please sign in to comment.