diff --git a/ommprotocol/md.py b/ommprotocol/md.py index 519b2f4..b1769d2 100644 --- a/ommprotocol/md.py +++ b/ommprotocol/md.py @@ -31,7 +31,10 @@ from .utils import (random_string, assert_not_exists, timed_input, available_platforms, warned_getattr) + logger = logging.getLogger(__name__) +OPENMM_VERSION = tuple(map(int, mm.__version__.split('.'))) + ########################### # Defaults @@ -219,7 +222,7 @@ def __init__(self, handler, positions=None, velocities=None, box=None, self.barostat_interval = int(barostat_interval) # Hardware self._platform = platform - self.platform_properties = platform_properties + self.platform_properties = {} if platform_properties is None else platform_properties # Output parameters self.project_name = project_name if project_name is not None else self._PROJECTNAME self.name = name if name is not None else random_string(length=5) @@ -421,8 +424,9 @@ def platform(self): if self._platform is None: return None, platform = mm.Platform.getPlatformByName(self._platform) - if self.platform_properties is None: - return platform, + if self._platform.upper() == 'CUDA' and OPENMM_VERSION < (7, 2, 3) \ + and 'DisablePmeStream' not in self.platform_properties: + self.platform_properties['DisablePmeStream'] = 'true' # Patch to allow env-defined GPUs device = self.platform_properties.get('DeviceIndex', '')