Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TST: Improve the test coverage for the anchor input parsing #205

Merged
merged 4 commits into from
Nov 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 6 additions & 6 deletions CAT/data_handling/anchor_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import re
import operator
from typing import Union, Tuple, Collection, Iterable, SupportsFloat
from typing import Union, Tuple, Iterable, SupportsFloat

from rdkit.Chem import Mol
from scm.plams import Units
Expand All @@ -17,11 +17,11 @@

class _UnparsedAnchorDictBase(TypedDict):
group: str
anchor_idx: "SupportsIndex | Collection[SupportsIndex]"
anchor_idx: "SupportsIndex | Iterable[SupportsIndex]"


class _UnparsedAnchorDict(_UnparsedAnchorDictBase, total=False):
remove: "None | SupportsIndex | Collection[SupportsIndex]"
remove: "None | SupportsIndex | Iterable[SupportsIndex]"
angle_offset: "None | SupportsFloat | SupportsIndex | bytes | str"


Expand Down Expand Up @@ -108,14 +108,14 @@ def parse_anchors(
Mol,
AnchorTup,
_UnparsedAnchorDict,
"Collection[str | Mol | AnchorTup | _UnparsedAnchorDict]",
"Iterable[str | Mol | AnchorTup | _UnparsedAnchorDict]",
] = None,
split: bool = True,
) -> Tuple[AnchorTup, ...]:
"""Parse the user-specified anchors."""
if patterns is None:
patterns = get_functional_groups(None, split)
elif isinstance(patterns, (Mol, str, dict)):
elif isinstance(patterns, (Mol, str, dict, AnchorTup)):
patterns = [patterns]

ret = []
Expand Down Expand Up @@ -151,7 +151,7 @@ def parse_anchors(
# (so the third dihedral-defining vector can be defined)
dihedral = kwargs["dihedral"]
if dihedral is not None and len(group_idx) < 2:
raise ValueError("`group_idx` must contain at least 3 atoms when "
raise ValueError("`group_idx` must contain at least 2 atoms when "
"`dihedral` is specified")

# Check that the indices in `group_idx` and `remove` are not out of bounds
Expand Down
121 changes: 120 additions & 1 deletion tests/test_ligand_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import math
import h5py
from shutil import rmtree
from os.path import join
Expand All @@ -14,8 +15,9 @@
from rdkit import Chem
from scm.plams import from_smiles, Molecule
from assertionlib import assertion
from schema import SchemaError

from CAT.utils import get_template
from CAT.utils import get_template, KindEnum, AnchorTup
from CAT.base import prep_input
from CAT.attachment.ligand_anchoring import (
get_functional_groups, _smiles_to_rdmol, find_substructure, init_ligand_anchoring
Expand Down Expand Up @@ -256,3 +258,120 @@ def test_init_ligand_anchoring() -> None:
rmtree(join(PATH, 'ligand'))
rmtree(join(PATH, 'qd'))
rmtree(join(PATH, 'database'))


class TestInputParsing:
PARAM_RAISE: "OrderedDict[str, tuple[Any, type[Exception]]]" = OrderedDict(
invalid_smiles=("test", ValueError),
idx_len_0=({"group": "OC", "group_idx": []}, SchemaError),
idx_len_2=({"group": "OC", "group_idx": 0, "angle_offset": 45}, ValueError),
idx_len_3=({"group": "OC", "group_idx": 0, "dihedral": 0.5}, ValueError),
duplicate_idx=({"group": "OC", "group_idx": [0, 0]}, SchemaError),
invalid_idx_type=({"group": "OC", "group_idx": 0.0}, SchemaError),
idx_intersection=({"group": "OC", "group_idx": 0, "remove": 0}, ValueError),
out_of_bounds_idx=({"group": "OC", "group_idx": 99}, IndexError),
out_of_bounds_remove=({"group": "OC", "group_idx": 0, "remove": 99}, IndexError),
angle_unit=({"group": "OC", "group_idx": 0, "angle_offset": "0.5 bob"}, SchemaError),
angle_invalid=({"group": "OC", "group_idx": 0, "angle_offset": "bob"}, SchemaError),
)

@pytest.mark.parametrize("inp,exc_type", PARAM_RAISE.values(), ids=PARAM_RAISE.keys())
def test_raise(self, inp: Any, exc_type: "type[Exception]") -> None:
with pytest.raises(exc_type):
parse_anchors(inp)

_PARAM_PASS1 = OrderedDict(
idx_scalar={"group": "OCC", "group_idx": 0},
idx_list={"group": "OCC", "group_idx": [0]},
list=[{"group": "OCC", "group_idx": 0}],
str=["O(C)[H]"],
angle_unit={"group": "OCC", "group_idx": range(3), "angle_offset": "1 rad"},
angle_no_unit={"group": "OCC", "group_idx": range(3), "angle_offset": "180"},
angle_none={"group": "OCC", "group_idx": range(3), "angle_offset": None},
angle_float={"group": "OCC", "group_idx": range(3), "angle_offset": 180.0},
remove_none={"group": "OCC", "group_idx": 0, "remove": None},
kind_none={"group": "OCC", "group_idx": 0, "kind": None},
kind_str={"group": "OCC", "group_idx": 0, "kind": "mean"},
kind_enum={"group": "OCC", "group_idx": 0, "kind": KindEnum.MEAN_TRANSLATE},
)
_PARAM_PASS2 = OrderedDict(
idx_scalar=AnchorTup(None, group="OCC", group_idx=(0,)),
idx_list=AnchorTup(None, group="OCC", group_idx=(0,)),
list=AnchorTup(None, group="OCC", group_idx=(0,)),
str=AnchorTup(None, group="O(C)[H]", group_idx=(0,), remove=(2,)),
angle_unit=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=1.0),
angle_no_unit=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=math.pi),
angle_none=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=None),
angle_float=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=math.pi),
remove_none=AnchorTup(None, group="OCC", group_idx=(0,), remove=None),
kind_none=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.FIRST),
kind_str=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.MEAN),
kind_enum=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.MEAN_TRANSLATE),
)
PARAM_PASS = OrderedDict({
k: (v1, v2) for (k, v1), v2 in zip(_PARAM_PASS1.items(), _PARAM_PASS2.values())
})

@pytest.mark.parametrize("inp,ref", PARAM_PASS.values(), ids=PARAM_PASS.keys())
def test_pass(self, inp: Any, ref: AnchorTup) -> None:
out_tup = parse_anchors(inp)
assertion.len_eq(out_tup, 1)
out = out_tup[0]

assertion.isinstance(out.mol, Chem.Mol)
if out.angle_offset is not None:
assertion.isclose(out.angle_offset, ref.angle_offset)
assertion.eq(out._replace(mol=None, angle_offset=None), ref._replace(angle_offset=None))

@pytest.mark.parametrize("split", [True, False], ids=["split", "no_split"])
def test_rdkit_mol(self, split: bool) -> None:
remove = (1,) if split else None
mol = _smiles_to_rdmol("[O-]C")
out = parse_anchors(mol, split=split)

ref = AnchorTup(mol, group=None, group_idx=(0,), remove=remove)
assertion.len_eq(out, 1)
assertion.eq(out[0], ref)

def test_anchor_tup(self) -> None:
ref = AnchorTup(_smiles_to_rdmol("[O-]C"), group="[O-]C", group_idx=(0,))
out = parse_anchors(ref)
assertion.len_eq(out, 1)
assertion.eq(out[0], ref)

SPLIT_REF = [
"C[N+].[F-]",
"C[N+].[Cl-]",
"C[N+].[Br-]",
"C[N+].[I-]",
"[H]NC",
"[H]PC",
"[H]OP",
"[H]OC",
"[H]SC",
"[H]OS",
]
NO_SPLIT_REF = [
"C[N+]",
"CN",
"C[N-]",
"CP",
"C[P-]",
"OP",
"[O-]P",
"CO",
"C[O-]",
"CS",
"C[S-]",
"OS",
"[O-]S",
]

@pytest.mark.parametrize("split,ref", [
(True, SPLIT_REF),
(False, NO_SPLIT_REF),
], ids=["split", "no_split"])
def test_none(self, split: bool, ref: "list[str]") -> None:
out = parse_anchors(split=split)
smiles = [Chem.MolToSmiles(tup.mol) for tup in out]
assertion.eq(smiles, ref)