diff --git a/CAT/attachment/ligand_anchoring.py b/CAT/attachment/ligand_anchoring.py index 697280e7..596952d5 100644 --- a/CAT/attachment/ligand_anchoring.py +++ b/CAT/attachment/ligand_anchoring.py @@ -34,7 +34,7 @@ from rdkit import Chem from ..logger import logger -from ..utils import get_template, AnchorTup, KindEnum, get_formula, FormatEnum +from ..utils import get_template, AnchorTup, KindEnum, get_formula, FormatEnum, MultiAnchorEnum from ..mol_utils import separate_mod # noqa: F401 from ..workflows import MOL, FORMULA, HDF5_INDEX, OPT from ..settings_dataframe import SettingsDataFrame @@ -201,12 +201,24 @@ def find_substructure( ligand_idx_dict[anchor_tup].append(idx_tup) ref_set.add(anchor_idx_tup) + # Apply some further filtering to the ligands if condition is not None: - if not condition(sum((len(i) for i in ligand_idx_dict.values()), 0)): + if not condition(sum((len(j) for j in ligand_idx_dict.values()), 0)): err = (f"Failed to satisfy the passed condition ({condition!r}) for " f"ligand: {ligand.properties.name!r}") logger.error(err) return [] + else: + for anchor_tup, j in ligand_idx_dict.items(): + if anchor_tup.multi_anchor_filter == MultiAnchorEnum.ALL: + pass + elif anchor_tup.multi_anchor_filter == MultiAnchorEnum.FIRST and len(j) > 1: + ligand_idx_dict[anchor_tup] = j[:1] + elif anchor_tup.multi_anchor_filter == MultiAnchorEnum.RAISE and len(j) > 1: + logger.error( + f"Found multiple valid functional groups for {ligand.properties.name!r}" + ) + return [] ret = [] idx_dict_items = chain.from_iterable(zip(repeat(k), v) for k, v in ligand_idx_dict.items()) diff --git a/CAT/data_handling/anchor_parsing.py b/CAT/data_handling/anchor_parsing.py index 249531e6..c6f4e5bf 100644 --- a/CAT/data_handling/anchor_parsing.py +++ b/CAT/data_handling/anchor_parsing.py @@ -9,7 +9,7 @@ from schema import Schema, Use, Optional from typing_extensions import TypedDict, SupportsIndex -from ..utils import AnchorTup, KindEnum, FormatEnum +from ..utils import AnchorTup, KindEnum, FormatEnum, MultiAnchorEnum from ..attachment.ligand_anchoring import _smiles_to_rdmol, get_functional_groups __all__ = ["parse_anchors"] @@ -26,6 +26,7 @@ class _UnparsedAnchorDict(_UnparsedAnchorDictBase, total=False): dihedral: "None | SupportsFloat | SupportsIndex | bytes | str" kind: "None | str | KindEnum" group_format: "None | str | FormatEnum" + multi_anchor_filter: "None | str | MultiAnchorEnum" class _AnchorDict(TypedDict): @@ -35,7 +36,8 @@ class _AnchorDict(TypedDict): kind: KindEnum angle_offset: "None | float" dihedral: "None | float" - group_format: "FormatEnum" + group_format: FormatEnum + multi_anchor_filter: MultiAnchorEnum def _parse_group_idx(item: "SupportsIndex | Iterable[SupportsIndex]") -> Tuple[int, ...]: @@ -85,6 +87,17 @@ def _parse_group_format(typ: "None | str | FormatEnum") -> FormatEnum: raise TypeError("`group_format` expected None or a string") +def _parse_multi_anchor_filter(typ: "None | str | MultiAnchorEnum") -> MultiAnchorEnum: + """Parse the ``multi_anchor_filter`` option.""" + if typ is None: + return MultiAnchorEnum.ALL + elif isinstance(typ, MultiAnchorEnum): + return typ + elif isinstance(typ, str): + return MultiAnchorEnum[typ.upper()] + raise TypeError("`multi_anchor_filter` expected None or a string") + + _UNIT_PATTERN = re.compile(r"([\.\_0-9]+)(\s+)?(\w+)?") @@ -143,6 +156,7 @@ def _symbol_to_rdmol(symbol: str) -> Chem.Mol: Optional("angle_offset", default=None): Use(_parse_angle_offset), Optional("dihedral", default=None): Use(_parse_angle_offset), Optional("group_format", default=FormatEnum.SMILES): Use(_parse_group_format), + Optional("multi_anchor_filter", default=MultiAnchorEnum.ALL): Use(_parse_multi_anchor_filter), }) #: A collection of symbols used for different kinds of dummy atoms. diff --git a/CAT/utils.py b/CAT/utils.py index b791c32b..7a422e0b 100644 --- a/CAT/utils.py +++ b/CAT/utils.py @@ -550,6 +550,14 @@ class AllignmentEnum(enum.Enum): SURFACE = 1 +class MultiAnchorEnum(enum.Enum): + """An enum with different actions for when ligands with multiple anchors are found.""" + + ALL = 0 + FIRST = 1 + RAISE = 2 + + class AnchorTup(NamedTuple): """A named tuple with anchoring operation instructions.""" @@ -562,6 +570,7 @@ class AnchorTup(NamedTuple): angle_offset: "None | float" = None dihedral: "None | float" = None group_format: FormatEnum = FormatEnum.SMILES + multi_anchor_filter: MultiAnchorEnum = MultiAnchorEnum.ALL class AllignmentTup(NamedTuple): diff --git a/docs/4_optional.rst b/docs/4_optional.rst index e73eea13..83ccc8a7 100644 --- a/docs/4_optional.rst +++ b/docs/4_optional.rst @@ -654,6 +654,7 @@ Ligand * :attr:`anchor.kind` * :attr:`anchor.angle_offset` * :attr:`anchor.dihedral` + * :attr:`anchor.multi_anchor_filter` .. note:: @@ -790,6 +791,20 @@ Ligand but if so desired one can explicitly pass the unit: ``dihedral: "0.5 rad"``. + .. attribute:: optional.ligand.anchor.multi_anchor_filter + + :Parameter: * **Type** - :class:`str` + * **Default value** – :data:`"ALL"` + + How ligands with multiple valid anchor sites are to-be treated. + + Accepts one of the following options: + + * ``"all"``: Construct a new ligand for each valid anchor/ligand combination. + * ``"first"``: Pick only the first valid functional group, all others are ignored. + * ``"raise"``: Treat a ligand as invalid if it has multiple valid anchoring sites. + + .. attribute:: optional.ligand.split :Parameter: * **Type** - :class:`bool` diff --git a/tests/test_ligand_anchoring.py b/tests/test_ligand_anchoring.py index 4918e142..01f815cc 100644 --- a/tests/test_ligand_anchoring.py +++ b/tests/test_ligand_anchoring.py @@ -21,7 +21,7 @@ from schema import SchemaError from packaging.version import Version -from CAT.utils import get_template, KindEnum, AnchorTup, FormatEnum +from CAT.utils import get_template, KindEnum, AnchorTup, FormatEnum, MultiAnchorEnum from CAT.base import prep_input from CAT.attachment.ligand_anchoring import ( get_functional_groups, _smiles_to_rdmol, find_substructure, init_ligand_anchoring @@ -279,6 +279,7 @@ class TestInputParsing: invalid_group=({"group": "OC", "group_idx": 0, "group": 1.0}, SchemaError), invalid_group_format=({"group": "OC", "group_idx": 0, "group_format": 1}, SchemaError), invalid_kind=({"group": "OC", "group_idx": 0, "kind": 1}, SchemaError), + invalid_multi_anchor_filter=({"group": "OC", "group_idx": 0, "multi_anchor_filter": 1}, SchemaError), ) @pytest.mark.parametrize("inp,exc_type", PARAM_RAISE.values(), ids=PARAM_RAISE.keys()) @@ -323,6 +324,9 @@ def test_raise_core(self, inp: Any, exc_type: "type[Exception]") -> None: group_format_none={"group": "OCC", "group_idx": 0, "group_format": None}, group_format_str={"group": "OCC", "group_idx": 0, "group_format": "SMARTS"}, group_format_enum={"group": "OCC", "group_idx": 0, "group_format": FormatEnum.SMARTS}, + multi_anchor_filter_none={"group": "OCC", "group_idx": 0, "multi_anchor_filter": None}, + multi_anchor_filter_str={"group": "OCC", "group_idx": 0, "multi_anchor_filter": "ALL"}, + multi_anchor_filter_enum={"group": "OCC", "group_idx": 0, "multi_anchor_filter": MultiAnchorEnum.ALL}, ) _PARAM_PASS2 = OrderedDict( idx_scalar=AnchorTup(None, group="OCC", group_idx=(0,)), @@ -348,6 +352,9 @@ def test_raise_core(self, inp: Any, exc_type: "type[Exception]") -> None: group_format_none=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMILES), group_format_str=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMARTS), group_format_enum=AnchorTup(None, group="OCC", group_idx=(0,), group_format=FormatEnum.SMARTS), + multi_anchor_filter_none=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL), + multi_anchor_filter_str=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL), + multi_anchor_filter_enum=AnchorTup(None, group="OCC", group_idx=(0,), multi_anchor_filter=MultiAnchorEnum.ALL), ) PARAM_PASS = OrderedDict({ k: (v1, v2) for (k, v1), v2 in zip(_PARAM_PASS1.items(), _PARAM_PASS2.values())