Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make topoaa.molX and mol_ parameter expandable #350

Merged
merged 11 commits into from
Mar 4, 2022
2 changes: 2 additions & 0 deletions src/haddock/core/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,5 @@
valid_run_dir_chars = string.ascii_letters + string.digits + "._-/\\"

RUNDIR = "run_dir"

max_molecules_allowed = 20
67 changes: 67 additions & 0 deletions src/haddock/gear/expandable_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,18 @@
However, because `<name>_<integer>` is too much of a simple rule, we
need to define in this module which parameters are actualy expandable.
If you are developing here look for the `type_simplest_ep` variable.

4. Parameters that are expandable to the max number of molecules:

Those are the parameters starting with `mol`. For example:
`mol_fix_origin_1`, which refers to the `fix_origin` parameter for
molecule 1. These parameters are allowed to expand only to the maximum
of input molecules, and at most to the max number of molecules allowed.
"""
from copy import deepcopy
from functools import partial

from haddock.core.defaults import max_molecules_allowed
from haddock.core.exceptions import ConfigurationError


Expand Down Expand Up @@ -333,6 +341,65 @@ def read_simplest_expandable(expparams, config):
return new


def get_mol_parameters(config):
"""Identify expandable `mol` parameters."""
return set(param for param in config if is_mol_parameter(param))


def is_mol_parameter(param):
"""Identify if a parameter is a `mol` parameter."""
parts = param.split("_")
return param.startswith("mol_") \
and parts[-1].isdigit() \
and len(parts) > 2


def read_mol_parameters(
user_config,
default_groups,
max_mols=max_molecules_allowed,
):
"""
Read the mol parameters in the user_config following expectations.

Parameters
----------
user_config : dict
The user configuration dictionary.

default_groups : dict or set.
The mol parameters present in the default configuration file for
the specific module. These are defined by `get_mol_parameters`.

max_mols : int
HADDOCK3 has a limit in the number of different molecules it
accepts for a calculation. Expandable parameters affecting molecules
should not be allowed to go beyond that number. Defaults to
`core.default.max_molecules_allowed`.

Returns
-------
set
The allowed parameters according to the default config and the
max allowed molecules.
"""
# removes the integer suffix from the default mol parameters
default_names = [remove_trail_idx(p) for p in default_groups]

new = set()
for param in get_mol_parameters(user_config):
param_name = remove_trail_idx(param)
param_idx = param.split("_")[-1]
if param_name in default_names and int(param_idx) <= max_mols:
new.add(param)
return new


def remove_trail_idx(param):
"""Remove the trailing integer from a parameter."""
return "_".join(param.split("_")[:-1])


def make_param_name_single_index(param_parts):
"""
Make the key name from param parts.
Expand Down
136 changes: 115 additions & 21 deletions src/haddock/gear/prepare_run.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,28 @@
"""Logic pertraining to preparing the run files and folders."""
import importlib
import itertools as it
import shutil
import string
import sys
from contextlib import contextmanager, suppress
from copy import deepcopy
from functools import wraps
from functools import lru_cache, wraps
from pathlib import Path

from haddock import contact_us, haddock3_source_path, log
from haddock.core.defaults import RUNDIR
from haddock.core.defaults import RUNDIR, max_molecules_allowed
from haddock.core.exceptions import ConfigurationError, ModuleError
from haddock.gear.config_reader import get_module_name, read_config
from haddock.gear.expandable_parameters import (
get_mol_parameters,
get_multiple_index_groups,
get_single_index_groups,
is_mol_parameter,
read_mol_parameters,
read_multiple_idx_groups_user_config,
read_simplest_expandable,
read_single_idx_groups_user_config,
remove_trail_idx,
type_simplest_ep,
)
from haddock.gear.greetings import get_goodbye_help
Expand Down Expand Up @@ -55,6 +61,21 @@ def wrapper(*args, **kwargs):
return wrapper


@lru_cache
def _read_defaults(module_name):
"""Read the defaults.yaml given a module name."""
module_name_ = get_module_name(module_name)
pdef = Path(
haddock3_source_path,
'modules',
modules_category[module_name_],
module_name_,
'defaults.yaml',
).resolve()

return read_from_yaml_config(pdef)


def setup_run(workflow_path, restart_from=None):
"""
Set up HADDOCK3 run.
Expand Down Expand Up @@ -100,12 +121,18 @@ def setup_run(workflow_path, restart_from=None):

# copy molecules parameter to topology module
copy_molecules_to_topology(params)
if len(params["topoaa"]["molecules"]) > max_molecules_allowed:
raise ConfigurationError("Too many molecules defined, max is {max_molecules_allowed}.") # noqa: E501

# separate general from modules parameters
_modules_keys = identify_modules(params)
general_params = remove_dict_keys(params, _modules_keys)
modules_params = remove_dict_keys(params, list(general_params.keys()))

# populate topology molecules
populate_topology_molecule_params(modules_params["topoaa"])
populate_mol_parameters(modules_params)

# validations
validate_modules_params(modules_params)
check_if_modules_are_installed(modules_params)
Expand Down Expand Up @@ -170,27 +197,26 @@ def validate_modules_params(modules_params):
If there is any parameter given by the user that is not defined
in the defaults.cfg of the module.
"""
for module_name, args in modules_params.items():
module_name = get_module_name(module_name)
pdef = Path(
haddock3_source_path,
'modules',
modules_category[module_name],
module_name,
'defaults.yaml',
).resolve()
# needed definition before starting the loop
max_mols = len(modules_params["topoaa"]["molecules"])

defaults = read_from_yaml_config(pdef)
for module_name, args in modules_params.items():
defaults = _read_defaults(module_name)
if not defaults:
return

block_params = get_expandable_parameters(args, defaults, module_name)
expandable_params = get_expandable_parameters(
args,
defaults,
module_name,
max_mols,
)

diff = set(extract_keys_recursive(args)) \
- set(extract_keys_recursive(defaults)) \
- set(config_mandatory_general_parameters) \
- set(non_mandatory_general_parameters_defaults.keys()) \
- block_params
- expandable_params

if diff:
_msg = (
Expand Down Expand Up @@ -373,7 +399,7 @@ def check_specific_validations(params):
v_rundir(params[RUNDIR])


def get_expandable_parameters(user_config, defaults, module_name):
def get_expandable_parameters(user_config, defaults, module_name, max_mols):
"""
Get configuration expandable blocks.

Expand All @@ -384,35 +410,43 @@ def get_expandable_parameters(user_config, defaults, module_name):

defaults : dict
The default configuration file defined for the module.

module_name : str
The name the module being processed.

max_mols : int
The max number of molecules allowed.
"""
# the topoaa module is an exception because it has subdictionaries
# for the `mol` parameter. Instead of defining a general recursive
# function, I decided to add a simple if/else exception.
# no other module should have subdictionaries has parameters
if module_name == "topoaa":
ap = set() # allowed_parameters
ap.update(_get_blocks(user_config, defaults, module_name))
for i in range(1, 20):
ap.update(_get_expandable(user_config, defaults, module_name, max_mols))
for i in range(1, max_mols + 1):
key = f"mol{i}"
with suppress(KeyError):
ap.update(
_get_blocks(
_get_expandable(
user_config[key],
defaults[key],
defaults["mol1"],
module_name,
max_mols,
)
)

return ap

else:
return _get_blocks(user_config, defaults, module_name)
return _get_expandable(user_config, defaults, module_name, max_mols)


# reading parameter blocks
def _get_blocks(user_config, defaults, module_name):
def _get_expandable(user_config, defaults, module_name, max_mols):
type_1 = get_single_index_groups(defaults)
type_2 = get_multiple_index_groups(defaults)
type_4 = get_mol_parameters(defaults)

allowed_params = set()
allowed_params.update(read_single_idx_groups_user_config(user_config, type_1)) # noqa: E501
Expand All @@ -422,4 +456,64 @@ def _get_blocks(user_config, defaults, module_name):
type_3 = type_simplest_ep[module_name]
allowed_params.update(read_simplest_expandable(type_3, user_config))

_ = read_mol_parameters(user_config, type_4, max_mols=max_mols)
allowed_params.update(_)

return allowed_params


def populate_topology_molecule_params(topoaa):
"""Populate topoaa `molX` subdictionaries."""
topoaa_dft = _read_defaults("topoaa")

# list of possible prot_segids
uppers = list(string.ascii_uppercase)[::-1]

# removes from the list those prot_segids that are already defined
for param in topoaa:
if param.startswith("mol") and param[3:].isdigit():
with suppress(KeyError):
uppers.remove(topoaa[param]["prot_segid"])

# populates the prot_segids just for those that were not defined
# in the user configuration file. Other parameters are populated as
# well. `prot_segid` is the only one differing per molecule.
for i in range(1, len(topoaa["molecules"]) + 1):
mol = f"mol{i}"
if not(mol in topoaa and "prot_segid" in topoaa[mol]):
topoaa_dft["mol1"]["prot_segid"] = uppers.pop()

topoaa[mol] = recursive_dict_update(
topoaa_dft["mol1"],
topoaa[mol] if mol in topoaa else {},
)
return


def populate_mol_parameters(modules_params):
"""
Populate modules parameters.

Parameters
----------
modules_params : dict
A dictionary containing the parameters for all modules.

Returns
-------
None
Alter the dictionary in place.
"""
for module_name, _ in modules_params.items():
defaults = _read_defaults(module_name)

mol_params = (p for p in list(defaults.keys()) if is_mol_parameter(p))
num_mols = range(1, len(modules_params["topoaa"]["molecules"]) + 1)

for param, i in it.product(mol_params, num_mols):
param_name = remove_trail_idx(param)
modules_params[module_name].setdefault(
f"{param_name}_{i}",
defaults[param],
)
return
2 changes: 1 addition & 1 deletion src/haddock/libs/libutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def file_exists(

def recursive_dict_update(d, u):
"""
Update dictionary recursively.
Update dictionary `d` according to `u` recursively.

https://stackoverflow.com/questions/3232943

Expand Down
Loading