Skip to content

Commit

Permalink
ENH: Add support for parsing polyatomic core anchors
Browse files Browse the repository at this point in the history
  • Loading branch information
Bas van Beek committed Sep 15, 2022
1 parent 86840d9 commit 519d8d2
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 18 deletions.
54 changes: 39 additions & 15 deletions CAT/attachment/core_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from .perp_surface import get_surface_vec
from .distribution import distribute_idx
from ..utils import AllignmentEnum, AllignmentTup, AnchorTup
from ..utils import AllignmentEnum, AllignmentTup, AnchorTup, KindEnum

if TYPE_CHECKING:
from numpy.typing import NDArray
Expand All @@ -47,13 +47,16 @@ def set_core_anchors(
# Get the indices of all anchor atom ligand placeholders in the core
anchors = mol.properties.dummies
if anchors is None:
anchor_idx, remove_idx = find_core_substructure(mol, anchor_tup)
anchor_idx, anchor_idx_group, remove_idx = find_core_substructure(mol, anchor_tup)
else:
anchor_idx = np.fromiter(anchors, count=len(anchors), dtype=np.int64)
anchor_idx -= 1
remove_idx = anchor_idx.copy()
anchor_idx_group = anchor_idx.reshape(-1, 1).copy()
if subset_kwargs:
anchor_idx_old = anchor_idx
anchor_idx = distribute_idx(mol, anchor_idx, **subset_kwargs)
anchor_idx_group = remove_idx[np.isin(anchor_idx_old, anchor_idx_ol1)]
if not len(anchor_idx):
raise MoleculeError(f"No valid anchoring groups found in the core {formula!r}")

Expand All @@ -75,7 +78,9 @@ def set_core_anchors(
)

# Define all core vectors
mol.properties.core_vec = _get_core_vectors(mol, mol.properties.dummies, allignment_tup)
mol.properties.core_vec = _get_core_vectors(
mol, anchor_idx_group, remove_idx, allignment_tup, anchor_tup
)

# Delete all core anchor atoms
if remove_idx is not None:
Expand All @@ -88,19 +93,37 @@ def set_core_anchors(

def _get_core_vectors(
core: Molecule,
dummies: Iterable[Atom],
anchor_group_idx: "NDArray[i8]",
remove_idx: "None | NDArray[i8]",
allignment: AllignmentTup,
anchor_tup: AnchorTup,
) -> "NDArray[f8]":
"""Return a 2D array with all core (unit) vectors."""
anchor = Molecule.as_array(None, atom_subset=dummies)
core_ar = np.array(core)

# Put the (effective) coordinates of the anchors into an (n, 3) array
if anchor_tup.kind == KindEnum.FIRST:
anchor_atoms = core_ar[anchor_group_idx[:, anchor_tup.group_idx[0]]]
elif anchor_tup.kind == KindEnum.MEAN:
anchor_idx = anchor_group_idx[:, np.fromiter(anchor_tup.group_idx, np.int64)]
anchor_atoms = core_ar[_anchor_idx].mean(axis=1)
else:
raise ValueError

# Define vectors based on the various allignment options
if allignment.kind == AllignmentEnum.SPHERE:
vec = np.array(core.get_center_of_mass()) - anchor
vec = np.array(core.get_center_of_mass()) - anchor_atoms
vec /= np.linalg.norm(vec, axis=1)[..., None]
elif allignment.kind == AllignmentEnum.SURFACE:
vec = -get_surface_vec(np.array(core), anchor)
if remove_idx is None:
vec = -get_surface_vec(core_ar, anchor_atoms)
else:
no_anchor_mask = np.ones(len(core_ar), dtype=np.bool_)
no_anchor_mask[remove_idx] = False
vec = -get_surface_vec(core_ar[no_anchor_mask], anchor_atoms)
elif allignment.kind == AllignmentEnum.ANCHOR:
raise NotImplementedError
anchor_mean = core[anchor_group_idx].mean(axis=1)
vec = anchor_atoms - anchor_mean
else:
raise ValueError(f"Unknown allignment kind: {allignment.kind}")

Expand All @@ -112,7 +135,7 @@ def _get_core_vectors(
def find_core_substructure(
mol: Molecule,
anchor_tup: AnchorTup,
) -> Tuple["NDArray[i8]", "None | NDArray[i8]"]:
) -> Tuple["NDArray[i8]", "NDArray[i8]", "None | NDArray[i8]"]:
"""Identify substructures within the passed core based on **anchor_tup**.
Returns two indice-arrays, respectivelly containing the indices of the anchor
Expand All @@ -124,23 +147,24 @@ def find_core_substructure(
remove = anchor_tup.remove

# Remove all duplicate matches, each heteroatom (match[0]) should have <= 1 entry
ref_set = set()
ref_list = []
anchor_list = []
remove_list = []
for idx_tup in matches:
anchor_idx_tup = tuple(idx_tup[i] for i in anchor_tup.group_idx)
if anchor_idx_tup in ref_set:
if anchor_idx_tup in ref_list:
continue # Skip duplicates
else:
ref_set.add(anchor_idx_tup)
ref_list.append(anchor_idx_tup)

if remove is not None:
remove_list += [idx_tup[i] for i in remove]
anchor_list.append(anchor_idx_tup[0])

anchor_array = np.fromiter(anchor_list, dtype=np.int64, count=len(anchor_list))
anchor_idx_array = np.fromiter(anchor_list, dtype=np.int64, count=len(anchor_list))
anchor_group_array = np.array(ref_list, dtype=np.int64)
if remove is not None:
remove_array = np.fromiter(remove_list, dtype=np.int64, count=len(remove_list))
return anchor_array, remove_array
return anchor_idx_array, anchor_group_array, remove_array
else:
return anchor_array, None
return anchor_idx_array, anchor_group_array, None
4 changes: 2 additions & 2 deletions CAT/data_handling/anchor_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,8 @@ def parse_anchors(
raise TypeError("`dihedral != None` is not supported for core anchors")
elif angle_offset is not None:
raise TypeError("`angle_offset != None` is not supported for core anchors")
elif kwargs["kind"] != KindEnum.FIRST:
raise NotImplementedError('`kind != "first"` is not yet supported')
elif kwargs["kind"] == KindEnum.MEAN_TRANSLATE:
raise ValueError('`kind != "mean translate"` is not supported for core anchors')
else:
# Check that at least 3 atoms are available for `angle_offset`
# (so a plane can be defined)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ligand_anchoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_raise(self, inp: Any, exc_type: "type[Exception]") -> None:
angle_offset=({"group": "Cl", "group_idx": 0, "angle_offset": 90}, TypeError),
dihedral=({"group": "Cl", "group_idx": 0, "dihedral": 90}, TypeError),
multiple=(["OC", "OCC"], NotImplementedError),
kind=({"group": "Cl", "group_idx": 0, "kind": "mean"}, NotImplementedError),
kind=({"group": "Cl", "group_idx": 0, "kind": "mean_translate"}, ValueError),
)

@pytest.mark.parametrize("inp,exc_type", PARAM_RAISE_CORE.values(), ids=PARAM_RAISE_CORE)
Expand Down

0 comments on commit 519d8d2

Please sign in to comment.