Skip to content

Commit

Permalink
TST: Improve the test coverage for the anchor input parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas van Beek committed Nov 29, 2021
1 parent 109826e commit ab2563e
Showing 1 changed file with 120 additions and 1 deletion.
121 changes: 120 additions & 1 deletion tests/test_ligand_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import sys
import math
import h5py
from shutil import rmtree
from os.path import join
Expand All @@ -14,8 +15,9 @@
from rdkit import Chem
from scm.plams import from_smiles, Molecule
from assertionlib import assertion
from schema import SchemaError

from CAT.utils import get_template
from CAT.utils import get_template, KindEnum, AnchorTup
from CAT.base import prep_input
from CAT.attachment.ligand_anchoring import (
get_functional_groups, _smiles_to_rdmol, find_substructure, init_ligand_anchoring
Expand Down Expand Up @@ -256,3 +258,120 @@ def test_init_ligand_anchoring() -> None:
rmtree(join(PATH, 'ligand'))
rmtree(join(PATH, 'qd'))
rmtree(join(PATH, 'database'))


class TestInputParsing:
PARAM_RAISE: "OrderedDict[str, tuple[Any, type[Exception]]]" = OrderedDict(
invalid_smiles=("test", ValueError),
idx_len_0=({"group": "OC", "group_idx": []}, SchemaError),
idx_len_2=({"group": "OC", "group_idx": 0, "angle_offset": 45}, ValueError),
idx_len_3=({"group": "OC", "group_idx": 0, "dihedral": 0.5}, ValueError),
duplicate_idx=({"group": "OC", "group_idx": [0, 0]}, SchemaError),
invalid_idx_type=({"group": "OC", "group_idx": 0.0}, SchemaError),
idx_intersection=({"group": "OC", "group_idx": 0, "remove": 0}, ValueError),
out_of_bounds_idx=({"group": "OC", "group_idx": 99}, IndexError),
out_of_bounds_remove=({"group": "OC", "group_idx": 0, "remove": 99}, IndexError),
angle_unit=({"group": "OC", "group_idx": 0, "angle_offset": "0.5 bob"}, SchemaError),
angle_invalid=({"group": "OC", "group_idx": 0, "angle_offset": "bob"}, SchemaError),
)

@pytest.mark.parametrize("inp,exc_type", PARAM_RAISE.values(), ids=PARAM_RAISE.keys())
def test_raise(self, inp: Any, exc_type: "type[Exception]") -> None:
with pytest.raises(exc_type):
parse_anchors(inp)

_PARAM_PASS1 = OrderedDict(
idx_scalar={"group": "OCC", "group_idx": 0},
idx_list={"group": "OCC", "group_idx": [0]},
list=[{"group": "OCC", "group_idx": 0}],
str=["O(C)[H]"],
angle_unit={"group": "OCC", "group_idx": range(3), "angle_offset": "1 rad"},
angle_no_unit={"group": "OCC", "group_idx": range(3), "angle_offset": "180"},
angle_none={"group": "OCC", "group_idx": range(3), "angle_offset": None},
angle_float={"group": "OCC", "group_idx": range(3), "angle_offset": 180.0},
remove_none={"group": "OCC", "group_idx": 0, "remove": None},
kind_none={"group": "OCC", "group_idx": 0, "kind": None},
kind_str={"group": "OCC", "group_idx": 0, "kind": "mean"},
kind_enum={"group": "OCC", "group_idx": 0, "kind": KindEnum.MEAN_TRANSLATE},
)
_PARAM_PASS2 = OrderedDict(
idx_scalar=AnchorTup(None, group="OCC", group_idx=(0,)),
idx_list=AnchorTup(None, group="OCC", group_idx=(0,)),
list=AnchorTup(None, group="OCC", group_idx=(0,)),
str=AnchorTup(None, group="O(C)[H]", group_idx=(0,), remove=(2,)),
angle_unit=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=1.0),
angle_no_unit=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=math.pi),
angle_none=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=None),
angle_float=AnchorTup(None, group="OCC", group_idx=(0, 1, 2), angle_offset=math.pi),
remove_none=AnchorTup(None, group="OCC", group_idx=(0,), remove=None),
kind_none=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.FIRST),
kind_str=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.MEAN),
kind_enum=AnchorTup(None, group="OCC", group_idx=(0,), kind=KindEnum.MEAN_TRANSLATE),
)
PARAM_PASS = OrderedDict({
k: (v1, v2) for (k, v1), v2 in zip(_PARAM_PASS1.items(), _PARAM_PASS2.values())
})

@pytest.mark.parametrize("inp,ref", PARAM_PASS.values(), ids=PARAM_PASS.keys())
def test_pass(self, inp: Any, ref: AnchorTup) -> None:
out_tup = parse_anchors(inp)
assertion.len_eq(out_tup, 1)
out = out_tup[0]

assertion.isinstance(out.mol, Chem.Mol)
if out.angle_offset is not None:
assertion.isclose(out.angle_offset, ref.angle_offset)
assertion.eq(out._replace(mol=None, angle_offset=None), ref._replace(angle_offset=None))

@pytest.mark.parametrize("split", [True, False], ids=["split", "no_split"])
def test_rdkit_mol(self, split: bool) -> None:
remove = (1,) if split else None
mol = _smiles_to_rdmol("[O-]C")
out = parse_anchors(mol, split=split)

ref = AnchorTup(mol, group=None, group_idx=(0,), remove=remove)
assertion.len_eq(out, 1)
assertion.eq(out[0], ref)

def test_anchor_tup(self) -> None:
ref = AnchorTup(_smiles_to_rdmol("[O-]C"), group="[O-]C", group_idx=(0,))
out = parse_anchors(ref)
assertion.len_eq(out, 1)
assertion.eq(out[0], ref)

SPLIT_REF = [
"C[N+].[F-]",
"C[N+].[Cl-]",
"C[N+].[Br-]",
"C[N+].[I-]",
"[H]NC",
"[H]PC",
"[H]OP",
"[H]OC",
"[H]SC",
"[H]OS",
]
NO_SPLIT_REF = [
"C[N+]",
"CN",
"C[N-]",
"CP",
"C[P-]",
"OP",
"[O-]P",
"CO",
"C[O-]",
"CS",
"C[S-]",
"OS",
"[O-]S",
]

@pytest.mark.parametrize("split,ref", [
(True, SPLIT_REF),
(False, NO_SPLIT_REF),
], ids=["split", "no_split"])
def test_none(self, split: bool, ref: "list[str]") -> None:
out = parse_anchors(split=split)
smiles = [Chem.MolToSmiles(tup.mol) for tup in out]
assertion.eq(smiles, ref)

0 comments on commit ab2563e

Please sign in to comment.