@@ -122,7 +122,7 @@ def EVAL(exprs, *args):
122122 return processed [0 ] if isinstance (exprs , str ) else processed
123123
124124
125- def parallel (item ):
125+ def parallel (item , m ):
126126 """
127127 Run a test in parallel. Readapted from:
128128
@@ -131,47 +131,44 @@ def parallel(item):
131131 mpi_exec = 'mpiexec'
132132 mpi_distro = sniff_mpi_distro (mpi_exec )
133133
134- marker = item .get_closest_marker ("parallel" )
135- mode = as_tuple (marker .kwargs .get ("mode" , 2 ))
136- for m in mode :
137- # Parse the `mode`
138- if isinstance (m , int ):
139- nprocs = m
140- scheme = 'basic'
141- else :
142- if len (m ) == 2 :
143- nprocs , scheme = m
144- else :
145- raise ValueError ("Can't run test: unexpected mode `%s`" % m )
146-
147- pyversion = sys .executable
148- # Only spew tracebacks on rank 0.
149- # Run xfailing tests to ensure that errors are reported to calling process
150- if item .cls is not None :
151- testname = "%s::%s::%s" % (item .fspath , item .cls .__name__ , item .name )
152- else :
153- testname = "%s::%s" % (item .fspath , item .name )
154- args = ["-n" , "1" , pyversion , "-m" , "pytest" , "--runxfail" , "-s" ,
155- "-q" , testname ]
156- if nprocs > 1 :
157- args .extend ([":" , "-n" , "%d" % (nprocs - 1 ), pyversion , "-m" , "pytest" ,
158- "--runxfail" , "--tb=no" , "-q" , testname ])
159- # OpenMPI requires an explicit flag for oversubscription. We need it as some
160- # of the MPI tests will spawn lots of processes
161- if mpi_distro == 'OpenMPI' :
162- call = [mpi_exec , '--oversubscribe' , '--timeout' , '300' ] + args
134+ # Parse the `mode`
135+ if isinstance (m , int ):
136+ nprocs = m
137+ scheme = 'basic'
138+ else :
139+ if len (m ) == 2 :
140+ nprocs , scheme = m
163141 else :
164- call = [ mpi_exec ] + args
142+ raise ValueError ( "Can't run test: unexpected mode `%s`" % m )
165143
166- # Tell the MPI ranks that they are running a parallel test
167- os .environ ['DEVITO_MPI' ] = scheme
168- try :
169- check_call (call )
170- return True
171- except :
172- return False
173- finally :
174- os .environ ['DEVITO_MPI' ] = '0'
144+ pyversion = sys .executable
145+ # Only spew tracebacks on rank 0.
146+ # Run xfailing tests to ensure that errors are reported to calling process
147+ if item .cls is not None :
148+ testname = "%s::%s::%s" % (item .fspath , item .cls .__name__ , item .name )
149+ else :
150+ testname = "%s::%s" % (item .fspath , item .name )
151+ args = ["-n" , "1" , pyversion , "-m" , "pytest" , "--runxfail" , "-q" , testname ]
152+ if nprocs > 1 :
153+ args .extend ([":" , "-n" , "%d" % (nprocs - 1 ), pyversion , "-m" , "pytest" ,
154+ "--runxfail" , "--tb=no" , "-q" , testname ])
155+ # OpenMPI requires an explicit flag for oversubscription. We need it as some
156+ # of the MPI tests will spawn lots of processes
157+ if mpi_distro == 'OpenMPI' :
158+ call = [mpi_exec , '--oversubscribe' , '--timeout' , '300' ] + args
159+ else :
160+ call = [mpi_exec ] + args
161+
162+ # Tell the MPI ranks that they are running a parallel test
163+ os .environ ['DEVITO_MPI' ] = scheme
164+ try :
165+ check_call (call )
166+ res = True
167+ except :
168+ res = False
169+ finally :
170+ os .environ ['DEVITO_MPI' ] = '0'
171+ return res
175172
176173
177174def pytest_configure (config ):
@@ -182,55 +179,45 @@ def pytest_configure(config):
182179 )
183180
184181
185- def pytest_runtest_setup (item ):
186- partest = os .environ .get ('DEVITO_MPI' , 0 )
187- try :
188- partest = int (partest )
189- except ValueError :
190- pass
191- if item .get_closest_marker ("parallel" ):
192- if MPI is None :
193- pytest .skip ("mpi4py/MPI not installed" )
194- else :
195- # Blow away function arg in "master" process, to ensure
196- # this test isn't run on only one process
197- dummy_test = lambda * args , ** kwargs : True
198- # For pytest <7
199- if item .cls is not None :
200- attr = item .originalname or item .name
201- setattr (item .cls , attr , dummy_test )
202- else :
203- item .obj = dummy_test
204- # For pytest >= 7
205- setattr (item , '_obj' , dummy_test )
182+ def pytest_generate_tests (metafunc ):
183+ # Process custom parallel marker as a parametrize to avoid
184+ # running a single test for all modes
185+ if 'mode' in metafunc .fixturenames :
186+ markers = metafunc .definition .iter_markers ()
187+ for marker in markers :
188+ if marker .name == 'parallel' :
189+ mode = list (as_tuple (marker .kwargs .get ('mode' , 2 )))
190+ metafunc .parametrize ("mode" , mode )
206191
207192
193+ @pytest .hookimpl (tryfirst = True , hookwrapper = True )
208194def pytest_runtest_call (item ):
209195 partest = os .environ .get ('DEVITO_MPI' , 0 )
210196 try :
211197 partest = int (partest )
212198 except ValueError :
213199 pass
200+
214201 if item .get_closest_marker ("parallel" ) and not partest :
215202 # Spawn parallel processes to run test
216- passed = parallel (item )
217- if not passed :
218- pytest .fail (f"{ item } failed in parallel execution " )
203+ outcome = parallel (item , item . funcargs [ 'mode' ] )
204+ if outcome :
205+ pytest .skip (f"{ item } success in parallel" )
219206 else :
220- pytest .skip (f"{ item } t passed in parallel execution" )
207+ pytest .fail (f"{ item } failed in parallel" )
208+ else :
209+ outcome = yield
221210
222211
223212@pytest .hookimpl (tryfirst = True , hookwrapper = True )
224213def pytest_runtest_makereport (item , call ):
225214 outcome = yield
226215 result = outcome .get_result ()
227-
228216 partest = os .environ .get ('DEVITO_MPI' , 0 )
229217 try :
230218 partest = int (partest )
231219 except ValueError :
232220 pass
233-
234221 if item .get_closest_marker ("parallel" ) and not partest :
235222 if call .when == 'call' and result .outcome == 'skipped' :
236223 result .outcome = 'passed'
0 commit comments