Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions paths_cli/compiling/core.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import json
import yaml

from collections import namedtuple

import logging

from .errors import InputError
from paths_cli.utils import import_thing
from paths_cli.plugin_management import OPSPlugin
from paths_cli.compiling.errors import InputError


def listify(obj):
listified = False
Expand All @@ -16,16 +13,20 @@ def listify(obj):
listified = True
return obj, listified


def unlistify(obj, listified):
if listified:
assert len(obj) == 1
obj = obj[0]
return obj


REQUIRED_PARAMETER = object()


class Parameter:
SCHEMA = "http://openpathsampling.org/schemas/sim-setup/draft01.json"

def __init__(self, name, loader, *, json_type=None, description=None,
default=REQUIRED_PARAMETER, aliases=None):
if isinstance(json_type, str):
Expand Down Expand Up @@ -150,6 +151,7 @@ class InstanceCompilerPlugin(OPSPlugin):
"""
SCHEMA = "http://openpathsampling.org/schemas/sim-setup/draft01.json"
category = None

def __init__(self, builder, parameters, name=None, aliases=None,
requires_ops=(1, 0), requires_cli=(0, 3)):
super().__init__(requires_ops, requires_cli)
Expand Down Expand Up @@ -239,7 +241,7 @@ def __call__(self, dct):
ops_dct = self.compile_attrs(dct)
self.logger.debug("Building...")
self.logger.debug(ops_dct)
obj = self.builder(**ops_dct)
obj = self.builder(**ops_dct)
self.logger.debug(obj)
return obj

Expand All @@ -260,7 +262,7 @@ def _compile_str(self, name):
self.logger.debug(f"Looking for '{name}'")
try:
return self.named_objs[name]
except KeyError as e:
except KeyError:
raise InputError.unknown_name(self.label, name)

def _compile_dict(self, dct):
Expand Down
12 changes: 5 additions & 7 deletions paths_cli/compiling/cvs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import importlib

from .core import Parameter, Builder
from .tools import custom_eval
from .topology import build_topology
from .errors import InputError
from paths_cli.compiling.core import Parameter, Builder
from paths_cli.compiling.tools import custom_eval
from paths_cli.compiling.topology import build_topology
from paths_cli.compiling.errors import InputError
from paths_cli.utils import import_thing
from paths_cli.compiling.plugins import CVCompilerPlugin, CategoryPlugin

Expand All @@ -21,6 +18,7 @@ def __call__(self, source):
# on ImportError, we leave the error unchanged
return func


def _cv_kwargs_remapper(dct):
kwargs = dct.pop('kwargs', {})
dct.update(kwargs)
Expand Down
8 changes: 5 additions & 3 deletions paths_cli/compiling/engines.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .topology import build_topology
from .core import Builder
from paths_cli.compiling.topology import build_topology
from paths_cli.compiling.core import Builder
from paths_cli.compiling.core import Parameter
from .tools import custom_eval_int_strict_pos
from paths_cli.compiling.tools import custom_eval_int_strict_pos
from paths_cli.compiling.plugins import EngineCompilerPlugin, CategoryPlugin


def load_openmm_xml(filename):
from paths_cli.compat.openmm import HAS_OPENMM, mm
if not HAS_OPENMM: # -no-cov-
Expand All @@ -14,6 +15,7 @@ def load_openmm_xml(filename):

return obj


def _openmm_options(dct):
n_steps_per_frame = dct.pop('n_steps_per_frame')
n_frames_max = dct.pop('n_frames_max')
Expand Down
1 change: 0 additions & 1 deletion paths_cli/compiling/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ def invalid_input(cls, value, attr):
@classmethod
def unknown_name(cls, type_name, name):
return cls(f"Unable to find object named {name} in {type_name}")

8 changes: 6 additions & 2 deletions paths_cli/compiling/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@
parameters=[
Parameter('cv', compiler_for('cv'), description="the collective "
"variable for this interface set"),
Parameter('minvals', custom_eval), # TODO fill in JSON types
Parameter('maxvals', custom_eval), # TODO fill in JSON types
Parameter('minvals', custom_eval), # TODO fill in JSON types
Parameter('maxvals', custom_eval), # TODO fill in JSON types
],
name='interface-set'
)


def mistis_trans_info(dct):
dct = dct.copy()
transitions = dct.pop('transitions')
Expand All @@ -31,6 +32,7 @@ def mistis_trans_info(dct):
dct['trans_info'] = trans_info
return dct


def tis_trans_info(dct):
# remap TIS into MISTIS format
dct = dct.copy()
Expand All @@ -42,6 +44,7 @@ def tis_trans_info(dct):
'interfaces': interface_set}]
return mistis_trans_info(dct)


TPS_NETWORK_PLUGIN = NetworkCompilerPlugin(
builder=Builder('openpathsampling.TPSNetwork'),
parameters=[
Expand All @@ -60,6 +63,7 @@ def tis_trans_info(dct):
name='mistis'
)


TIS_NETWORK_PLUGIN = NetworkCompilerPlugin(
builder=Builder('openpathsampling.MISTISNetwork'),
parameters=[Parameter('trans_info', tis_trans_info)],
Expand Down
8 changes: 7 additions & 1 deletion paths_cli/compiling/plugins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from paths_cli.compiling.core import InstanceCompilerPlugin
from paths_cli.plugin_management import OPSPlugin


class CategoryPlugin(OPSPlugin):
"""
Category plugins only need to be made for top-level
"""
def __init__(self, plugin_class, aliases=None, requires_ops=(1, 0),
requires_cli=(0,4)):
requires_cli=(0, 3)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems like an actual bugfix? should this be mentioned somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd argue that it isn't a bugfix on the basis that it isn't in released code. However, I edited the main post of this PR to make mention of the change.

super().__init__(requires_ops, requires_cli)
self.plugin_class = plugin_class
if aliases is None:
Expand All @@ -25,17 +26,22 @@ def __repr__(self):
class EngineCompilerPlugin(InstanceCompilerPlugin):
category = 'engine'


class CVCompilerPlugin(InstanceCompilerPlugin):
category = 'cv'


class VolumeCompilerPlugin(InstanceCompilerPlugin):
category = 'volume'


class NetworkCompilerPlugin(InstanceCompilerPlugin):
category = 'network'


class SchemeCompilerPlugin(InstanceCompilerPlugin):
category = 'scheme'


class StrategyCompilerPlugin(InstanceCompilerPlugin):
category = 'strategy'
13 changes: 12 additions & 1 deletion paths_cli/compiling/root_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import logging
logger = logging.getLogger(__name__)


class CategoryCompilerRegistrationError(Exception):
pass

Expand All @@ -22,6 +23,7 @@ class CategoryCompilerRegistrationError(Exception):

COMPILE_ORDER = _DEFAULT_COMPILE_ORDER.copy()


def clean_input_key(key):
# TODO: move this to core
"""
Expand All @@ -34,12 +36,14 @@ def clean_input_key(key):
key = key.replace("-", "_")
return key


### Managing known compilers and aliases to the known compilers ############

_COMPILERS = {} # mapping: {canonical_name: CategoryCompiler}
_ALIASES = {} # mapping: {alias: canonical_name}
# NOTE: _ALIASES does *not* include self-mapping of the canonical names


def _canonical_name(alias):
"""Take an alias or a compiler name and return the compiler name

Expand All @@ -51,6 +55,7 @@ def _canonical_name(alias):
alias_to_canonical.update({pname: pname for pname in _COMPILERS})
return alias_to_canonical.get(alias, None)


def _get_compiler(category):
"""
_get_compiler must only be used after the CategoryCompilers have been
Expand All @@ -69,6 +74,7 @@ def _get_compiler(category):
_COMPILERS[category] = CategoryCompiler(None, category)
return _COMPILERS[canonical_name]


def _register_compiler_plugin(plugin):
DUPLICATE_ERROR = CategoryCompilerRegistrationError(
f"The category '{plugin.name}' has been reserved by another plugin"
Expand All @@ -87,7 +93,7 @@ def _register_compiler_plugin(plugin):


### Handling delayed loading of compilers ##################################
#

# Many objects need to use compilers to create their input parameters. In
# order for them to be able to access dynamically-loaded plugins, we delay
# the loading of the compiler by using a proxy object.
Expand All @@ -111,6 +117,7 @@ def named_objs(self):
def __call__(self, dct):
return self._proxy(dct)


def compiler_for(category):
"""Delayed compiler calling.

Expand Down Expand Up @@ -142,11 +149,13 @@ def _get_registration_names(plugin):
found_names.add(name)
return ordered_names


def _register_builder_plugin(plugin):
compiler = _get_compiler(plugin.category)
for name in _get_registration_names(plugin):
compiler.register_builder(plugin, name)


def register_plugins(plugins):
builders = []
compilers = []
Expand All @@ -162,6 +171,7 @@ def register_plugins(plugins):
for plugin in builders:
_register_builder_plugin(plugin)


### Performing the compiling of user input #################################

def _sort_user_categories(user_categories):
Expand All @@ -179,6 +189,7 @@ def _sort_user_categories(user_categories):
)
return sorted_keys


def do_compile(dct):
"""Main function for compiling user input to objects.
"""
Expand Down
1 change: 1 addition & 0 deletions paths_cli/compiling/schemes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
name='spring-shooting',
)


class BuildSchemeStrategy:
def __init__(self, scheme_class, default_global_strategy):
self.scheme_class = scheme_class
Expand Down
4 changes: 4 additions & 0 deletions paths_cli/compiling/shooting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,21 @@
from paths_cli.compiling.root_compiler import compiler_for
from paths_cli.compiling.tools import custom_eval


build_uniform_selector = InstanceCompilerPlugin(
builder=Builder('openpathsampling.UniformSelector'),
parameters=[],
name='uniform',
)


def _remapping_gaussian_stddev(dct):
dct['alpha'] = 0.5 / dct.pop('stddev')**2
dct['collectivevariable'] = dct.pop('cv')
dct['l_0'] = dct.pop('mean')
return dct


build_gaussian_selector = InstanceCompilerPlugin(
builder=Builder('openpathsampling.GaussianBiasSelector',
remapper=_remapping_gaussian_stddev),
Expand All @@ -27,6 +30,7 @@ def _remapping_gaussian_stddev(dct):
name='gaussian',
)


shooting_selector_compiler = CategoryCompiler(
type_dispatch={
'uniform': build_uniform_selector,
Expand Down
5 changes: 4 additions & 1 deletion paths_cli/compiling/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,19 @@
)
from paths_cli.compiling.root_compiler import compiler_for


def _strategy_name(class_name):
return f"openpathsampling.strategies.{class_name}"


def _group_parameter(group_name):
return Parameter('group', str, default=group_name,
description="the group name for these movers")


# TODO: maybe this moves into shooting once we have the metadata?
SP_SELECTOR_PARAMETER = Parameter('selector', shooting_selector_compiler,
default=None)

ENGINE_PARAMETER = Parameter('engine', compiler_for('engine'),
description="the engine for moves of this "
"type")
Expand All @@ -39,6 +41,7 @@ def _group_parameter(group_name):
name='one-way-shooting',
)
build_one_way_shooting_strategy = ONE_WAY_SHOOTING_STRATEGY_PLUGIN

# build_two_way_shooting_strategy = StrategyCompilerPlugin(
# builder=Builder(_strategy_name("TwoWayShootingStrategy")),
# parameters = [
Expand Down
8 changes: 6 additions & 2 deletions paths_cli/compiling/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from .errors import InputError
from paths_cli.compiling.errors import InputError


def custom_eval(obj, named_objs=None):
"""Parse user input to allow simple math.
Expand All @@ -19,10 +20,12 @@ def custom_eval(obj, named_objs=None):
}
return eval(string, namespace)


def custom_eval_int(obj, named_objs=None):
val = custom_eval(obj, named_objs)
return int(val)


def custom_eval_int_strict_pos(obj, named_objs=None):
val = custom_eval_int(obj, named_objs)
if val <= 0:
Expand All @@ -33,6 +36,7 @@ def custom_eval_int_strict_pos(obj, named_objs=None):
class UnknownAtomsError(RuntimeError):
pass


def mdtraj_parse_atomlist(inp_str, n_atoms, topology=None):
"""
n_atoms: int
Expand All @@ -47,7 +51,7 @@ def mdtraj_parse_atomlist(inp_str, n_atoms, topology=None):
raise TypeError("Input is not integers")
if arr.shape != (1, n_atoms):
# try to clean it up
if len(arr.shape) == 1 and arr.shape[0] == n_atoms:
if len(arr.shape) == 1 and arr.shape[0] == n_atoms:
arr.shape = (1, n_atoms)
else:
raise TypeError(f"Invalid input. Requires {n_atoms} "
Expand Down
Loading