Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions dmff/admp/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def generate_construct_local_frames(axis_types, axis_indices):
Bisector_filter = (axis_types == Bisector)
ZBisect_filter = (axis_types == ZBisect)
ThreeFold_filter = (axis_types == ThreeFold)
NoAxisType_filter = (axis_types == NoAxisType)

def construct_local_frames(positions, box):
'''
Expand Down Expand Up @@ -139,6 +140,13 @@ def construct_local_frames(positions, box):
vec_x = normalize(vec_x - vec_z * xz_projection, axis=1)
# up to this point, x-axis should be ready
vec_y = jnp.cross(vec_z, vec_x)

# NoAxisType
if np.sum(NoAxisType_filter) > 0:
vec_y = vec_y.at[NoAxisType_filter].set(jnp.array([0,1,0]))
vec_z = vec_z.at[NoAxisType_filter].set(jnp.array([0,0,1]))
vec_x = vec_x.at[NoAxisType_filter].set(jnp.array([1,0,0]))


return jnp.stack((vec_x, vec_y, vec_z), axis=1)

Expand Down
50 changes: 50 additions & 0 deletions tests/test_admp/test_noaxistype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import jax.numpy as jnp
import numpy as np
import numpy.testing as npt
import pytest
from dmff.admp.spatial import (build_quasi_internal,
generate_construct_local_frames, pbc_shift,
v_pbc_shift)


class TestSpatial:

@pytest.mark.parametrize(
"axis_types, axis_indices, positions, box, expected_local_frames",
[
(
np.array([5]),
np.array(
[
[-1, -1, -1],
]
),
jnp.array(
[
[0.992, 0.068, -0.073],
]
),
jnp.array([[50.000, 0.0, 0.0], [0.0, 50.000, 0.0], [0.0, 0.0, 50.000]]),
np.array(
[
[
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
],
]
),
)
],
)
def test_generate_construct_local_frames(
self, axis_types, axis_indices, positions, box, expected_local_frames
):
construct_local_frame_fn = generate_construct_local_frames(
axis_types, axis_indices
)
assert construct_local_frame_fn
npt.assert_allclose(
construct_local_frame_fn(positions, box), expected_local_frames, rtol=1e-5
)