diff --git a/easybuild/tools/toolchain/mpi.py b/easybuild/tools/toolchain/mpi.py index 2f501d2fec..35ecc28188 100644 --- a/easybuild/tools/toolchain/mpi.py +++ b/easybuild/tools/toolchain/mpi.py @@ -48,7 +48,7 @@ _log = fancylogger.getLogger('tools.toolchain.mpi', fname=False) -def get_mpi_cmd_template(mpi_family, params, mpi_version=None): +def get_mpi_cmd_template(mpi_family, params, mpi_version=None, oversubscribe=False): """ Return template for MPI command, for specified MPI family. @@ -123,6 +123,50 @@ def get_mpi_cmd_template(mpi_family, params, mpi_version=None): else: raise EasyBuildError("Don't know which template MPI command to use for MPI family '%s'", mpi_family) + if oversubscribe: + osub_cmd = '' + if mpi_family in [toolchain.OPENMPI]: + if mpi_version is None: + raise EasyBuildError("OpenMPI version unknown, can't determine how to handle oversubscription!") + if LooseVersion(mpi_version) < '5': + varname = 'OMPI_MCA_rmaps_base_oversubscribe' + varvalue = os.getenv(varname) + if varvalue and varvalue != '1': + _log.warning("Overwriting existing %s=%s with %s=1", varname, varvalue, varname) + osub_cmd = f'{varname}=1' + else: + varname = 'PRTE_MCA_rmaps_default_mapping_policy' + varvalue = os.getenv(varname) + + # This logic should account for: + # - var not set -> set to 'core:oversubscribe' + # - unit set to value without `:` eg package -> 'package:oversubscribe' + # - unit set to value with `:` eg ppr:4:numa -> 'ppr:4:numa:oversubscribe' + # - all of the above but with oversubscribe already in flags + flags = '' + if varvalue.startswith('ppr'): + _log.warning("Can't handle ppr mapping with oversubscription yet, overwriting unit with 'core'") + unit = 'core' + flags = 'oversubscribe' + else: + unit, flags = (varvalue.rsplit(':', maxsplit=1) + [''])[:2] + unit = unit or 'core' + flags = list(filter(None, flags.split(','))) + if 'oversubscribe' not in flags: + flags.append('oversubscribe') + newvalue = f"{unit}:{','.join(flags)}" + + osub_cmd = f'{varname}={newvalue}' + elif mpi_family in [toolchain.INTELMPI]: + _log.info("INTELMPI always oversubscribe by default, nothing to do...") + elif mpi_family in [toolchain.MVAPICH2, toolchain.MPICH, toolchain.MPICH2]: + _log.info("MPICH always oversubscribe by default, nothing to do...") + else: + raise EasyBuildError("Oversubscribe not supported for MPI family '%s'", mpi_family) + + mpi_cmd_template = f'%(oversubscribe)s {mpi_cmd_template}' + params.update({'oversubscribe': osub_cmd}) # just a placeholder + missing = [] for key in sorted(params.keys()): tmpl = '%(' + key + ')s' @@ -270,7 +314,7 @@ def mpi_cmd_prefix(self, nr_ranks=1): return result - def mpi_cmd_for(self, cmd, nr_ranks): + def mpi_cmd_for(self, cmd, nr_ranks, oversubscribe=False): """Construct an MPI command for the given command and number of ranks.""" # parameter values for mpirun command @@ -281,20 +325,20 @@ def mpi_cmd_for(self, cmd, nr_ranks): mpi_family = self.mpi_family() - mpi_version = None + # this fails when it's done too early (before modules for toolchain/dependencies are loaded), + # but it's safe to ignore this + mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0] if mpi_family == toolchain.INTELMPI: - # for Intel MPI, try to determine impi version - # this fails when it's done too early (before modules for toolchain/dependencies are loaded), - # but it's safe to ignore this - mpi_version = self.get_software_version(self.MPI_MODULE_NAME, required=False)[0] if not mpi_version: self.log.debug("Ignoring error when trying to determine %s version", self.MPI_MODULE_NAME) # impi version is required to determine correct MPI command template, # so we have to return early if we couldn't determine the impi version... return None - mpi_cmd_template, params = get_mpi_cmd_template(mpi_family, params, mpi_version=mpi_version) + mpi_cmd_template, params = get_mpi_cmd_template( + mpi_family, params, mpi_version=mpi_version, oversubscribe=oversubscribe + ) self.log.info("Using MPI command template '%s' (params: %s)", mpi_cmd_template, params) try: