Skip to content

Commit

Permalink
TST: Add tests for the new dihedral option
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas van Beek committed Nov 19, 2021
1 parent 1c3204b commit 7c17563
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 5 deletions.
24 changes: 24 additions & 0 deletions tests/test_files/CAT_dihedral.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
path: .

input_cores:
- Cd68Se55.xyz

input_ligands:
- 'CC(O)=O'

optional:
core:
dirname: core
anchor: Cl

ligand:
dirname: ligand
optimize: True
anchor:
group: "[H]OC=O"
group_idx: [1, 2, 3]
remove: 0
dihedral: 45

qd:
dirname: QD
Binary file added tests/test_files/test_dihedral_dihed_180.npy
Binary file not shown.
Binary file added tests/test_files/test_dihedral_dihed_45.npy
Binary file not shown.
Binary file added tests/test_files/test_dihedral_dihed_45_deg.npy
Binary file not shown.
71 changes: 66 additions & 5 deletions tests/test_ligand_attach.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
"""Tests for :mod:`CAT.attachment.ligand_attach`."""

from os.path import join
import sys
import shutil
from pathlib import Path
from typing import Generator, NamedTuple, TYPE_CHECKING

import yaml
import pytest
import numpy as np

from assertionlib import assertion
from scm.plams import Settings, Molecule

from CAT.base import prep
from CAT.attachment.ligand_attach import (_get_rotmat1, _get_rotmat2)
from CAT.workflows import MOL

if TYPE_CHECKING:
import _pytest

PATH = join('tests', 'test_files')
PATH = Path('tests') / 'test_files'

LIG_PATH = PATH / 'ligand'
QD_PATH = PATH / 'qd'
DB_PATH = PATH / 'database'


def test_get_rotmat1() -> None:
Expand Down Expand Up @@ -50,11 +64,58 @@ def test_get_rotmat2() -> None:
vec1 = np.array([1, 0, 0], dtype=float)
vec2 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=float)

ref1 = np.load(join(PATH, 'rotmat2_1.npy'))
ref2 = np.load(join(PATH, 'rotmat2_2.npy'))
ref1 = np.load(PATH / 'rotmat2_1.npy')
ref2 = np.load(PATH / 'rotmat2_2.npy')

rotmat1 = _get_rotmat2(vec1)
np.testing.assert_allclose(rotmat1, ref1)

rotmat2 = _get_rotmat2(vec2)
np.testing.assert_allclose(rotmat2, ref2)


class DihedTup(NamedTuple):
mol: Molecule
ref: np.recarray


class TestDihedral:
PARAMS = {
"dihed_45": 45,
"dihed_45_deg": "45 deg",
"dihed_180": 180.0,
}

@pytest.fixture(scope="class", autouse=True, name="output", params=PARAMS.items(), ids=PARAMS)
def run_cat(self, request: "_pytest.fixtures.SubRequest") -> Generator[DihedTup, None, None]:
# Setup
name, dihed = request.param # type: str, str | float
yaml_path = PATH / 'CAT_dihedral.yaml'
with open(yaml_path, 'r') as f:
arg = Settings(yaml.load(f, Loader=yaml.FullLoader))

arg.path = PATH
arg.optional.ligand.anchor.dihedral = dihed
qd_df, _, _ = prep(arg)
qd = qd_df[MOL].iloc[0]

ref = np.load(PATH / f"test_dihedral_{name}.npy").view(np.recarray)
yield DihedTup(qd, ref)

# Teardown
files = [LIG_PATH, QD_PATH, DB_PATH]
for f in files:
shutil.rmtree(f, ignore_errors=True)

def test_atoms(self, output: DihedTup) -> None:
dtype = [("symbols", "U2"), ("coords", "f8", 3)]
atoms = np.fromiter(
[(at.symbol, at.coords) for at in output.mol], dtype=dtype
).view(np.recarray)

assertion.eq(atoms.dtype, output.ref.dtype)
np.testing.assert_array_equal(atoms.symbols, output.ref.symbols)

if sys.version_info >= (3, 9):
pytest.xfail("Geometries must be updated for RDKit >2019.09.2")
np.testing.assert_allclose(atoms.coords, output.ref.coords)

0 comments on commit 7c17563

Please sign in to comment.