diff --git a/neutronpy/form_facs.py b/neutronpy/form_facs.py index 42e10c6..40a0aae 100644 --- a/neutronpy/form_facs.py +++ b/neutronpy/form_facs.py @@ -1,5 +1,6 @@ import neutronpy.constants as const import numpy as np +from numbers import Number class _Atom(object): @@ -133,23 +134,38 @@ def calc_str_fac(self, hkl): h, k, l = hkl # Determines shape of input variables to build FQ = 0 array - if type(h) is float and type(k) is float and type(l) is float: - FQ = 0. * 1j - else: - # if one dimension is zero, flatten FQ to 2D - if type(h) is float: - FQ = np.zeros(k.shape) * 1j - elif type(k) is float: - FQ = np.zeros(l.shape) * 1j - elif type(l) is float: - FQ = np.zeros(h.shape) * 1j - else: - FQ = np.zeros(h.shape) * 1j + dims = [] + for x in hkl: + if isinstance(x, np.ndarray): + dims.append(x.shape) + elif isinstance(x, (list, tuple)): + dims.append((len(x),)) + elif isinstance(x, Number): + dims.append(1) + + dims_str = np.array([str(x) for x in dims], dtype=str) + if np.unique(dims_str).size == 1: + FQ = np.zeros(dims[0]) + elif np.unique(dims_str[np.where([isinstance(x, tuple) for x in dims])]).size == 1: + FQ = np.zeros(dims[np.where([isinstance(x, tuple) for x in dims])[0][0]]) + else: + raise ValueError("Dimensions of 'hkl' elements are not compatible. An " \ + "element must be either the same shape as an other " \ + "non-decimal element, or a decimal number.") + + # Ensures input arrays are complex ndarrays + if isinstance(h, (np.ndarray, list, tuple)): + h = np.array(h).astype(complex, casting='unsafe') + if isinstance(k, (np.ndarray, list, tuple)): + k = np.array(k).astype(complex, casting='unsafe') + if isinstance(l, (np.ndarray, list, tuple)): + l = np.array(l).astype(complex, casting='unsafe') + # construct structure factor for atom in self.atoms: FQ += atom.b * np.exp(1j * 2. * np.pi * (h * atom.pos[0] + k * atom.pos[1] + l * atom.pos[2])) * \ - np.exp(-(2. * np.pi * (h * atom.dpos[0] + k * atom.dpos[1] + l * atom.dpos[2])) ** 2) + np.exp(-(2. * np.pi * (h * atom.dpos[0] + k * atom.dpos[1] + l * atom.dpos[2])) ** 2) return FQ diff --git a/tests/test_form_facs.py b/tests/test_form_facs.py index 860b1d7..f95348b 100644 --- a/tests/test_form_facs.py +++ b/tests/test_form_facs.py @@ -1,4 +1,5 @@ from neutronpy import form_facs +import numpy as np import unittest @@ -18,7 +19,27 @@ def __init__(self, *args, **kwargs): def test_str_fac(self): structure = form_facs.Material(self.input) - self.assertAlmostEqual(abs(structure.calc_str_fac((2., 0., 0.)) ** 2), 1583878.155915682, 6) + self.assertAlmostEqual(np.abs(structure.calc_str_fac((2., 0., 0.))) ** 2, 1583878.155915682, 6) + self.assertAlmostEqual(np.abs(structure.calc_str_fac((2, 0, 0))) ** 2, 1583878.155915682, 6) + self.assertAlmostEqual(np.abs(structure.calc_str_fac((0, 2., 0))) ** 2, 1583878.155915682, 6) + self.assertAlmostEqual(np.abs(structure.calc_str_fac((0, 2, 0))) ** 2, 1583878.155915682, 6) + + ndarray_example = np.linspace(0.5, 1.5, 21) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((ndarray_example, 0, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, ndarray_example, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, 0, ndarray_example))) ** 2), 16294175.79743738, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((ndarray_example, ndarray_example, 0))) ** 2), 5572585.110405569, 6) + + + list_example = list(ndarray_example) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((list_example, 0, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, list_example, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, 0, list_example))) ** 2), 16294175.79743738, 6) + + tuple_example = tuple(ndarray_example) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((tuple_example, 0, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, tuple_example, 0))) ** 2), 1261922.414836668, 6) + self.assertAlmostEqual(np.sum(abs(structure.calc_str_fac((0, 0, tuple_example))) ** 2), 16294175.79743738, 6) class MagneticFormFactor(unittest.TestCase):