diff --git a/requirements/test.txt b/requirements/test.txt index a7277865..c31d2af1 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -4,3 +4,4 @@ codecov coverage pytest-cov pytest-env +sympy diff --git a/tests/__init__.py b/tests/__init__.py index 7f498f76..e69de29b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,85 +0,0 @@ -#!/usr/bin/env python -############################################################################## -# -# diffpy.srfit by DANSE Diffraction group -# Simon J. L. Billinge -# (c) 2010 The Trustees of Columbia University -# in the City of New York. All rights reserved. -# -# File coded by: Pavol Juhas -# -# See AUTHORS.txt for a list of people who contributed. -# See LICENSE_DANSE.txt for license information. -# -############################################################################## -"""Unit tests for diffpy.srfit.""" - -import logging -import unittest - -# create logger instance for the tests subpackage -logging.basicConfig() -logger = logging.getLogger(__name__) -del logging - - -def testsuite(pattern=""): - """Create a unit tests suite for diffpy.srfit package. - - Parameters - ---------- - pattern : str, optional - Regular expression pattern for selecting test cases. - Select all tests when empty. Ignore the pattern when - any of unit test modules fails to import. - - Returns - ------- - suite : `unittest.TestSuite` - The TestSuite object containing the matching tests. - """ - import re - from itertools import chain - from os.path import dirname - - from pkg_resources import resource_filename - - loader = unittest.defaultTestLoader - thisdir = resource_filename(__name__, "") - depth = __name__.count(".") + 1 - topdir = thisdir - for i in range(depth): - topdir = dirname(topdir) - suite_all = loader.discover(thisdir, top_level_dir=topdir) - # always filter the suite by pattern to test-cover the selection code. - suite = unittest.TestSuite() - rx = re.compile(pattern) - tsuites = list(chain.from_iterable(suite_all)) - tsok = all(isinstance(ts, unittest.TestSuite) for ts in tsuites) - if not tsok: # pragma: no cover - return suite_all - tcases = chain.from_iterable(tsuites) - for tc in tcases: - tcwords = tc.id().split(".") - shortname = ".".join(tcwords[-3:]) - if rx.search(shortname): - suite.addTest(tc) - # verify all tests are found for an empty pattern. - assert pattern or suite_all.countTestCases() == suite.countTestCases() - return suite - - -def test(): - """Execute all unit tests for the diffpy.srfit package. - - Returns - ------- - result : `unittest.TestResult` - """ - suite = testsuite() - runner = unittest.TextTestRunner() - result = runner.run(suite) - return result - - -# End of file diff --git a/tests/conftest.py b/tests/conftest.py index e3b63139..bd53d38d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,10 +1,92 @@ +import importlib.resources import json +import logging +import sys +from functools import lru_cache from pathlib import Path import pytest +import six +import diffpy.srfit.equation.literals as literals +from diffpy.srfit.sas.sasimport import sasimport -@pytest.fixture +logger = logging.getLogger(__name__) + + +@lru_cache() +def has_sas(): + try: + sasimport("sas.pr.invertor") + sasimport("sas.models") + return True + except ImportError: + return False + + +# diffpy.structure +@lru_cache() +def has_diffpy_structure(): + try: + import diffpy.structure as m + + del m + return True + except ImportError: + return False + logger.warning( + "Cannot import diffpy.structure, Structure tests skipped." + ) + + +@lru_cache() +def has_pyobjcryst(): + try: + import pyobjcryst as m + + del m + return True + except ImportError: + return False + logger.warning("Cannot import pyobjcryst, pyobjcryst tests skipped.") + + +# diffpy.srreal + + +@lru_cache() +def has_diffpy_srreal(): + try: + import diffpy.srreal.pdfcalculator as m + + del m + return True + except ImportError: + return False + logger.warning("Cannot import diffpy.srreal, PDF tests skipped.") + + +@pytest.fixture(scope="session") +def sas_available(): + return has_sas() + + +@pytest.fixture(scope="session") +def diffpy_structure_available(): + return has_diffpy_structure() + + +@pytest.fixture(scope="session") +def diffpy_srreal_available(): + return has_diffpy_srreal() + + +@pytest.fixture(scope="session") +def pyobjcryst_available(): + return has_pyobjcryst() + + +@pytest.fixture(scope="session") def user_filesystem(tmp_path): base_dir = Path(tmp_path) home_dir = base_dir / "home_dir" @@ -17,3 +99,61 @@ def user_filesystem(tmp_path): json.dump(home_config_data, f) yield tmp_path + + +@pytest.fixture(scope="session") +def datafile(): + """Fixture to load a test data file from the testdata package directory.""" + + def _datafile(filename): + return importlib.resources.files("tests.testdata").joinpath(filename) + + return _datafile + + +@pytest.fixture(scope="session") +def make_args(): + def _makeArgs(num): + args = [] + for i in range(num): + j = i + 1 + args.append(literals.Argument(name="v%i" % j, value=j)) + return args + + return _makeArgs + + +@pytest.fixture(scope="session") +def noObserversInGlobalBuilders(): + def _noObserversInGlobalBuilders(): + """True if no observer function leaks to global builder objects. + + Ensure objects are not immortal due to a reference from static + value. + """ + from diffpy.srfit.equation.builder import _builders + + rv = True + for n, b in _builders.items(): + if b.literal and b.literal._observers: + rv = False + break + return rv + + return _noObserversInGlobalBuilders() + + +@pytest.fixture(scope="session") +def capturestdout(): + def _capturestdout(f, *args, **kwargs): + """Capture the standard output from a call of function f.""" + savestdout = sys.stdout + fp = six.StringIO() + try: + sys.stdout = fp + f(*args, **kwargs) + finally: + sys.stdout = savestdout + return fp.getvalue() + + return _capturestdout diff --git a/tests/test_builder.py b/tests/test_builder.py index 81f5eb9f..dfabc8fb 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -14,274 +14,273 @@ ############################################################################## """Tests for refinableobj module.""" -import unittest - import numpy +import pytest import diffpy.srfit.equation.builder as builder import diffpy.srfit.equation.literals as literals -from .utils import _makeArgs, noObserversInGlobalBuilders - -class TestBuilder(unittest.TestCase): +def testRegisterArg(make_args, noObserversInGlobalBuilders): - def testRegisterArg(self): + factory = builder.EquationFactory() - factory = builder.EquationFactory() + v1 = make_args(1)[0] - v1 = _makeArgs(1)[0] + b1 = factory.registerArgument("v1", v1) + assert factory.builders["v1"] is b1 + assert b1.literal is v1 - b1 = factory.registerArgument("v1", v1) - self.assertTrue(factory.builders["v1"] is b1) - self.assertTrue(b1.literal is v1) + eq = factory.makeEquation("v1") - eq = factory.makeEquation("v1") + assert v1 is eq.args[0] + assert 1 == len(eq.args) - self.assertTrue(v1 is eq.args[0]) - self.assertEqual(1, len(eq.args)) + # Try to parse an equation with buildargs turned off + with pytest.raises(ValueError): + factory.makeEquation("v1 + v2", False) - # Try to parse an equation with buildargs turned off - self.assertRaises(ValueError, factory.makeEquation, "v1 + v2", False) + # Make sure we can still use constants + eq = factory.makeEquation("v1 + 2", False) + assert v1 is eq.args[0] + assert 1 == len(eq.args) + assert noObserversInGlobalBuilders + return - # Make sure we can still use constants - eq = factory.makeEquation("v1 + 2", False) - self.assertTrue(v1 is eq.args[0]) - self.assertEqual(1, len(eq.args)) - self.assertTrue(noObserversInGlobalBuilders()) - return - def testRegisterOperator(self): - """Try to use an operator without arguments in an equation.""" +def testRegisterOperator(make_args, noObserversInGlobalBuilders): + """Try to use an operator without arguments in an equation.""" - factory = builder.EquationFactory() - v1, v2, v3, v4 = _makeArgs(4) + factory = builder.EquationFactory() + v1, v2, v3, v4 = make_args(4) - op = literals.AdditionOperator() + op = literals.AdditionOperator() - op.addLiteral(v1) - op.addLiteral(v2) + op.addLiteral(v1) + op.addLiteral(v2) - factory.registerArgument("v3", v3) - factory.registerArgument("v4", v4) - factory.registerOperator("op", op) + factory.registerArgument("v3", v3) + factory.registerArgument("v4", v4) + factory.registerOperator("op", op) - # Build an equation where op is treated as a terminal node - eq = factory.makeEquation("op") - self.assertAlmostEqual(3, eq()) + # Build an equation where op is treated as a terminal node + eq = factory.makeEquation("op") + assert 3 == pytest.approx(eq()) - eq = factory.makeEquation("v3*op") - self.assertAlmostEqual(9, eq()) + eq = factory.makeEquation("v3*op") + assert 9 == pytest.approx(eq()) - # Now use the op like a function - eq = factory.makeEquation("op(v3, v4)") - self.assertAlmostEqual(7, eq()) + # Now use the op like a function + eq = factory.makeEquation("op(v3, v4)") + assert 7 == pytest.approx(eq()) - # Make sure we can still access op as itself. - eq = factory.makeEquation("op") - self.assertAlmostEqual(3, eq()) + # Make sure we can still access op as itself. + eq = factory.makeEquation("op") + assert 3 == pytest.approx(eq()) - self.assertTrue(noObserversInGlobalBuilders()) - return - - def testSwapping(self): + assert noObserversInGlobalBuilders + return - def g1(v1, v2, v3, v4): - return (v1 + v2) * (v3 + v4) - def g2(v1): - return 0.5 * v1 - - factory = builder.EquationFactory() - v1, v2, v3, v4, v5 = _makeArgs(5) - - factory.registerArgument("v1", v1) - factory.registerArgument("v2", v2) - factory.registerArgument("v3", v3) - factory.registerArgument("v4", v4) - b = factory.registerFunction("g", g1, ["v1", "v2", "v3", "v4"]) - - # Now associate args with the wrapped function - op = b.literal - self.assertTrue(op.operation == g1) - self.assertTrue(v1 in op.args) - self.assertTrue(v2 in op.args) - self.assertTrue(v3 in op.args) - self.assertTrue(v4 in op.args) - self.assertAlmostEqual(21, op.value) - - eq1 = factory.makeEquation("g") - self.assertTrue(eq1.root is op) - self.assertAlmostEqual(21, eq1()) - - # Swap out an argument by registering it under a taken name - b = factory.registerArgument("v4", v5) - self.assertTrue(factory.builders["v4"] is b) - self.assertTrue(b.literal is v5) - self.assertTrue(op._value is None) - self.assertTrue(op.args == [v1, v2, v3, v5]) - self.assertAlmostEqual(24, eq1()) - - # Now swap out the function - b = factory.registerFunction("g", g2, ["v1"]) - op = b.literal - self.assertTrue(op.operation == g2) - self.assertTrue(v1 in op.args) - self.assertTrue(eq1.root is op) - self.assertAlmostEqual(0.5, op.value) - self.assertAlmostEqual(0.5, eq1()) - - # Make an equation - eqeq = factory.makeEquation("v1 + v2") - # Register this "g" - b = factory.registerFunction("g", eqeq, eqeq.argdict.keys()) - op = b.literal - self.assertTrue(v1 in op.args) - self.assertTrue(v2 in op.args) - self.assertTrue(eq1.root is op) - self.assertAlmostEqual(3, op.value) - self.assertAlmostEqual(3, eq1()) - - self.assertTrue(noObserversInGlobalBuilders()) - return - - def testParseEquation(self): - - from numpy import array_equal, divide, e, sin, sqrt - - factory = builder.EquationFactory() - - # Scalar equation - eq = factory.makeEquation("A*sin(0.5*x)+divide(B,C)") - A = 1 - x = numpy.pi - B = 4.0 - C = 2.0 - eq.A.setValue(A) - eq.x.setValue(x) - eq.B.setValue(B) - eq.C.setValue(C) - f = lambda A, x, B, C: A * sin(0.5 * x) + divide(B, C) - self.assertTrue(array_equal(eq(), f(A, x, B, C))) - - # Make sure that the arguments of eq are listed in the order in which - # they appear in the equations. - self.assertEqual(eq.args, [eq.A, eq.x, eq.B, eq.C]) - - # Vector equation - eq = factory.makeEquation("sqrt(e**(-0.5*(x/sigma)**2))") - x = numpy.arange(0, 1, 0.05) - sigma = 0.1 - eq.x.setValue(x) - eq.sigma.setValue(sigma) - f = lambda x, sigma: sqrt(e ** (-0.5 * (x / sigma) ** 2)) - self.assertTrue(numpy.allclose(eq(), f(x, sigma))) - - self.assertEqual(eq.args, [eq.x, eq.sigma]) - - # Equation with constants - factory.registerConstant("x", x) - eq = factory.makeEquation("sqrt(e**(-0.5*(x/sigma)**2))") - self.assertTrue("sigma" in eq.argdict) - self.assertTrue("x" not in eq.argdict) - self.assertTrue(numpy.allclose(eq(sigma=sigma), f(x, sigma))) - - self.assertEqual(eq.args, [eq.sigma]) - - # Equation with user-defined functions - factory.registerFunction("myfunc", eq, ["sigma"]) - eq2 = factory.makeEquation("c*myfunc(sigma)") - self.assertTrue(numpy.allclose(eq2(c=2, sigma=sigma), 2 * f(x, sigma))) - self.assertTrue("sigma" in eq2.argdict) - self.assertTrue("c" in eq2.argdict) - self.assertEqual(eq2.args, [eq2.c, eq2.sigma]) - - self.assertTrue(noObserversInGlobalBuilders()) - return - - def test_parse_constant(self): - """Verify parsing of constant numeric expressions.""" - factory = builder.EquationFactory() - eq = factory.makeEquation("3.12 + 2") - self.assertTrue(isinstance(eq, builder.Equation)) - self.assertEqual(set(), factory.equations) - self.assertEqual(5.12, eq()) - self.assertRaises(ValueError, eq, 3) - return - - def testBuildEquation(self): - - from numpy import array_equal - - # simple equation - sin = builder.getBuilder("sin") - a = builder.ArgumentBuilder(name="a", value=1) - A = builder.ArgumentBuilder(name="A", value=2) - x = numpy.arange(0, numpy.pi, 0.1) - - beq = A * sin(a * x) - eq = beq.getEquation() - - self.assertTrue("a" in eq.argdict) - self.assertTrue("A" in eq.argdict) - self.assertTrue(array_equal(eq(), 2 * numpy.sin(x))) - - self.assertEqual(eq.args, [eq.A, eq.a]) - - # Check the number of arguments - self.assertRaises(ValueError, sin) - - # custom function - def _f(a, b): - return (a - b) * 1.0 / (a + b) - - f = builder.wrapFunction("f", _f, 2, 1) - a = builder.ArgumentBuilder(name="a", value=2) - b = builder.ArgumentBuilder(name="b", value=1) - - beq = sin(f(a, b)) - eq = beq.getEquation() - self.assertEqual(eq(), numpy.sin(_f(2, 1))) - - # complex function - sqrt = builder.getBuilder("sqrt") - e = numpy.e - _x = numpy.arange(0, 1, 0.05) - x = builder.ArgumentBuilder(name="x", value=_x, const=True) - sigma = builder.ArgumentBuilder(name="sigma", value=0.1) - beq = sqrt(e ** (-0.5 * (x / sigma) ** 2)) - eq = beq.getEquation() - f = lambda x, sigma: sqrt(e ** (-0.5 * (x / sigma) ** 2)) - self.assertTrue( - numpy.allclose(eq(), numpy.sqrt(e ** (-0.5 * (_x / 0.1) ** 2))) - ) - - # Equation with Equation - A = builder.ArgumentBuilder(name="A", value=2) - B = builder.ArgumentBuilder(name="B", value=4) - beq = A + B - eq = beq.getEquation() - E = builder.wrapOperator("eq", eq) - eq2 = (2 * E).getEquation() - # Make sure these evaluate to the same thing - self.assertEqual(eq.args, [A.literal, B.literal]) - self.assertEqual(2 * eq(), eq2()) - # Pass new arguments to the equation - C = builder.ArgumentBuilder(name="C", value=5) - D = builder.ArgumentBuilder(name="D", value=6) - eq3 = (E(C, D) + 1).getEquation() - self.assertEqual(12, eq3()) - # Pass old and new arguments to the equation - # If things work right, A has been given the value of C in the last - # evaluation (5) - eq4 = (3 * E(A, D) - 1).getEquation() - self.assertEqual(32, eq4()) - # Try to pass the wrong number of arguments - self.assertRaises(ValueError, E, A) - self.assertRaises(ValueError, E, A, B, C) - - self.assertTrue(noObserversInGlobalBuilders()) - return +def testSwapping(make_args, noObserversInGlobalBuilders): + + def g1(v1, v2, v3, v4): + return (v1 + v2) * (v3 + v4) + + def g2(v1): + return 0.5 * v1 + + factory = builder.EquationFactory() + v1, v2, v3, v4, v5 = make_args(5) + + factory.registerArgument("v1", v1) + factory.registerArgument("v2", v2) + factory.registerArgument("v3", v3) + factory.registerArgument("v4", v4) + b = factory.registerFunction("g", g1, ["v1", "v2", "v3", "v4"]) + + # Now associate args with the wrapped function + op = b.literal + assert op.operation == g1 + assert v1 in op.args + assert v2 in op.args + assert v3 in op.args + assert v4 in op.args + assert round(abs(21 - op.value), 7) == 0 + + eq1 = factory.makeEquation("g") + assert eq1.root is op + assert round(abs(21 - eq1()), 7) == 0 + + # Swap out an argument by registering it under a taken name + b = factory.registerArgument("v4", v5) + assert factory.builders["v4"] is b + assert b.literal is v5 + assert op._value is None + assert op.args == [v1, v2, v3, v5] + assert round(abs(24 - eq1()), 7) == 0 + + # Now swap out the function + b = factory.registerFunction("g", g2, ["v1"]) + op = b.literal + assert op.operation == g2 + assert v1 in op.args + assert eq1.root is op + assert round(abs(0.5 - op.value), 7) == 0 + assert round(abs(0.5 - eq1()), 7) == 0 + + # Make an equation + eqeq = factory.makeEquation("v1 + v2") + # Register this "g" + b = factory.registerFunction("g", eqeq, eqeq.argdict.keys()) + op = b.literal + assert v1 in op.args + assert v2 in op.args + assert eq1.root is op + assert round(abs(3 - op.value), 7) == 0 + assert round(abs(3 - eq1()), 7) == 0 + assert noObserversInGlobalBuilders + return + + +def testParseEquation(noObserversInGlobalBuilders): + + from numpy import array_equal, divide, e, sin, sqrt + + factory = builder.EquationFactory() + + # Scalar equation + eq = factory.makeEquation("A*sin(0.5*x)+divide(B,C)") + A = 1 + x = numpy.pi + B = 4.0 + C = 2.0 + eq.A.setValue(A) + eq.x.setValue(x) + eq.B.setValue(B) + eq.C.setValue(C) + f = lambda A, x, B, C: A * sin(0.5 * x) + divide(B, C) + assert array_equal(eq(), f(A, x, B, C)) + + # Make sure that the arguments of eq are listed in the order in which + # they appear in the equations. + assert eq.args == [eq.A, eq.x, eq.B, eq.C] + + # Vector equation + eq = factory.makeEquation("sqrt(e**(-0.5*(x/sigma)**2))") + x = numpy.arange(0, 1, 0.05) + sigma = 0.1 + eq.x.setValue(x) + eq.sigma.setValue(sigma) + f = lambda x, sigma: sqrt(e ** (-0.5 * (x / sigma) ** 2)) + assert numpy.allclose(eq(), f(x, sigma)) + + assert eq.args == [eq.x, eq.sigma] + + # Equation with constants + factory.registerConstant("x", x) + eq = factory.makeEquation("sqrt(e**(-0.5*(x/sigma)**2))") + assert "sigma" in eq.argdict + assert "x" not in eq.argdict + assert numpy.allclose(eq(sigma=sigma), f(x, sigma)) + assert eq.args == [eq.sigma] + + # Equation with user-defined functions + factory.registerFunction("myfunc", eq, ["sigma"]) + eq2 = factory.makeEquation("c*myfunc(sigma)") + assert numpy.allclose(eq2(c=2, sigma=sigma), 2 * f(x, sigma)) + assert "sigma" in eq2.argdict + assert "c" in eq2.argdict + assert eq2.args == [eq2.c, eq2.sigma] + assert noObserversInGlobalBuilders + return + + +def test_parse_constant(): + """Verify parsing of constant numeric expressions.""" + factory = builder.EquationFactory() + eq = factory.makeEquation("3.12 + 2") + assert isinstance(eq, builder.Equation) + assert set() == factory.equations + assert 5.12 == eq() + with pytest.raises(ValueError): + eq(3) + return + + +def testBuildEquation(noObserversInGlobalBuilders): + + from numpy import array_equal + + # simple equation + sin = builder.getBuilder("sin") + a = builder.ArgumentBuilder(name="a", value=1) + A = builder.ArgumentBuilder(name="A", value=2) + x = numpy.arange(0, numpy.pi, 0.1) + + beq = A * sin(a * x) + eq = beq.getEquation() + + assert "a" in eq.argdict + assert "A" in eq.argdict + assert array_equal(eq(), 2 * numpy.sin(x)) + + assert eq.args == [eq.A, eq.a] + + # Check the number of arguments + with pytest.raises(ValueError): + sin() + + # custom function + def _f(a, b): + return (a - b) * 1.0 / (a + b) + + f = builder.wrapFunction("f", _f, 2, 1) + a = builder.ArgumentBuilder(name="a", value=2) + b = builder.ArgumentBuilder(name="b", value=1) + + beq = sin(f(a, b)) + eq = beq.getEquation() + assert eq() == numpy.sin(_f(2, 1)) + + # complex function + sqrt = builder.getBuilder("sqrt") + e = numpy.e + _x = numpy.arange(0, 1, 0.05) + x = builder.ArgumentBuilder(name="x", value=_x, const=True) + sigma = builder.ArgumentBuilder(name="sigma", value=0.1) + beq = sqrt(e ** (-0.5 * (x / sigma) ** 2)) + eq = beq.getEquation() + f = lambda x, sigma: sqrt(e ** (-0.5 * (x / sigma) ** 2)) + assert numpy.allclose(eq(), numpy.sqrt(e ** (-0.5 * (_x / 0.1) ** 2))) + + # Equation with Equation + A = builder.ArgumentBuilder(name="A", value=2) + B = builder.ArgumentBuilder(name="B", value=4) + beq = A + B + eq = beq.getEquation() + E = builder.wrapOperator("eq", eq) + eq2 = (2 * E).getEquation() + # Make sure these evaluate to the same thing + assert eq.args == [A.literal, B.literal] + assert 2 * eq() == eq2() + # Pass new arguments to the equation + C = builder.ArgumentBuilder(name="C", value=5) + D = builder.ArgumentBuilder(name="D", value=6) + eq3 = (E(C, D) + 1).getEquation() + assert 12 == eq3() + # Pass old and new arguments to the equation + # If things work right, A has been given the value of C in the last + # evaluation (5) + eq4 = (3 * E(A, D) - 1).getEquation() + assert 32 == eq4() + # Try to pass the wrong number of arguments + with pytest.raises(ValueError): + E(A) + with pytest.raises(ValueError): + E(A, B, C) + assert noObserversInGlobalBuilders + return if __name__ == "__main__": diff --git a/tests/test_characteristicfunctions.py b/tests/test_characteristicfunctions.py index ddd02c32..dd846baf 100644 --- a/tests/test_characteristicfunctions.py +++ b/tests/test_characteristicfunctions.py @@ -17,136 +17,138 @@ import unittest import numpy +import pytest +import diffpy.srfit.pdf.characteristicfunctions as cf from diffpy.srfit.sas.sasimport import sasimport -from .utils import _msg_nosas, has_sas - # Global variables to be assigned in setUp cf = None # ---------------------------------------------------------------------------- -@unittest.skipUnless(has_sas, _msg_nosas) -class TestSASCF(unittest.TestCase): - - def setUp(self): - global cf - import diffpy.srfit.pdf.characteristicfunctions as cf - - return - - def testSphere(self): - radius = 25 - # Calculate sphere cf from SphereModel - SphereModel = sasimport("sas.models.SphereModel").SphereModel - model = SphereModel() - model.setParam("radius", radius) - ff = cf.SASCF("sphere", model) - r = numpy.arange(1, 60, 0.1, dtype=float) - fr1 = ff(r) - - # Calculate sphere cf analytically - fr2 = cf.sphericalCF(r, 2 * radius) - diff = fr1 - fr2 - res = numpy.dot(diff, diff) - res /= numpy.dot(fr2, fr2) - self.assertAlmostEqual(0, res, 4) - return - - def testSpheroid(self): - prad = 20.9 - erad = 33.114 - # Calculate cf from EllipsoidModel - EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel - model = EllipsoidModel() - model.setParam("radius_a", prad) - model.setParam("radius_b", erad) - ff = cf.SASCF("spheroid", model) - r = numpy.arange(0, 100, 1 / numpy.pi, dtype=float) - fr1 = ff(r) - - # Calculate cf analytically - fr2 = cf.spheroidalCF(r, erad, prad) - diff = fr1 - fr2 - res = numpy.dot(diff, diff) - res /= numpy.dot(fr2, fr2) - self.assertAlmostEqual(0, res, 4) - return - - def testShell(self): - radius = 19.2 - thickness = 7.8 - # Calculate cf from VesicleModel - VesicleModel = sasimport("sas.models.VesicleModel").VesicleModel - model = VesicleModel() - model.setParam("radius", radius) - model.setParam("thickness", thickness) - ff = cf.SASCF("vesicle", model) - r = numpy.arange(0, 99.45, 0.1, dtype=float) - fr1 = ff(r) - - # Calculate sphere cf analytically - fr2 = cf.shellCF(r, radius, thickness) - diff = fr1 - fr2 - res = numpy.dot(diff, diff) - res /= numpy.dot(fr2, fr2) - self.assertAlmostEqual(0, res, 4) - return - - def testCylinder(self): - """Make sure cylinder works over different r-ranges.""" - radius = 100 - length = 30 - - CylinderModel = sasimport("sas.models.CylinderModel").CylinderModel - model = CylinderModel() - model.setParam("radius", radius) - model.setParam("length", length) - - ff = cf.SASCF("cylinder", model) - - r1 = numpy.arange(0, 10, 0.1, dtype=float) - r2 = numpy.arange(0, 50, 0.1, dtype=float) - r3 = numpy.arange(0, 100, 0.1, dtype=float) - r4 = numpy.arange(0, 500, 0.1, dtype=float) - - fr1 = ff(r1) - fr2 = ff(r2) - fr3 = ff(r3) - fr4 = ff(r4) - - d = fr1 - numpy.interp(r1, r2, fr2) - res12 = numpy.dot(d, d) - res12 /= numpy.dot(fr1, fr1) - self.assertAlmostEqual(0, res12, 4) - - d = fr1 - numpy.interp(r1, r3, fr3) - res13 = numpy.dot(d, d) - res13 /= numpy.dot(fr1, fr1) - self.assertAlmostEqual(0, res13, 4) - - d = fr1 - numpy.interp(r1, r4, fr4) - res14 = numpy.dot(d, d) - res14 /= numpy.dot(fr1, fr1) - self.assertAlmostEqual(0, res14, 4) - - d = fr2 - numpy.interp(r2, r3, fr3) - res23 = numpy.dot(d, d) - res23 /= numpy.dot(fr2, fr2) - self.assertAlmostEqual(0, res23, 4) - - d = fr2 - numpy.interp(r2, r4, fr4) - res24 = numpy.dot(d, d) - res24 /= numpy.dot(fr2, fr2) - self.assertAlmostEqual(0, res24, 4) - - d = fr3 - numpy.interp(r3, r4, fr4) - res34 = numpy.dot(d, d) - res34 /= numpy.dot(fr3, fr3) - self.assertAlmostEqual(0, res34, 4) - return +def testSphere(sas_available): + if not sas_available: + pytest.skip("sas package not available") + radius = 25 + # Calculate sphere cf from SphereModel + SphereModel = sasimport("sas.models.SphereModel").SphereModel + model = SphereModel() + model.setParam("radius", radius) + ff = cf.SASCF("sphere", model) + r = numpy.arange(1, 60, 0.1, dtype=float) + fr1 = ff(r) + + # Calculate sphere cf analytically + fr2 = cf.sphericalCF(r, 2 * radius) + diff = fr1 - fr2 + res = numpy.dot(diff, diff) + res /= numpy.dot(fr2, fr2) + assert res == pytest.approx(0, abs=1e-4) + return + + +def testSpheroid(sas_available): + if not sas_available: + pytest.skip("sas package not available") + prad = 20.9 + erad = 33.114 + # Calculate cf from EllipsoidModel + EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel + model = EllipsoidModel() + model.setParam("radius_a", prad) + model.setParam("radius_b", erad) + ff = cf.SASCF("spheroid", model) + r = numpy.arange(0, 100, 1 / numpy.pi, dtype=float) + fr1 = ff(r) + + # Calculate cf analytically + fr2 = cf.spheroidalCF(r, erad, prad) + diff = fr1 - fr2 + res = numpy.dot(diff, diff) + res /= numpy.dot(fr2, fr2) + assert res == pytest.approx(0, abs=1e-4) + return + + +def testShell(sas_available): + if not sas_available: + pytest.skip("sas package not available") + radius = 19.2 + thickness = 7.8 + # Calculate cf from VesicleModel + VesicleModel = sasimport("sas.models.VesicleModel").VesicleModel + model = VesicleModel() + model.setParam("radius", radius) + model.setParam("thickness", thickness) + ff = cf.SASCF("vesicle", model) + r = numpy.arange(0, 99.45, 0.1, dtype=float) + fr1 = ff(r) + + # Calculate sphere cf analytically + fr2 = cf.shellCF(r, radius, thickness) + diff = fr1 - fr2 + res = numpy.dot(diff, diff) + res /= numpy.dot(fr2, fr2) + assert res == pytest.approx(0, abs=1e-4) + return + + +def testCylinder(sas_available): + if not sas_available: + pytest.skip("sas package not available") + """Make sure cylinder works over different r-ranges.""" + radius = 100 + length = 30 + + CylinderModel = sasimport("sas.models.CylinderModel").CylinderModel + model = CylinderModel() + model.setParam("radius", radius) + model.setParam("length", length) + + ff = cf.SASCF("cylinder", model) + + r1 = numpy.arange(0, 10, 0.1, dtype=float) + r2 = numpy.arange(0, 50, 0.1, dtype=float) + r3 = numpy.arange(0, 100, 0.1, dtype=float) + r4 = numpy.arange(0, 500, 0.1, dtype=float) + + fr1 = ff(r1) + fr2 = ff(r2) + fr3 = ff(r3) + fr4 = ff(r4) + + d = fr1 - numpy.interp(r1, r2, fr2) + res12 = numpy.dot(d, d) + res12 /= numpy.dot(fr1, fr1) + assert res12 == pytest.approx(0, abs=1e-4) + + d = fr1 - numpy.interp(r1, r3, fr3) + res13 = numpy.dot(d, d) + res13 /= numpy.dot(fr1, fr1) + assert res13 == pytest.approx(0, abs=1e-4) + + d = fr1 - numpy.interp(r1, r4, fr4) + res14 = numpy.dot(d, d) + res14 /= numpy.dot(fr1, fr1) + assert res14 == pytest.approx(0, abs=1e-4) + + d = fr2 - numpy.interp(r2, r3, fr3) + res23 = numpy.dot(d, d) + res23 /= numpy.dot(fr2, fr2) + assert res23 == pytest.approx(0, abs=1e-4) + + d = fr2 - numpy.interp(r2, r4, fr4) + res24 = numpy.dot(d, d) + res24 /= numpy.dot(fr2, fr2) + assert res24 == pytest.approx(0, abs=1e-4) + + d = fr3 - numpy.interp(r3, r4, fr4) + res34 = numpy.dot(d, d) + res34 /= numpy.dot(fr3, fr3) + assert res34 == pytest.approx(0, abs=1e-4) + return # End of class TestSASCF diff --git a/tests/test_contribution.py b/tests/test_contribution.py index 60616a8b..4461647d 100644 --- a/tests/test_contribution.py +++ b/tests/test_contribution.py @@ -13,9 +13,10 @@ # ############################################################################## """Tests for refinableobj module.""" - import unittest +import numpy as np +import pytest from numpy import arange, array_equal, dot, sin from diffpy.srfit.exceptions import SrFitError @@ -24,8 +25,6 @@ from diffpy.srfit.fitbase.profile import Profile from diffpy.srfit.fitbase.profilegenerator import ProfileGenerator -from .utils import noObserversInGlobalBuilders - class TestContribution(unittest.TestCase): @@ -144,108 +143,6 @@ def testReplacements(self): self.assertEqual(len(xobs2), len(fc.residual())) return - def testResidual(self): - """Test the residual, which requires all other methods.""" - fc = self.fitcontribution - profile = self.profile - gen = self.gen - - # Add the calculator and profile - fc.setProfile(profile) - self.assertTrue(fc.profile is profile) - fc.addProfileGenerator(gen, "I") - self.assertTrue(fc._eq._value is None) - self.assertTrue(fc._reseq._value is None) - self.assertEqual(1, len(fc._generators)) - self.assertTrue(gen.name in fc._generators) - - # Let's create some data - xobs = arange(0, 10, 0.5) - yobs = xobs - profile.setObservedProfile(xobs, yobs) - - # Check our fitting equation. - self.assertTrue(array_equal(fc._eq(), gen(xobs))) - - # Now calculate the residual - chiv = fc.residual() - self.assertAlmostEqual(0, dot(chiv, chiv)) - - # Now change the equation - fc.setEquation("2*I") - self.assertTrue(fc._eq._value is None) - self.assertTrue(fc._reseq._value is None) - chiv = fc.residual() - self.assertAlmostEqual(dot(yobs, yobs), dot(chiv, chiv)) - - # Try to add a parameter - c = Parameter("c", 2) - fc._addParameter(c) - fc.setEquation("c*I") - self.assertTrue(fc._eq._value is None) - self.assertTrue(fc._reseq._value is None) - chiv = fc.residual() - self.assertAlmostEqual(dot(yobs, yobs), dot(chiv, chiv)) - - # Try something more complex - c.setValue(3) - fc.setEquation("c**2*sin(I)") - self.assertTrue(fc._eq._value is None) - self.assertTrue(fc._reseq._value is None) - xobs = arange(0, 10, 0.5) - yobs = 9 * sin(xobs) - profile.setObservedProfile(xobs, yobs) - self.assertTrue(fc._eq._value is None) - self.assertTrue(fc._reseq._value is None) - - chiv = fc.residual() - self.assertAlmostEqual(0, dot(chiv, chiv)) - - # Choose a new residual. - fc.setEquation("2*I") - fc.setResidualEquation("resv") - chiv = fc.residual() - self.assertAlmostEqual( - sum((2 * xobs - yobs) ** 2) / sum(yobs**2), dot(chiv, chiv) - ) - - # Make a custom residual. - fc.setResidualEquation("abs(eq-y)**0.5") - chiv = fc.residual() - self.assertAlmostEqual(sum(abs(2 * xobs - yobs)), dot(chiv, chiv)) - - # Test configuration checks - fc1 = FitContribution("test1") - self.assertRaises(SrFitError, fc1.setResidualEquation, "chiv") - fc1.setProfile(self.profile) - self.assertRaises(SrFitError, fc1.setResidualEquation, "chiv") - fc1.setEquation("A * x") - fc1.setResidualEquation("chiv") - self.assertTrue(noObserversInGlobalBuilders()) - return - - def test_setEquation(self): - """Check replacement of removed parameters.""" - fc = self.fitcontribution - fc.setEquation("x + 5") - fc.x.setValue(2) - self.assertEqual(7, fc.evaluate()) - fc.removeParameter(fc.x) - x = arange(0, 10, 0.5) - fc.newParameter("x", x) - self.assertTrue(array_equal(5 + x, fc.evaluate())) - self.assertTrue(noObserversInGlobalBuilders()) - return - - def test_getEquation(self): - """Check getting the current profile simulation formula.""" - fc = self.fitcontribution - self.assertEqual("", fc.getEquation()) - fc.setEquation("A * sin(x + 5)") - self.assertEqual("(A * sin((x + 5)))", fc.getEquation()) - self.assertTrue(noObserversInGlobalBuilders()) - return - def test_getResidualEquation(self): """Check getting the current formula for residual equation.""" fc = self.fitcontribution @@ -287,5 +184,112 @@ def test_registerFunction(self): return +def testResidual(noObserversInGlobalBuilders): + """Test the residual, which requires all other methods.""" + gen = ProfileGenerator("test") + profile = Profile() + fc = FitContribution("test") + + # Add the calculator and profile + fc.setProfile(profile) + assert fc.profile is profile + fc.addProfileGenerator(gen, "I") + assert fc._eq._value is None + assert fc._reseq._value is None + assert 1 == len(fc._generators) + assert gen.name in fc._generators + + # Let's create some data) + xobs = arange(0, 10, 0.5) + yobs = xobs + profile.setObservedProfile(xobs, yobs) + + # Check our fitting equation. + assert np.array_equal(fc._eq(), gen(xobs)) + + # Now calculate the residual + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx(0) + + # Now change the equation + fc.setEquation("2*I") + assert fc._eq._value is None + assert fc._reseq._value is None + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx(dot(yobs, yobs)) + + # Try to add a parameter + c = Parameter("c", 2) + fc._addParameter(c) + fc.setEquation("c*I") + assert fc._eq._value is None + assert fc._reseq._value is None + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx(dot(yobs, yobs)) + + # Try something more complex + c.setValue(3) + fc.setEquation("c**2*sin(I)") + assert fc._eq._value is None + assert fc._reseq._value is None + xobs = arange(0, 10, 0.5) + yobs = 9 * sin(xobs) + profile.setObservedProfile(xobs, yobs) + assert fc._eq._value is None + assert fc._reseq._value is None + + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx(0) + + # Choose a new residual. + fc.setEquation("2*I") + fc.setResidualEquation("resv") + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx( + sum((2 * xobs - yobs) ** 2) / sum(yobs**2) + ) + + # Make a custom residual. + fc.setResidualEquation("abs(eq-y)**0.5") + chiv = fc.residual() + assert dot(chiv, chiv) == pytest.approx(sum(abs(2 * xobs - yobs))) + + # Test configuration checks + fc1 = FitContribution("test1") + with pytest.raises(SrFitError): + fc1.setResidualEquation("chiv") + fc1.setProfile(profile) + with pytest.raises(SrFitError): + fc1.setResidualEquation("chiv") + fc1.setEquation("A * x") + fc1.setResidualEquation("chiv") + assert noObserversInGlobalBuilders + return + + +def test_setEquation(noObserversInGlobalBuilders): + """Check replacement of removed parameters.""" + fc = FitContribution("test") + fc.setEquation("x + 5") + fc.x.setValue(2) + assert 7 == fc.evaluate() + fc.removeParameter(fc.x) + x = arange(0, 10, 0.5) + fc.newParameter("x", x) + assert np.array_equal(5 + x, fc.evaluate()) + assert noObserversInGlobalBuilders + return + + +def test_getEquation(noObserversInGlobalBuilders): + """Check getting the current profile simulation formula.""" + fc = FitContribution("test") + assert "" == fc.getEquation() + fc.setEquation("A * sin(x + 5)") + assert "(A * sin((x + 5)))" == fc.getEquation() + assert noObserversInGlobalBuilders + return + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_diffpyparset.py b/tests/test_diffpyparset.py index b4868de8..8463d80d 100644 --- a/tests/test_diffpyparset.py +++ b/tests/test_diffpyparset.py @@ -17,122 +17,121 @@ import pickle import unittest -import numpy - -from .utils import _msg_nostructure, has_structure - -# Global variables to be assigned in setUp -Atom = Lattice = Structure = DiffpyStructureParSet = None - -# ---------------------------------------------------------------------------- - - -@unittest.skipUnless(has_structure, _msg_nostructure) -class TestParameterAdapter(unittest.TestCase): - - def setUp(self): - global Atom, Lattice, Structure, DiffpyStructureParSet - from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet - from diffpy.structure import Atom, Lattice, Structure - - return - - def testDiffpyStructureParSet(self): - """Test the structure conversion.""" - - a1 = Atom("Cu", xyz=numpy.array([0.0, 0.1, 0.2]), Uisoequiv=0.003) - a2 = Atom("Ag", xyz=numpy.array([0.3, 0.4, 0.5]), Uisoequiv=0.002) - l = Lattice(2.5, 2.5, 2.5, 90, 90, 90) - - dsstru = Structure([a1, a2], l) - # Structure makes copies - a1 = dsstru[0] - a2 = dsstru[1] - - s = DiffpyStructureParSet("CuAg", dsstru) - - self.assertEqual(s.name, "CuAg") - - def _testAtoms(): - # Check the atoms thoroughly - self.assertEqual(a1.element, s.Cu0.element) - self.assertEqual(a2.element, s.Ag0.element) - self.assertEqual(a1.Uisoequiv, s.Cu0.Uiso.getValue()) - self.assertEqual(a2.Uisoequiv, s.Ag0.Uiso.getValue()) - self.assertEqual(a1.Bisoequiv, s.Cu0.Biso.getValue()) - self.assertEqual(a2.Bisoequiv, s.Ag0.Biso.getValue()) - for i in range(1, 4): - for j in range(i, 4): - uijstru = getattr(a1, "U%i%i" % (i, j)) - uij = getattr(s.Cu0, "U%i%i" % (i, j)).getValue() - uji = getattr(s.Cu0, "U%i%i" % (j, i)).getValue() - self.assertEqual(uijstru, uij) - self.assertEqual(uijstru, uji) - bijstru = getattr(a1, "B%i%i" % (i, j)) - bij = getattr(s.Cu0, "B%i%i" % (i, j)).getValue() - bji = getattr(s.Cu0, "B%i%i" % (j, i)).getValue() - self.assertEqual(bijstru, bij) - self.assertEqual(bijstru, bji) - - self.assertEqual(a1.xyz[0], s.Cu0.x.getValue()) - self.assertEqual(a1.xyz[1], s.Cu0.y.getValue()) - self.assertEqual(a1.xyz[2], s.Cu0.z.getValue()) - return - - def _testLattice(): - - # Test the lattice - self.assertEqual(dsstru.lattice.a, s.lattice.a.getValue()) - self.assertEqual(dsstru.lattice.b, s.lattice.b.getValue()) - self.assertEqual(dsstru.lattice.c, s.lattice.c.getValue()) - self.assertEqual(dsstru.lattice.alpha, s.lattice.alpha.getValue()) - self.assertEqual(dsstru.lattice.beta, s.lattice.beta.getValue()) - self.assertEqual(dsstru.lattice.gamma, s.lattice.gamma.getValue()) - - _testAtoms() - _testLattice() - - # Now change some values from the diffpy Structure - a1.xyz[1] = 0.123 - a1.U11 = 0.321 - a1.B32 = 0.111 - dsstru.lattice.setLatPar(a=3.0, gamma=121) - _testAtoms() - _testLattice() - - # Now change values from the srfit DiffpyStructureParSet - s.Cu0.x.setValue(0.456) - s.Cu0.U22.setValue(0.441) - s.Cu0.B13.setValue(0.550) - d = dsstru.lattice.dist(a1.xyz, a2.xyz) - s.lattice.b.setValue(4.6) - s.lattice.alpha.setValue(91.3) - _testAtoms() - _testLattice() - # Make sure the distance changed - self.assertNotEqual(d, dsstru.lattice.dist(a1.xyz, a2.xyz)) +import numpy as np +import pytest + +from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet + + +def testDiffpyStructureParSet(diffpy_structure_available): + """Test the structure conversion.""" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Atom, Lattice, Structure + + a1 = Atom("Cu", xyz=np.array([0.0, 0.1, 0.2]), Uisoequiv=0.003) + a2 = Atom("Ag", xyz=np.array([0.3, 0.4, 0.5]), Uisoequiv=0.002) + lattice = Lattice(2.5, 2.5, 2.5, 90, 90, 90) + + dsstru = Structure([a1, a2], lattice) + # Structure makes copies + a1 = dsstru[0] + a2 = dsstru[1] + + s = DiffpyStructureParSet("CuAg", dsstru) + + assert s.name == "CuAg" + + def _testAtoms(): + # Check the atoms thoroughly + assert a1.element == s.Cu0.element + assert a2.element == s.Ag0.element + assert a1.Uisoequiv == s.Cu0.Uiso.getValue() + assert a2.Uisoequiv == s.Ag0.Uiso.getValue() + assert a1.Bisoequiv == s.Cu0.Biso.getValue() + assert a2.Bisoequiv == s.Ag0.Biso.getValue() + for i in range(1, 4): + for j in range(i, 4): + uijstru = getattr(a1, "U%i%i" % (i, j)) + uij = getattr(s.Cu0, "U%i%i" % (i, j)).getValue() + uji = getattr(s.Cu0, "U%i%i" % (j, i)).getValue() + assert uijstru == uij + assert uijstru == uji + bijstru = getattr(a1, "B%i%i" % (i, j)) + bij = getattr(s.Cu0, "B%i%i" % (i, j)).getValue() + bji = getattr(s.Cu0, "B%i%i" % (j, i)).getValue() + assert bijstru == bij + assert bijstru == bji + + assert a1.xyz[0] == s.Cu0.x.getValue() + assert a1.xyz[1] == s.Cu0.y.getValue() + assert a1.xyz[2] == s.Cu0.z.getValue() return - def test___repr__(self): - """Test representation of DiffpyStructureParSet objects.""" - lat = Lattice(3, 3, 2, 90, 90, 90) - atom = Atom("C", [0, 0.2, 0.5]) - stru = Structure([atom], lattice=lat) - dsps = DiffpyStructureParSet("dsps", stru) - self.assertEqual(repr(stru), repr(dsps)) - self.assertEqual(repr(lat), repr(dsps.lattice)) - self.assertEqual(repr(atom), repr(dsps.atoms[0])) - return - - def test_pickling(self): - """Test pickling of DiffpyStructureParSet.""" - stru = Structure([Atom("C", [0, 0.2, 0.5])]) - dsps = DiffpyStructureParSet("dsps", stru) - data = pickle.dumps(dsps) - dsps2 = pickle.loads(data) - self.assertEqual(1, len(dsps2.atoms)) - self.assertEqual(0.2, dsps2.atoms[0].y.value) - return + def _testLattice(): + + # Test the lattice + assert dsstru.lattice.a == s.lattice.a.getValue() + assert dsstru.lattice.b == s.lattice.b.getValue() + assert dsstru.lattice.c == s.lattice.c.getValue() + assert dsstru.lattice.alpha == s.lattice.alpha.getValue() + assert dsstru.lattice.beta == s.lattice.beta.getValue() + assert dsstru.lattice.gamma == s.lattice.gamma.getValue() + + _testAtoms() + _testLattice() + + # Now change some values from the diffpy Structure + a1.xyz[1] = 0.123 + a1.U11 = 0.321 + a1.B32 = 0.111 + dsstru.lattice.setLatPar(a=3.0, gamma=121) + _testAtoms() + _testLattice() + + # Now change values from the srfit DiffpyStructureParSet + s.Cu0.x.setValue(0.456) + s.Cu0.U22.setValue(0.441) + s.Cu0.B13.setValue(0.550) + d = dsstru.lattice.dist(a1.xyz, a2.xyz) + s.lattice.b.setValue(4.6) + s.lattice.alpha.setValue(91.3) + _testAtoms() + _testLattice() + # Make sure the distance changed + assert d != dsstru.lattice.dist(a1.xyz, a2.xyz) + return + + +def test___repr__(diffpy_structure_available): + """Test representation of DiffpyStructureParSet objects.""" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Atom, Lattice, Structure + + lat = Lattice(3, 3, 2, 90, 90, 90) + atom = Atom("C", [0, 0.2, 0.5]) + stru = Structure([atom], lattice=lat) + dsps = DiffpyStructureParSet("dsps", stru) + assert repr(stru) == repr(dsps) + assert repr(lat) == repr(dsps.lattice) + assert repr(atom) == repr(dsps.atoms[0]) + return + + +def test_pickling(diffpy_structure_available): + """Test pickling of DiffpyStructureParSet.""" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Atom, Structure + + stru = Structure([Atom("C", [0, 0.2, 0.5])]) + dsps = DiffpyStructureParSet("dsps", stru) + data = pickle.dumps(dsps) + dsps2 = pickle.loads(data) + assert 1 == len(dsps2.atoms) + assert 0.2 == dsps2.atoms[0].y.value + return # End of class TestParameterAdapter diff --git a/tests/test_equation.py b/tests/test_equation.py index 1c9fe6b9..ff44a3d8 100644 --- a/tests/test_equation.py +++ b/tests/test_equation.py @@ -14,181 +14,192 @@ ############################################################################## """Tests for refinableobj module.""" -import unittest +import pytest import diffpy.srfit.equation.literals as literals from diffpy.srfit.equation import Equation -from .utils import _makeArgs, noObserversInGlobalBuilders - - -class TestEquation(unittest.TestCase): - - def testSimpleFunction(self): - """Test a simple function.""" - - # Make some variables - v1, v2, v3, v4, c = _makeArgs(5) - c.name = "c" - c.const = True - - # Make some operations - mult = literals.MultiplicationOperator() - root = mult2 = literals.MultiplicationOperator() - plus = literals.AdditionOperator() - minus = literals.SubtractionOperator() - - # Create the equation c*(v1+v3)*(v4-v2) - plus.addLiteral(v1) - plus.addLiteral(v3) - minus.addLiteral(v4) - minus.addLiteral(v2) - mult.addLiteral(plus) - mult.addLiteral(minus) - mult2.addLiteral(mult) - mult2.addLiteral(c) - - # Set the values of the variables. - # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 - v1.setValue(1) - v2.setValue(2) - v3.setValue(3) - v4.setValue(4) - c.setValue(2.5) - - # Make an equation and test - eq = Equation("eq", mult2) - - self.assertTrue(eq._value is None) - args = eq.args - self.assertTrue(v1 in args) - self.assertTrue(v2 in args) - self.assertTrue(v3 in args) - self.assertTrue(v4 in args) - self.assertTrue(c not in args) - self.assertTrue(root is eq.root) - - self.assertTrue(v1 is eq.v1) - self.assertTrue(v2 is eq.v2) - self.assertTrue(v3 is eq.v3) - self.assertTrue(v4 is eq.v4) - - self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2) - self.assertEqual(20, eq.getValue()) # same as above - self.assertEqual(20, eq.value) # same as above - self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2) - self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0) - self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0) - self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0) - - # Try some swapping - eq.swap(v4, v1) - self.assertTrue(eq._value is None) - self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0) - args = eq.args - self.assertTrue(v4 not in args) - - # Try to create a dependency loop - self.assertRaises(ValueError, eq.swap, v1, eq.root) - self.assertRaises(ValueError, eq.swap, v1, plus) - self.assertRaises(ValueError, eq.swap, v1, minus) - self.assertRaises(ValueError, eq.swap, v1, mult) - self.assertRaises(ValueError, eq.swap, v1, root) - - # Swap the root - eq.swap(eq.root, v1) - self.assertTrue(eq._value is None) - self.assertEqual(v1.value, eq()) - - self.assertTrue(noObserversInGlobalBuilders()) - return - - def testEmbeddedEquation(self): - """Test a simple function.""" - - # Make some variables - v1, v2, v3, v4, c = _makeArgs(5) - c.name = "c" - c.const = True - - # Make some operations - mult = literals.MultiplicationOperator() - mult2 = literals.MultiplicationOperator() - plus = literals.AdditionOperator() - minus = literals.SubtractionOperator() - - # Create the equation c*(v1+v3)*(v4-v2) - plus.addLiteral(v1) - plus.addLiteral(v3) - minus.addLiteral(v4) - minus.addLiteral(v2) - mult.addLiteral(plus) - mult.addLiteral(minus) - mult2.addLiteral(mult) - mult2.addLiteral(c) - - # Set the values of the variables. - # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 - v1.setValue(1) - v2.setValue(2) - v3.setValue(3) - v4.setValue(4) - c.setValue(2.5) - - # Make an equation and test - root = Equation("root", mult2) - eq = Equation("eq", root) - - self.assertTrue(eq._value is None) - args = eq.args - self.assertTrue(v1 in args) - self.assertTrue(v2 in args) - self.assertTrue(v3 in args) - self.assertTrue(v4 in args) - self.assertTrue(c not in args) - self.assertTrue(root is eq.root) - - self.assertTrue(v1 is eq.v1) - self.assertTrue(v2 is eq.v2) - self.assertTrue(v3 is eq.v3) - self.assertTrue(v4 is eq.v4) - - # Make sure the right messages get sent - v1.value = 0 - self.assertTrue(root._value is None) - self.assertTrue(eq._value is None) - v1.value = 1 - - self.assertEqual(20, eq()) # 20 = 2.5*(1+3)*(4-2) - self.assertEqual(20, eq.getValue()) # same as above - self.assertEqual(20, eq.value) # same as above - self.assertEqual(25, eq(v1=2)) # 25 = 2.5*(2+3)*(4-2) - self.assertEqual(50, eq(v2=0)) # 50 = 2.5*(2+3)*(4-0) - self.assertEqual(30, eq(v3=1)) # 30 = 2.5*(2+1)*(4-0) - self.assertEqual(0, eq(v4=0)) # 20 = 2.5*(2+1)*(0-0) - - # Try some swapping. - eq.swap(v4, v1) - self.assertTrue(eq._value is None) - self.assertEqual(15, eq()) # 15 = 2.5*(2+1)*(2-0) - args = eq.args - self.assertTrue(v4 not in args) - - # Try to create a dependency loop - self.assertRaises(ValueError, eq.swap, v1, eq.root) - self.assertRaises(ValueError, eq.swap, v1, plus) - self.assertRaises(ValueError, eq.swap, v1, minus) - self.assertRaises(ValueError, eq.swap, v1, mult) - self.assertRaises(ValueError, eq.swap, v1, root) - - # Swap the root - eq.swap(eq.root, v1) - self.assertTrue(eq._value is None) - self.assertEqual(v1.value, eq()) - - self.assertTrue(noObserversInGlobalBuilders()) - return - - -if __name__ == "__main__": - unittest.main() + +def testSimpleFunction(make_args, noObserversInGlobalBuilders): + """Test a simple function.""" + + # Make some variables + v1, v2, v3, v4, c = make_args(5) + c.name = "c" + c.const = True + + # Make some operations + mult = literals.MultiplicationOperator() + root = mult2 = literals.MultiplicationOperator() + plus = literals.AdditionOperator() + minus = literals.SubtractionOperator() + + # Create the equation c*(v1+v3)*(v4-v2) + plus.addLiteral(v1) + plus.addLiteral(v3) + minus.addLiteral(v4) + minus.addLiteral(v2) + mult.addLiteral(plus) + mult.addLiteral(minus) + mult2.addLiteral(mult) + mult2.addLiteral(c) + + # Set the values of the variables. + # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 + v1.setValue(1) + v2.setValue(2) + v3.setValue(3) + v4.setValue(4) + c.setValue(2.5) + + # Make an equation and test + eq = Equation("eq", mult2) + + assert eq._value is None + args = eq.args + assert v1 in args + assert v2 in args + assert v3 in args + assert v4 in args + assert c not in args + assert root is eq.root + + assert v1 is eq.v1 + assert v2 is eq.v2 + assert v3 is eq.v3 + assert v4 is eq.v4 + + assert 20 == eq() # 20 = 2.5*(1+3)*(4-2) + assert 20 == eq.getValue() # same as above + assert 20 == eq.value # same as above + assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2) + assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0) + assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0) + assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0) + + # Try some swapping + eq.swap(v4, v1) + assert eq._value is None + assert 15 == eq() # 15 = 2.5*(2+1)*(2-0) + args = eq.args + assert v4 not in args + + # Try to create a dependency loop + with pytest.raises(ValueError): + eq.swap(v1, eq.root) + + with pytest.raises(ValueError): + eq.swap(v1, plus) + + with pytest.raises(ValueError): + eq.swap(v1, minus) + + with pytest.raises(ValueError): + eq.swap(v1, mult) + + with pytest.raises(ValueError): + eq.swap(v1, root) + + # Swap the root + eq.swap(eq.root, v1) + assert eq._value is None + assert v1.value, eq() + + assert noObserversInGlobalBuilders + return + + +def testEmbeddedEquation(make_args, noObserversInGlobalBuilders): + """Test a simple function.""" + + # Make some variables + v1, v2, v3, v4, c = make_args(5) + c.name = "c" + c.const = True + + # Make some operations + mult = literals.MultiplicationOperator() + mult2 = literals.MultiplicationOperator() + plus = literals.AdditionOperator() + minus = literals.SubtractionOperator() + + # Create the equation c*(v1+v3)*(v4-v2) + plus.addLiteral(v1) + plus.addLiteral(v3) + minus.addLiteral(v4) + minus.addLiteral(v2) + mult.addLiteral(plus) + mult.addLiteral(minus) + mult2.addLiteral(mult) + mult2.addLiteral(c) + + # Set the values of the variables. + # The equation should evaluate to 2.5*(1+3)*(4-2) = 20 + v1.setValue(1) + v2.setValue(2) + v3.setValue(3) + v4.setValue(4) + c.setValue(2.5) + + # Make an equation and test + root = Equation("root", mult2) + eq = Equation("eq", root) + + assert eq._value is None + args = eq.args + assert v1 in args + assert v2 in args + assert v3 in args + assert v4 in args + assert c not in args + assert root is eq.root + + assert v1 is eq.v1 + assert v2 is eq.v2 + assert v3 is eq.v3 + assert v4 is eq.v4 + + # Make sure the right messages get sent + v1.value = 0 + assert root._value is None + assert eq._value is None + v1.value = 1 + + assert 20 == eq() # 20 = 2.5*(1+3)*(4-2) + assert 20 == eq.getValue() # same as above + assert 20 == eq.value # same as above + assert 25 == eq(v1=2) # 25 = 2.5*(2+3)*(4-2) + assert 50 == eq(v2=0) # 50 = 2.5*(2+3)*(4-0) + assert 30 == eq(v3=1) # 30 = 2.5*(2+1)*(4-0) + assert 0 == eq(v4=0) # 20 = 2.5*(2+1)*(0-0) + + # Try some swapping. + eq.swap(v4, v1) + assert eq._value is None + assert 15 == eq() # 15 = 2.5*(2+1)*(2-0) + args = eq.args + assert v4 not in args + + # Try to create a dependency loop + with pytest.raises(ValueError): + eq.swap(v1, eq.root) + + with pytest.raises(ValueError): + eq.swap(v1, plus) + + with pytest.raises(ValueError): + eq.swap(v1, minus) + + with pytest.raises(ValueError): + eq.swap(v1, mult) + + with pytest.raises(ValueError): + eq.swap(v1, root) + + # Swap the root + eq.swap(eq.root, v1) + assert eq._value is None + assert v1.value == eq() + + assert noObserversInGlobalBuilders + return diff --git a/tests/test_fitrecipe.py b/tests/test_fitrecipe.py index 1945e063..1a2b2368 100644 --- a/tests/test_fitrecipe.py +++ b/tests/test_fitrecipe.py @@ -23,8 +23,6 @@ from diffpy.srfit.fitbase.parameter import Parameter from diffpy.srfit.fitbase.profile import Profile -from .utils import capturestdout - class TestFitRecipe(unittest.TestCase): @@ -239,32 +237,51 @@ def testResidual(self): return - def testPrintFitHook(self): - "check output from default PrintFitHook." - self.recipe.addVar(self.fitcontribution.c) - self.recipe.restrain("c", lb=5) - (pfh,) = self.recipe.getFitHooks() - out = capturestdout(self.recipe.scalarResidual) - self.assertEqual("", out) - pfh.verbose = 1 - out = capturestdout(self.recipe.scalarResidual) - self.assertTrue(out.strip().isdigit()) - self.assertFalse("\nRestraints:" in out) - pfh.verbose = 2 - out = capturestdout(self.recipe.scalarResidual) - self.assertTrue("\nResidual:" in out) - self.assertTrue("\nRestraints:" in out) - self.assertFalse("\nVariables" in out) - pfh.verbose = 3 - out = capturestdout(self.recipe.scalarResidual) - self.assertTrue("\nVariables" in out) - self.assertTrue("c = " in out) - return - # End of class TestFitRecipe + # ---------------------------------------------------------------------------- +def testPrintFitHook(capturestdout): + "check output from default PrintFitHook." + recipe = FitRecipe("recipe") + recipe.fithooks[0].verbose = 0 + + # Set up the Profile + profile = Profile() + x = linspace(0, pi, 10) + y = sin(x) + profile.setObservedProfile(x, y) + + # Set up the FitContribution + fitcontribution = FitContribution("cont") + fitcontribution.setProfile(profile) + fitcontribution.setEquation("A*sin(k*x + c)") + fitcontribution.A.setValue(1) + fitcontribution.k.setValue(1) + fitcontribution.c.setValue(0) + + recipe.addContribution(fitcontribution) + + recipe.addVar(fitcontribution.c) + recipe.restrain("c", lb=5) + (pfh,) = recipe.getFitHooks() + out = capturestdout(recipe.scalarResidual) + assert "" == out + pfh.verbose = 1 + out = capturestdout(recipe.scalarResidual) + assert out.strip().isdigit() + assert "\nRestraints:" not in out + pfh.verbose = 2 + out = capturestdout(recipe.scalarResidual) + assert "\nResidual:" in out + assert "\nRestraints:" in out + assert "\nVariables" not in out + pfh.verbose = 3 + out = capturestdout(recipe.scalarResidual) + assert "\nVariables" in out + assert "c = " in out + return if __name__ == "__main__": diff --git a/tests/test_fitresults.py b/tests/test_fitresults.py index 841c7a73..5702b03f 100644 --- a/tests/test_fitresults.py +++ b/tests/test_fitresults.py @@ -16,64 +16,76 @@ import unittest +import pytest + from diffpy.srfit.fitbase.fitrecipe import FitRecipe from diffpy.srfit.fitbase.fitresults import initializeRecipe -from .utils import datafile +def testInitializeFromFileName(datafile): + recipe = FitRecipe("recipe") + recipe.newVar("A", 0) + recipe.newVar("sig", 0) + recipe.newVar("x0", 0) + filename = datafile("results.res") + Aval = 5.77619823e-01 + sigval = -9.22758690e-01 + x0val = 6.12422115e00 + + assert 0 == recipe.A.value + assert 0 == recipe.sig.value + assert 0 == recipe.x0.value + initializeRecipe(recipe, filename) + assert Aval == pytest.approx(recipe.A.value) + assert sigval == pytest.approx(recipe.sig.value) + assert x0val == pytest.approx(recipe.x0.value) + return -class TestInitializeRecipe(unittest.TestCase): - def setUp(self): - self.recipe = recipe = FitRecipe("recipe") - recipe.newVar("A", 0) - recipe.newVar("sig", 0) - recipe.newVar("x0", 0) - self.filename = datafile("results.res") +def testInitializeFromFileObj(datafile): + recipe = FitRecipe("recipe") + recipe.newVar("A", 0) + recipe.newVar("sig", 0) + recipe.newVar("x0", 0) + filename = datafile("results.res") + Aval = 5.77619823e-01 + sigval = -9.22758690e-01 + x0val = 6.12422115e00 - self.Aval = 5.77619823e-01 - self.sigval = -9.22758690e-01 - self.x0val = 6.12422115e00 - return + assert 0 == recipe.A.value + assert 0 == recipe.sig.value + assert 0 == recipe.x0.value + infile = open(filename, "r") + initializeRecipe(recipe, infile) + assert not infile.closed + infile.close() + assert Aval == pytest.approx(recipe.A.value) + assert sigval == pytest.approx(recipe.sig.value) + assert x0val == pytest.approx(recipe.x0.value) + return - def testInitializeFromFileName(self): - recipe = self.recipe - self.assertEqual(0, recipe.A.value) - self.assertEqual(0, recipe.sig.value) - self.assertEqual(0, recipe.x0.value) - initializeRecipe(recipe, self.filename) - self.assertAlmostEqual(self.Aval, recipe.A.value) - self.assertAlmostEqual(self.sigval, recipe.sig.value) - self.assertAlmostEqual(self.x0val, recipe.x0.value) - return - def testInitializeFromFileObj(self): - recipe = self.recipe - self.assertEqual(0, recipe.A.value) - self.assertEqual(0, recipe.sig.value) - self.assertEqual(0, recipe.x0.value) - infile = open(self.filename, "r") - initializeRecipe(recipe, infile) - self.assertFalse(infile.closed) - infile.close() - self.assertAlmostEqual(self.Aval, recipe.A.value) - self.assertAlmostEqual(self.sigval, recipe.sig.value) - self.assertAlmostEqual(self.x0val, recipe.x0.value) - return +def testInitializeFromString(datafile): + recipe = FitRecipe("recipe") + recipe.newVar("A", 0) + recipe.newVar("sig", 0) + recipe.newVar("x0", 0) + filename = datafile("results.res") + Aval = 5.77619823e-01 + sigval = -9.22758690e-01 + x0val = 6.12422115e00 - def testInitializeFromString(self): - recipe = self.recipe - self.assertEqual(0, recipe.A.value) - self.assertEqual(0, recipe.sig.value) - self.assertEqual(0, recipe.x0.value) - infile = open(self.filename, "r") - resstr = infile.read() - infile.close() - initializeRecipe(recipe, resstr) - self.assertAlmostEqual(self.Aval, recipe.A.value) - self.assertAlmostEqual(self.sigval, recipe.sig.value) - self.assertAlmostEqual(self.x0val, recipe.x0.value) - return + assert 0 == recipe.A.value + assert 0 == recipe.sig.value + assert 0 == recipe.x0.value + infile = open(filename, "r") + resstr = infile.read() + infile.close() + initializeRecipe(recipe, resstr) + assert Aval == pytest.approx(recipe.A.value) + assert sigval == pytest.approx(recipe.sig.value) + assert x0val == pytest.approx(recipe.x0.value) + return if __name__ == "__main__": diff --git a/tests/test_objcrystparset.py b/tests/test_objcrystparset.py index 2e516136..c21a6d28 100644 --- a/tests/test_objcrystparset.py +++ b/tests/test_objcrystparset.py @@ -17,8 +17,7 @@ import unittest import numpy - -from .utils import _msg_nopyobjcryst, has_pyobjcryst +import pytest # Global variables to be assigned in setUp ObjCrystCrystalParSet = spacegroups = None @@ -112,10 +111,13 @@ def makeC60(): # ---------------------------------------------------------------------------- -@unittest.skipUnless(has_pyobjcryst, _msg_nopyobjcryst) -class TestParameterAdapter(unittest.TestCase): +class TestParameterAdapter: + @pytest.fixture(autouse=True) + def setup(self, pyobjcryst_available): + # shared setup + if not pyobjcryst_available: + pytest.skip("pyobjcryst package not available") - def setUp(self): global ObjCrystCrystalParSet, Crystal, Atom, Molecule global ScatteringPowerAtom from pyobjcryst.atom import Atom @@ -134,60 +136,83 @@ def tearDown(self): del self.ocmol return + def testImplicitBondAngleRestraints(self): + """Test the structure with implicit bond angles.""" + occryst = self.occryst + ocmol = self.ocmol + + # Add some bond angles to the molecule + ocmol.AddBondAngle(ocmol[0], ocmol[5], ocmol[8], 1.1, 0.1, 0.1) + ocmol.AddBondAngle(ocmol[0], ocmol[7], ocmol[44], 1.3, 0.1, 0.1) + + # make our crystal + cryst = ObjCrystCrystalParSet("bucky", occryst) + m = cryst.c60 + m.wrapRestraints() + + # make sure that we have some restraints in the molecule + assert 2 == len(m._restraints) + + # make sure these evaluate to whatver we get from objcryst + res0, res1 = m._restraints + p0 = set([res0.penalty(), res1.penalty()]) + angles = ocmol.GetBondAngleList() + p1 = set([angles[0].GetLogLikelihood(), angles[1].GetLogLikelihood()]) + assert p0 == p1 + + return + def testObjCrystParSet(self): """Test the structure conversion.""" occryst = self.occryst ocmol = self.ocmol - cryst = ObjCrystCrystalParSet("bucky", occryst) m = cryst.c60 - self.assertEqual(cryst.name, "bucky") + assert cryst.name == "bucky" def _testCrystal(): - # Test the lattice - self.assertAlmostEqual(occryst.a, cryst.a.value) - self.assertAlmostEqual(occryst.b, cryst.b.getValue()) - self.assertAlmostEqual(occryst.c, cryst.c.getValue()) - self.assertAlmostEqual(occryst.alpha, cryst.alpha.getValue()) - self.assertAlmostEqual(occryst.beta, cryst.beta.getValue()) - self.assertAlmostEqual(occryst.gamma, cryst.gamma.getValue()) - + assert occryst.a == pytest.approx(cryst.a.value) + assert occryst.b == pytest.approx(cryst.b.getValue()) + assert occryst.c == pytest.approx(cryst.c.getValue()) + assert occryst.alpha == pytest.approx(cryst.alpha.getValue()) + assert occryst.beta == pytest.approx(cryst.beta.getValue()) + assert occryst.gamma == pytest.approx(cryst.gamma.getValue()) return def _testMolecule(): # Test position / occupancy - self.assertAlmostEqual(ocmol.X, m.x.getValue()) - self.assertAlmostEqual(ocmol.Y, m.y.getValue()) - self.assertAlmostEqual(ocmol.Z, m.z.getValue()) - self.assertAlmostEqual(ocmol.Occupancy, m.occ.getValue()) + assert ocmol.X == pytest.approx(m.x.getValue()) + assert ocmol.Y == pytest.approx(m.y.getValue()) + assert ocmol.Z == pytest.approx(m.z.getValue()) + assert ocmol.Occupancy == pytest.approx(m.occ.getValue()) # Test orientation - self.assertAlmostEqual(ocmol.Q0, m.q0.getValue()) - self.assertAlmostEqual(ocmol.Q1, m.q1.getValue()) - self.assertAlmostEqual(ocmol.Q2, m.q2.getValue()) - self.assertAlmostEqual(ocmol.Q3, m.q3.getValue()) + assert ocmol.Q0 == pytest.approx(m.q0.getValue()) + assert ocmol.Q1 == pytest.approx(m.q1.getValue()) + assert ocmol.Q2 == pytest.approx(m.q2.getValue()) + assert ocmol.Q3 == pytest.approx(m.q3.getValue()) # Check the atoms thoroughly for i in range(len(ocmol)): oca = ocmol[i] ocsp = oca.GetScatteringPower() a = m.atoms[i] - self.assertEqual(ocsp.GetSymbol(), a.element) - self.assertAlmostEqual(oca.X, a.x.getValue()) - self.assertAlmostEqual(oca.Y, a.y.getValue()) - self.assertAlmostEqual(oca.Z, a.z.getValue()) - self.assertAlmostEqual(oca.Occupancy, a.occ.getValue()) - self.assertAlmostEqual(ocsp.Biso, a.Biso.getValue()) + assert ocsp.GetSymbol() == a.element + assert oca.X == pytest.approx(a.x.getValue()) + assert oca.Y == pytest.approx(a.y.getValue()) + assert oca.Z == pytest.approx(a.z.getValue()) + assert oca.Occupancy == pytest.approx(a.occ.getValue()) + assert ocsp.Biso == pytest.approx(a.Biso.getValue()) return _testCrystal() _testMolecule() - ## Now change some values from ObjCryst + # Now change some values from ObjCryst ocmol[0].X *= 1.1 ocmol[0].Occupancy *= 1.1 ocmol[0].GetScatteringPower().Biso *= 1.1 @@ -197,7 +222,7 @@ def _testMolecule(): _testCrystal() _testMolecule() - ## Now change values from the srfit StructureParSet + # Now change values from the srfit StructureParSet cryst.c60.C44.x.setValue(1.1) cryst.c60.C44.occ.setValue(1.1) cryst.c60.C44.Biso.setValue(1.1) @@ -223,40 +248,14 @@ def testImplicitBondLengthRestraints(self): m.wrapRestraints() # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) + assert 2 == len(m._restraints) # make sure these evaluate to whatver we get from objcryst res0, res1 = m._restraints p0 = set([res0.penalty(), res1.penalty()]) bonds = ocmol.GetBondList() p1 = set([bonds[0].GetLogLikelihood(), bonds[1].GetLogLikelihood()]) - self.assertEqual(p0, p1) - - return - - def testImplicitBondAngleRestraints(self): - """Test the structure with implicit bond angles.""" - occryst = self.occryst - ocmol = self.ocmol - - # Add some bond angles to the molecule - ocmol.AddBondAngle(ocmol[0], ocmol[5], ocmol[8], 1.1, 0.1, 0.1) - ocmol.AddBondAngle(ocmol[0], ocmol[7], ocmol[44], 1.3, 0.1, 0.1) - - # make our crystal - cryst = ObjCrystCrystalParSet("bucky", occryst) - m = cryst.c60 - m.wrapRestraints() - - # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) - - # make sure these evaluate to whatver we get from objcryst - res0, res1 = m._restraints - p0 = set([res0.penalty(), res1.penalty()]) - angles = ocmol.GetBondAngleList() - p1 = set([angles[0].GetLogLikelihood(), angles[1].GetLogLikelihood()]) - self.assertEqual(p0, p1) + assert p0 == p1 return @@ -279,14 +278,14 @@ def testImplicitDihedralAngleRestraints(self): m.wrapRestraints() # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) + assert 2 == len(m._restraints) # make sure these evaluate to whatver we get from objcryst res0, res1 = m._restraints p0 = set([res0.penalty(), res1.penalty()]) angles = ocmol.GetDihedralAngleList() p1 = set([angles[0].GetLogLikelihood(), angles[1].GetLogLikelihood()]) - self.assertEqual(p0, p1) + assert p0 == p1 return @@ -309,14 +308,14 @@ def testExplicitBondLengthRestraints(self): res1 = m.restrainBondLength(m.atoms[0], m.atoms[7], 3.3, 0.1, 0.1) # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) + assert 2 == len(m._restraints) # make sure these evaluate to whatver we get from objcryst p0 = [res0.penalty(), res1.penalty()] bonds = ocmol.GetBondList() - self.assertEqual(2, len(bonds)) + assert 2 == len(bonds) p1 = [b.GetLogLikelihood() for b in bonds] - self.assertEqual(p0, p1) + assert p0 == p1 return @@ -342,13 +341,13 @@ def testExplicitBondAngleRestraints(self): ) # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) + assert 2 == len(m._restraints) # make sure these evaluate to whatver we get from objcryst p0 = set([res0.penalty(), res1.penalty()]) angles = ocmol.GetBondAngleList() p1 = set([angles[0].GetLogLikelihood(), angles[1].GetLogLikelihood()]) - self.assertEqual(p0, p1) + assert p0 == p1 return @@ -370,13 +369,13 @@ def testExplicitDihedralAngleRestraints(self): ) # make sure that we have some restraints in the molecule - self.assertTrue(2, len(m._restraints)) + assert 2 == len(m._restraints) # make sure these evaluate to whatver we get from objcryst p0 = set([res0.penalty(), res1.penalty()]) angles = ocmol.GetDihedralAngleList() p1 = set([angles[0].GetLogLikelihood(), angles[1].GetLogLikelihood()]) - self.assertEqual(p0, p1) + assert p0 == p1 return @@ -405,7 +404,7 @@ def testExplicitBondLengthParameter(self): dd = xyz0 - xyz7 d0 = numpy.dot(dd, dd) ** 0.5 - self.assertAlmostEqual(d0, p1.getValue(), 6) + assert d0 == pytest.approx(p1.getValue(), abs=1e-6) # Record the unit direction of change for later u = dd / d0 @@ -415,7 +414,7 @@ def testExplicitBondLengthParameter(self): p1.setValue(scale * d0) # Verify that it has changed. - self.assertAlmostEqual(scale * d0, p1.getValue()) + assert scale * d0 == pytest.approx(p1.getValue(), abs=1e-6) xyz0a = numpy.array( [a0.x.getValue(), a0.y.getValue(), a0.z.getValue()] @@ -430,19 +429,19 @@ def testExplicitBondLengthParameter(self): dda = xyz0a - xyz7a d1 = numpy.dot(dda, dda) ** 0.5 - self.assertAlmostEqual(scale * d0, d1) + assert scale * d0 == pytest.approx(d1, abs=1e-6) # Verify that only the second and third atoms have moved. - self.assertTrue(numpy.array_equal(xyz0, xyz0a)) + assert numpy.array_equal(xyz0, xyz0a) xyz7calc = xyz7 + (1 - scale) * d0 * u for i in range(3): - self.assertAlmostEqual(xyz7a[i], xyz7calc[i], 6) + assert xyz7a[i] == pytest.approx(xyz7calc[i], abs=1e-5) xyz20calc = xyz20 + (1 - scale) * d0 * u for i in range(3): - self.assertAlmostEqual(xyz20a[i], xyz20calc[i], 6) + assert xyz20a[i] == pytest.approx(xyz20calc[i], abs=1e-6) return @@ -480,14 +479,14 @@ def testExplicitBondAngleParameter(self): # Have another atom tag along for the ride p1.addAtoms([a25]) - self.assertAlmostEqual(angle0, p1.getValue(), 6) + assert angle0 == pytest.approx(p1.getValue(), abs=1e-6) # Change the value scale = 1.05 p1.setValue(scale * angle0) # Verify that it has changed. - self.assertAlmostEqual(scale * angle0, p1.getValue(), 6) + assert scale * angle0 == pytest.approx(p1.getValue(), abs=1e-6) xyz0a = numpy.array( [a0.x.getValue(), a0.y.getValue(), a0.z.getValue()] @@ -509,14 +508,14 @@ def testExplicitBondAngleParameter(self): angle1 = numpy.arccos(numpy.dot(v1a, v2a) / (d1a * d2a)) - self.assertAlmostEqual(scale * angle0, angle1) + assert scale * angle0 == pytest.approx(angle1, abs=1e-6) # Verify that only the last two atoms have moved. - self.assertTrue(numpy.array_equal(xyz0, xyz0a)) - self.assertTrue(numpy.array_equal(xyz7, xyz7a)) - self.assertFalse(numpy.array_equal(xyz20, xyz20a)) - self.assertFalse(numpy.array_equal(xyz25, xyz25a)) + assert numpy.array_equal(xyz0, xyz0a) + assert numpy.array_equal(xyz7, xyz7a) + assert not numpy.array_equal(xyz20, xyz20a) + assert not numpy.array_equal(xyz25, xyz25a) return @@ -561,14 +560,14 @@ def testExplicitDihedralAngleParameter(self): # Have another atom tag along for the ride p1.addAtoms([a33]) - self.assertAlmostEqual(angle0, p1.getValue(), 6) + assert angle0 == pytest.approx(p1.getValue(), abs=1e-6) # Change the value scale = 1.05 p1.setValue(scale * angle0) # Verify that it has changed. - self.assertAlmostEqual(scale * angle0, p1.getValue(), 6) + assert scale * angle0 == pytest.approx(p1.getValue(), abs=1e-6) xyz0a = numpy.array( [a0.x.getValue(), a0.y.getValue(), a0.z.getValue()] @@ -595,34 +594,34 @@ def testExplicitDihedralAngleParameter(self): d123a = numpy.dot(v123a, v123a) ** 0.5 d234a = numpy.dot(v234a, v234a) ** 0.5 angle1 = -numpy.arccos(numpy.dot(v123a, v234a) / (d123a * d234a)) - - self.assertAlmostEqual(scale * angle0, angle1) + assert scale * angle0 == pytest.approx(angle1, abs=1e-6) # Verify that only the last two atoms have moved. - self.assertTrue(numpy.array_equal(xyz0, xyz0a)) - self.assertTrue(numpy.array_equal(xyz7, xyz7a)) - self.assertTrue(numpy.array_equal(xyz20, xyz20a)) - self.assertFalse(numpy.array_equal(xyz25, xyz25a)) - self.assertFalse(numpy.array_equal(xyz33, xyz33a)) + assert numpy.array_equal(xyz0, xyz0a) + assert numpy.array_equal(xyz7, xyz7a) + assert numpy.array_equal(xyz20, xyz20a) + assert not numpy.array_equal(xyz25, xyz25a) + assert not numpy.array_equal(xyz33, xyz33a) return -# End of class TestParameterAdapter - -# ---------------------------------------------------------------------------- - - -@unittest.skipUnless(has_pyobjcryst, _msg_nopyobjcryst) -class TestCreateSpaceGroup(unittest.TestCase): +class TestCreateSpaceGroup: """Test space group creation from pyobjcryst structures. This makes sure that the space groups created by the structure parameter set are correct. """ - def setUp(self): + @pytest.fixture(autouse=True) + def setup(self, diffpy_structure_available, pyobjcryst_available): + # shared setup + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + if not pyobjcryst_available: + pytest.skip("pyobjcryst package not available") + global ObjCrystCrystalParSet, spacegroups from diffpy.srfit.structure.objcrystparset import ObjCrystCrystalParSet from diffpy.structure import spacegroups @@ -665,7 +664,7 @@ def xtestCreateSpaceGroup(self): sg = spacegroups.GetSpaceGroup(shn) sgnew = self.getObjCrystParSetSpaceGroup(sg) # print("dbsg: " + repr(self.sgsEquivalent(sg, sgnew))) - self.assertTrue(self.sgsEquivalent(sg, sgnew)) + assert self.sgsEquivalent(sg, sgnew) return diff --git a/tests/test_pdf.py b/tests/test_pdf.py index f9ad04bc..929b16ed 100644 --- a/tests/test_pdf.py +++ b/tests/test_pdf.py @@ -17,303 +17,301 @@ import io import pickle import unittest +from itertools import chain import numpy +import pytest from diffpy.srfit.exceptions import SrFitError from diffpy.srfit.pdf import PDFContribution, PDFGenerator, PDFParser -from .utils import ( - _msg_nosrreal, - _msg_nostructure, - datafile, - has_srreal, - has_structure, -) - -# ---------------------------------------------------------------------------- - - -class TestPDFParset(unittest.TestCase): - - def setUp(self): - return - - def testParser1(self): - data = datafile("ni-q27r100-neutron.gr") - parser = PDFParser() - parser.parseFile(data) - - meta = parser._meta - - self.assertEqual(data, meta["filename"]) - self.assertEqual(1, meta["nbanks"]) - self.assertEqual("N", meta["stype"]) - self.assertEqual(27, meta["qmax"]) - self.assertEqual(300, meta.get("temperature")) - self.assertEqual(None, meta.get("qdamp")) - self.assertEqual(None, meta.get("qbroad")) - self.assertEqual(None, meta.get("spdiameter")) - self.assertEqual(None, meta.get("scale")) - self.assertEqual(None, meta.get("doping")) - - x, y, dx, dy = parser.getData() - self.assertTrue(dx is None) - self.assertTrue(dy is None) - - testx = numpy.linspace(0.01, 100, 10000) - diff = testx - x - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testy = numpy.array( - [ - 1.144, - 2.258, - 3.312, - 4.279, - 5.135, - 5.862, - 6.445, - 6.875, - 7.150, - 7.272, - ] - ) - diff = testy - y[:10] - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - return - - def testParser2(self): - data = datafile("si-q27r60-xray.gr") - parser = PDFParser() - parser.parseFile(data) - - meta = parser._meta - - self.assertEqual(data, meta["filename"]) - self.assertEqual(1, meta["nbanks"]) - self.assertEqual("X", meta["stype"]) - self.assertEqual(27, meta["qmax"]) - self.assertEqual(300, meta.get("temperature")) - self.assertEqual(None, meta.get("qdamp")) - self.assertEqual(None, meta.get("qbroad")) - self.assertEqual(None, meta.get("spdiameter")) - self.assertEqual(None, meta.get("scale")) - self.assertEqual(None, meta.get("doping")) - - x, y, dx, dy = parser.getData() - testx = numpy.linspace(0.01, 60, 5999, endpoint=False) - diff = testx - x - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testy = numpy.array( - [ - 0.1105784, - 0.2199684, - 0.3270088, - 0.4305913, - 0.5296853, - 0.6233606, - 0.7108060, - 0.7913456, - 0.8644501, - 0.9297440, - ] - ) - diff = testy - y[:10] - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testdy = numpy.array( - [ - 0.001802192, - 0.003521449, - 0.005079115, - 0.006404892, - 0.007440527, - 0.008142955, - 0.008486813, - 0.008466340, - 0.008096858, - 0.007416456, - ] - ) - diff = testdy - dy[:10] - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - self.assertTrue(dx is None) - return - - -# End of class TestPDFParset - # ---------------------------------------------------------------------------- -@unittest.skipUnless(has_srreal, _msg_nosrreal) -@unittest.skipUnless(has_structure, _msg_nostructure) -class TestPDFGenerator(unittest.TestCase): - - def setUp(self): - self.gen = PDFGenerator() - return - - def testGenerator(self): - qmax = 27.0 - gen = self.gen - gen.setScatteringType("N") - self.assertEqual("N", gen.getScatteringType()) - gen.setQmax(qmax) - self.assertAlmostEqual(qmax, gen.getQmax()) - from diffpy.structure import PDFFitStructure - - stru = PDFFitStructure() - ciffile = datafile("ni.cif") - stru.read(ciffile) - for i in range(4): - stru[i].Bisoequiv = 1 - gen.setStructure(stru) - - calc = gen._calc - # Test parameters - for par in gen.iterPars(recurse=False): - pname = par.name - defval = calc._getDoubleAttr(pname) - self.assertEqual(defval, par.getValue()) - # Test setting values - par.setValue(1.0) - self.assertEqual(1.0, par.getValue()) - par.setValue(defval) - self.assertEqual(defval, par.getValue()) - - r = numpy.arange(0, 10, 0.1) - y = gen(r) - - # Now create a reference PDF. Since the calculator is testing its - # output, we just have to make sure we can calculate from the - # PDFGenerator interface. - from diffpy.srreal.pdfcalculator import PDFCalculator - - calc = PDFCalculator() - calc.rstep = r[1] - r[0] - calc.rmin = r[0] - calc.rmax = r[-1] + 0.5 * calc.rstep - calc.qmax = qmax - calc.setScatteringFactorTableByType("N") - calc.eval(stru) - yref = calc.pdf - - diff = y - yref - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - return - - def test_setQmin(self): - """Verify qmin is propagated to the calculator object.""" - gen = self.gen - self.assertEqual(0, gen.getQmin()) - self.assertEqual(0, gen._calc.qmin) - gen.setQmin(0.93) - self.assertEqual(0.93, gen.getQmin()) - self.assertEqual(0.93, gen._calc.qmin) - return - - -# End of class TestPDFGenerator - -# ---------------------------------------------------------------------------- - - -@unittest.skipUnless(has_srreal, _msg_nosrreal) -@unittest.skipUnless(has_structure, _msg_nostructure) -class TestPDFContribution(unittest.TestCase): - - def setUp(self): - self.pc = PDFContribution("pdf") - return - - def test_setQmax(self): - """Check PDFContribution.setQmax()""" - from diffpy.structure import Structure - - pc = self.pc - pc.setQmax(21) - pc.addStructure("empty", Structure()) - self.assertEqual(21, pc.empty.getQmax()) - pc.setQmax(22) - self.assertEqual(22, pc.getQmax()) - self.assertEqual(22, pc.empty.getQmax()) - return - - def test_getQmax(self): - """Check PDFContribution.getQmax()""" - from diffpy.structure import Structure - - # cover all code branches in PDFContribution._getMetaValue - # (1) contribution metadata - pc1 = self.pc - self.assertIsNone(pc1.getQmax()) - pc1.setQmax(17) - self.assertEqual(17, pc1.getQmax()) - # (2) contribution metadata - pc2 = PDFContribution("pdf") - pc2.addStructure("empty", Structure()) - pc2.empty.setQmax(18) - self.assertEqual(18, pc2.getQmax()) - # (3) profile metadata - pc3 = PDFContribution("pdf") - pc3.profile.meta["qmax"] = 19 - self.assertEqual(19, pc3.getQmax()) - return - - def test_savetxt(self): - "check PDFContribution.savetxt()" - from diffpy.structure import Structure - - pc = self.pc - pc.loadData(datafile("si-q27r60-xray.gr")) - pc.setCalculationRange(0, 10) - pc.addStructure("empty", Structure()) - fp = io.BytesIO() - self.assertRaises(SrFitError, pc.savetxt, fp) - pc.evaluate() +def testParser1(datafile): + data = datafile("ni-q27r100-neutron.gr") + parser = PDFParser() + parser.parseFile(data) + + meta = parser._meta + + assert data == meta["filename"] + assert 1 == meta["nbanks"] + assert "N" == meta["stype"] + assert 27 == meta["qmax"] + assert 300 == meta.get("temperature") + assert meta.get("qdamp") is None + assert meta.get("qbroad") is None + assert meta.get("spdiameter") is None + assert meta.get("scale") is None + assert meta.get("doping") is None + + x, y, dx, dy = parser.getData() + assert dx is None + assert dy is None + + testx = numpy.linspace(0.01, 100, 10000) + diff = testx - x + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testy = numpy.array( + [ + 1.144, + 2.258, + 3.312, + 4.279, + 5.135, + 5.862, + 6.445, + 6.875, + 7.150, + 7.272, + ] + ) + diff = testy - y[:10] + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + return + + +def testParser2(datafile): + data = datafile("si-q27r60-xray.gr") + parser = PDFParser() + parser.parseFile(data) + + meta = parser._meta + + assert data == meta["filename"] + assert 1 == meta["nbanks"] + assert "X" == meta["stype"] + assert 27 == meta["qmax"] + assert 300 == meta.get("temperature") + assert meta.get("qdamp") is None + assert meta.get("qbroad") is None + assert meta.get("spdiameter") is None + assert meta.get("scale") is None + assert meta.get("doping") is None + + x, y, dx, dy = parser.getData() + testx = numpy.linspace(0.01, 60, 5999, endpoint=False) + diff = testx - x + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testy = numpy.array( + [ + 0.1105784, + 0.2199684, + 0.3270088, + 0.4305913, + 0.5296853, + 0.6233606, + 0.7108060, + 0.7913456, + 0.8644501, + 0.9297440, + ] + ) + diff = testy - y[:10] + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testdy = numpy.array( + [ + 0.001802192, + 0.003521449, + 0.005079115, + 0.006404892, + 0.007440527, + 0.008142955, + 0.008486813, + 0.008466340, + 0.008096858, + 0.007416456, + ] + ) + diff = testdy - dy[:10] + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + assert dx is None + return + + +def testGenerator( + diffpy_srreal_available, diffpy_structure_available, datafile +): + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + if not diffpy_srreal_available: + pytest.skip("diffpy.srreal package not available") + + from diffpy.srreal.pdfcalculator import PDFCalculator + from diffpy.structure import PDFFitStructure + + qmax = 27.0 + gen = PDFGenerator() + gen.setScatteringType("N") + assert "N" == gen.getScatteringType() + gen.setQmax(qmax) + assert qmax == pytest.approx(gen.getQmax()) + + stru = PDFFitStructure() + ciffile = datafile("ni.cif") + stru.read(ciffile) + for i in range(4): + stru[i].Bisoequiv = 1 + gen.setStructure(stru) + + calc = gen._calc + # Test parameters + for par in gen.iterPars(recurse=False): + pname = par.name + defval = calc._getDoubleAttr(pname) + assert defval == par.getValue() + # Test setting values + par.setValue(1.0) + assert 1.0 == par.getValue() + par.setValue(defval) + assert defval == par.getValue() + + r = numpy.arange(0, 10, 0.1) + y = gen(r) + + # Now create a reference PDF. Since the calculator is testing its + # output, we just have to make sure we can calculate from the + # PDFGenerator interface. + + calc = PDFCalculator() + calc.rstep = r[1] - r[0] + calc.rmin = r[0] + calc.rmax = r[-1] + 0.5 * calc.rstep + calc.qmax = qmax + calc.setScatteringFactorTableByType("N") + calc.eval(stru) + yref = calc.pdf + + diff = y - yref + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + return + + +def test_setQmin(diffpy_structure_available, diffpy_srreal_available): + """Verify qmin is propagated to the calculator object.""" + if not diffpy_srreal_available: + pytest.skip("diffpy.srreal package not available") + + gen = PDFGenerator() + assert 0 == gen.getQmin() + assert 0 == gen._calc.qmin + gen.setQmin(0.93) + assert 0.93 == gen.getQmin() + assert 0.93 == gen._calc.qmin + return + + +def test_setQmax(diffpy_structure_available, diffpy_srreal_available): + """Check PDFContribution.setQmax()""" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Structure + + if not diffpy_srreal_available: + pytest.skip("diffpy.structure package not available") + + pc = PDFContribution("pdf") + pc.setQmax(21) + pc.addStructure("empty", Structure()) + assert 21 == pc.empty.getQmax() + pc.setQmax(22) + assert 22 == pc.getQmax() + assert 22 == pc.empty.getQmax() + return + + +def test_getQmax(diffpy_structure_available, diffpy_srreal_available): + """Check PDFContribution.getQmax()""" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Structure + + if not diffpy_srreal_available: + pytest.skip("diffpy.structure package not available") + + # cover all code branches in PDFContribution._getMetaValue + # (1) contribution metadata + pc1 = PDFContribution("pdf") + assert pc1.getQmax() is None + pc1.setQmax(17) + assert 17 == pc1.getQmax() + # (2) contribution metadata + pc2 = PDFContribution("pdf") + pc2.addStructure("empty", Structure()) + pc2.empty.setQmax(18) + assert 18 == pc2.getQmax() + # (3) profile metadata + pc3 = PDFContribution("pdf") + pc3.profile.meta["qmax"] = 19 + assert 19 == pc3.getQmax() + return + + +def test_savetxt( + diffpy_structure_available, diffpy_srreal_available, datafile +): + "check PDFContribution.savetxt()" + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import Structure + + if not diffpy_srreal_available: + pytest.skip("diffpy.structure package not available") + + pc = PDFContribution("pdf") + pc.loadData(datafile("si-q27r60-xray.gr")) + pc.setCalculationRange(0, 10) + pc.addStructure("empty", Structure()) + fp = io.BytesIO() + with pytest.raises(SrFitError): pc.savetxt(fp) - txt = fp.getvalue().decode() - nlines = len(txt.strip().split("\n")) - self.assertEqual(1001, nlines) - return - - def test_pickling(self): - "validate PDFContribution.residual() after pickling." - from itertools import chain - - from diffpy.structure import loadStructure - - pc = self.pc - pc.loadData(datafile("ni-q27r100-neutron.gr")) - ni = loadStructure(datafile("ni.cif")) - ni.Uisoequiv = 0.003 - pc.addStructure("ni", ni) - pc.setCalculationRange(0, 10) - pc2 = pickle.loads(pickle.dumps(pc)) - res0 = pc.residual() - self.assertTrue(numpy.array_equal(res0, pc2.residual())) - for p in chain(pc.iterPars("Uiso"), pc2.iterPars("Uiso")): - p.value = 0.004 - res1 = pc.residual() - self.assertFalse(numpy.allclose(res0, res1)) - self.assertTrue(numpy.array_equal(res1, pc2.residual())) - return - - -# End of class TestPDFContribution + pc.evaluate() + pc.savetxt(fp) + txt = fp.getvalue().decode() + nlines = len(txt.strip().split("\n")) + assert 1001 == nlines + return + + +def test_pickling( + diffpy_structure_available, diffpy_srreal_available, datafile +): + "validate PDFContribution.residual() after pickling." + if not diffpy_structure_available: + pytest.skip("diffpy.structure package not available") + from diffpy.structure import loadStructure + + if not diffpy_srreal_available: + pytest.skip("diffpy.structure package not available") + + pc = PDFContribution("pdf") + pc.loadData(datafile("ni-q27r100-neutron.gr")) + ni = loadStructure(datafile("ni.cif")) + ni.Uisoequiv = 0.003 + pc.addStructure("ni", ni) + pc.setCalculationRange(0, 10) + pc2 = pickle.loads(pickle.dumps(pc)) + res0 = pc.residual() + assert numpy.array_equal(res0, pc2.residual()) + for p in chain(pc.iterPars("Uiso"), pc2.iterPars("Uiso")): + p.value = 0.004 + res1 = pc.residual() + assert not numpy.allclose(res0, res1) + assert numpy.array_equal(res1, pc2.residual()) + return -# ---------------------------------------------------------------------------- if __name__ == "__main__": unittest.main() diff --git a/tests/test_profile.py b/tests/test_profile.py index f19adef6..29da2162 100644 --- a/tests/test_profile.py +++ b/tests/test_profile.py @@ -18,13 +18,12 @@ import re import unittest +import pytest from numpy import allclose, arange, array, array_equal, ones_like from diffpy.srfit.exceptions import SrFitError from diffpy.srfit.fitbase.profile import Profile -from .utils import datafile - class TestProfile(unittest.TestCase): @@ -178,37 +177,6 @@ def testSetCalculationPoints(self): return - def testLoadtxt(self): - """Test the loadtxt method.""" - - prof = self.profile - data = datafile("testdata.txt") - - def _test(p): - self.assertAlmostEqual(1e-2, p.x[0]) - self.assertAlmostEqual(1.105784e-1, p.y[0]) - self.assertAlmostEqual(1.802192e-3, p.dy[0]) - - # Test normal load - prof.loadtxt(data, usecols=(0, 1, 3)) - _test(prof) - - # Test trying to not set unpack - prof.loadtxt(data, usecols=(0, 1, 3), unpack=False) - _test(prof) - prof.loadtxt(data, float, "#", None, None, 0, (0, 1, 3), False) - _test(prof) - - # Try not including dy - prof.loadtxt(data, usecols=(0, 1)) - self.assertAlmostEqual(1e-2, prof.x[0]) - self.assertAlmostEqual(1.105784e-1, prof.y[0]) - self.assertAlmostEqual(1, prof.dy[0]) - - # Try to include too little - self.assertRaises(ValueError, prof.loadtxt, data, usecols=(0,)) - return - def test_savetxt(self): "Check the savetxt method." prof = self.profile @@ -226,9 +194,38 @@ def test_savetxt(self): return -# End of class TestProfile +def testLoadtxt(datafile): + """Test the loadtxt method.""" + + prof = Profile() + data = datafile("testdata.txt") + + def _test(p): + assert 1e-2 == pytest.approx(p.x[0]) + assert 1.105784e-1 == pytest.approx(p.y[0]) + assert 1.802192e-3 == pytest.approx(p.dy[0]) + + # Test normal load + prof.loadtxt(data, usecols=(0, 1, 3)) + _test(prof) + + # Test trying to not set unpack + prof.loadtxt(data, usecols=(0, 1, 3), unpack=False) + _test(prof) + prof.loadtxt(data, float, "#", None, None, 0, (0, 1, 3), False) + _test(prof) + + # Try not including dy + prof.loadtxt(data, usecols=(0, 1)) + assert 1e-2 == pytest.approx(prof.x[0]) + assert 1.105784e-1 == pytest.approx(prof.y[0]) + assert 1 == pytest.approx(prof.dy[0]) + + # Try to include too little + with pytest.raises(ValueError): + prof.loadtxt(data, usecols=(0,)) + return -# ---------------------------------------------------------------------------- if __name__ == "__main__": unittest.main() diff --git a/tests/test_recipeorganizer.py b/tests/test_recipeorganizer.py index 60b2e4da..88909b91 100644 --- a/tests/test_recipeorganizer.py +++ b/tests/test_recipeorganizer.py @@ -27,8 +27,6 @@ equationFromString, ) -from .utils import capturestdout - # ---------------------------------------------------------------------------- @@ -519,52 +517,57 @@ def test_releaseOldEquations(self): self.assertEqual(0, len(self.m._eqfactory.equations)) return - def test_show(self): - """Verify output from the show function.""" - - def capture_show(*args, **kwargs): - rv = capturestdout(self.m.show, *args, **kwargs) - return rv - - self.assertEqual("", capture_show()) - self.m._newParameter("x", 1) - self.m._newParameter("y", 2) - out1 = capture_show() - lines1 = out1.strip().split("\n") - self.assertEqual(4, len(lines1)) - self.assertTrue("Parameters" in lines1) - self.assertFalse("Constraints" in lines1) - self.assertFalse("Restraints" in lines1) - self.m._newParameter("z", 7) - self.m.constrain("y", "3 * z") - out2 = capture_show() - lines2 = out2.strip().split("\n") - self.assertEqual(9, len(lines2)) - self.assertTrue("Parameters" in lines2) - self.assertTrue("Constraints" in lines2) - self.assertFalse("Restraints" in lines2) - self.m.restrain("z", lb=2, ub=3, sig=0.001) - out3 = capture_show() - lines3 = out3.strip().split("\n") - self.assertEqual(13, len(lines3)) - self.assertTrue("Parameters" in lines3) - self.assertTrue("Constraints" in lines3) - self.assertTrue("Restraints" in lines3) - out4 = capture_show(pattern="x") - lines4 = out4.strip().split("\n") - self.assertEqual(9, len(lines4)) - out5 = capture_show(pattern="^") - self.assertEqual(out3, out5) - # check output with another level of hierarchy - self.m._addObject(RecipeOrganizer("foo"), self.m._containers) - self.m.foo._newParameter("bar", 13) - out6 = capture_show() - self.assertTrue("foo.bar" in out6) - # filter out foo.bar - out7 = capture_show("^(?!foo).") - self.assertFalse("foo.bar" in out7) - self.assertEqual(out3, out7) - return + +def test_show(capturestdout): + """Verify output from the show function.""" + organizer = RecipeOrganizer("test") + # Add a managed container so we can do more in-depth tests. + organizer._containers = {} + organizer._manage(organizer._containers) + + def capture_show(*args, **kwargs): + rv = capturestdout(organizer.show, *args, **kwargs) + return rv + + assert "" == capture_show() + organizer._newParameter("x", 1) + organizer._newParameter("y", 2) + out1 = capture_show() + lines1 = out1.strip().split("\n") + assert 4 == len(lines1) + assert "Parameters" in lines1 + assert "Constraints" not in lines1 + assert "Restraints" not in lines1 + organizer._newParameter("z", 7) + organizer.constrain("y", "3 * z") + out2 = capture_show() + lines2 = out2.strip().split("\n") + assert 9 == len(lines2) + assert "Parameters" in lines2 + assert "Constraints" in lines2 + assert "Restraints" not in lines2 + organizer.restrain("z", lb=2, ub=3, sig=0.001) + out3 = capture_show() + lines3 = out3.strip().split("\n") + assert 13 == len(lines3) + assert "Parameters" in lines3 + assert "Constraints" in lines3 + assert "Restraints" in lines3 + out4 = capture_show(pattern="x") + lines4 = out4.strip().split("\n") + assert 9 == len(lines4) + out5 = capture_show(pattern="^") + assert out3 == out5 + # check output with another level of hierarchy + organizer._addObject(RecipeOrganizer("foo"), organizer._containers) + organizer.foo._newParameter("bar", 13) + out6 = capture_show() + assert "foo.bar" in out6 + # filter out foo.bar + out7 = capture_show("^(?!foo).") + assert "foo.bar" not in out7 + assert out3 == out7 + return # ---------------------------------------------------------------------------- diff --git a/tests/test_sas.py b/tests/test_sas.py index 110c863e..3a7ee226 100644 --- a/tests/test_sas.py +++ b/tests/test_sas.py @@ -14,167 +14,151 @@ ############################################################################## """Tests for sas package.""" -import unittest - import numpy +import pytest from diffpy.srfit.sas import SASGenerator, SASParser, SASProfile from diffpy.srfit.sas.sasimport import sasimport -from .utils import _msg_nosas, datafile, has_sas - # ---------------------------------------------------------------------------- - - -@unittest.skipUnless(has_sas, _msg_nosas) -class TestSASParser(unittest.TestCase): - - def testParser(self): - data = datafile("sas_ascii_test_1.txt") - parser = SASParser() - parser.parseFile(data) - - x, y, dx, dy = parser.getData() - - testx = numpy.array( - [ - 0.002618, - 0.007854, - 0.01309, - 0.01832, - 0.02356, - 0.02879, - 0.03402, - 0.03925, - 0.04448, - 0.0497, - ] - ) - diff = testx - x - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testy = numpy.array( - [ - 0.02198, - 0.02201, - 0.02695, - 0.02645, - 0.03024, - 0.3927, - 7.305, - 17.43, - 13.43, - 8.346, - ] - ) - diff = testy - y - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testdy = numpy.array( - [ - 0.002704, - 0.001643, - 0.002452, - 0.001769, - 0.001531, - 0.1697, - 1.006, - 0.5351, - 0.3677, - 0.191, - ] - ) - diff = testdy - dy - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - testdx = numpy.array( - [ - 0.0004091, - 0.005587, - 0.005598, - 0.005624, - 0.005707, - 0.005975, - 0.006264, - 0.006344, - 0.006424, - 0.006516, - ] - ) - diff = testdx - dx - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - return +# FIXME: adjust sensitivity of the pytest.approx statements when ready to test +# with sasview installed. + + +def testParser(sas_available, datafile): + if not sas_available: + pytest.skip("sas package not available") + + data = datafile("sas_ascii_test_1.txt") + parser = SASParser() + parser.parseFile(data) + x, y, dx, dy = parser.getData() + testx = numpy.array( + [ + 0.002618, + 0.007854, + 0.01309, + 0.01832, + 0.02356, + 0.02879, + 0.03402, + 0.03925, + 0.04448, + 0.0497, + ] + ) + diff = testx - x + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testy = numpy.array( + [ + 0.02198, + 0.02201, + 0.02695, + 0.02645, + 0.03024, + 0.3927, + 7.305, + 17.43, + 13.43, + 8.346, + ] + ) + diff = testy - y + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testdy = numpy.array( + [ + 0.002704, + 0.001643, + 0.002452, + 0.001769, + 0.001531, + 0.1697, + 1.006, + 0.5351, + 0.3677, + 0.191, + ] + ) + diff = testdy - dy + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + + testdx = numpy.array( + [ + 0.0004091, + 0.005587, + 0.005598, + 0.005624, + 0.005707, + 0.005975, + 0.006264, + 0.006344, + 0.006424, + 0.006516, + ] + ) + diff = testdx - dx + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + return # End of class TestSASParser -# ---------------------------------------------------------------------------- - -@unittest.skipUnless(has_sas, _msg_nosas) -class TestSASGenerator(unittest.TestCase): - - def testGenerator(self): - - # Test generator output - SphereModel = sasimport("sas.models.SphereModel").SphereModel - model = SphereModel() - gen = SASGenerator("sphere", model) - - for pname in model.params: - defval = model.getParam(pname) - par = gen.get(pname) - self.assertEqual(defval, par.getValue()) - # Test setting values - par.setValue(1.0) - self.assertEqual(1.0, par.getValue()) - self.assertEqual(1.0, model.getParam(pname)) - par.setValue(defval) - self.assertEqual(defval, par.getValue()) - self.assertEqual(defval, model.getParam(pname)) - - r = numpy.arange(1, 10, 0.1, dtype=float) - y = gen(r) - refy = model.evalDistribution(r) - diff = y - refy - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - - return - - def testGenerator2(self): - - # Test generator with a profile - EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel - model = EllipsoidModel() - gen = SASGenerator("ellipsoid", model) - - # Load the data using SAS tools - Loader = sasimport("sas.dataloader.loader").Loader - loader = Loader() - data = datafile("sas_ellipsoid_testdata.txt") - datainfo = loader.load(data) - profile = SASProfile(datainfo) - - gen.setProfile(profile) - gen.scale.value = 1.0 - gen.radius_a.value = 20 - gen.radius_b.value = 400 - gen.background.value = 0.01 - - y = gen(profile.xobs) - diff = profile.yobs - y - res = numpy.dot(diff, diff) - self.assertAlmostEqual(0, res) - return - - -# End of class TestSASGenerator - -if __name__ == "__main__": - unittest.main() +def test_generator(sas_available): + if not sas_available: + pytest.skip("sas package not available") + SphereModel = sasimport("sas.models.SphereModel").SphereModel + model = SphereModel() + gen = SASGenerator("sphere", model) + for pname in model.params: + defval = model.getParam(pname) + par = gen.get(pname) + assert defval == par.getValue() + # Test setting values + par.setValue(1.0) + assert 1.0 == par.getValue() + assert 1.0 == model.getParam(pname) + par.setValue(defval) + assert defval == par.getValue() + assert defval == model.getParam(pname) + + r = numpy.arange(1, 10, 0.1, dtype=float) + y = gen(r) + refy = model.evalDistribution(r) + diff = y - refy + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + return + + +def testGenerator2(sas_available, datafile): + if not sas_available: + pytest.skip("sas package not available") + EllipsoidModel = sasimport("sas.models.EllipsoidModel").EllipsoidModel + model = EllipsoidModel() + gen = SASGenerator("ellipsoid", model) + + # Load the data using SAS tools + Loader = sasimport("sas.dataloader.loader").Loader + loader = Loader() + data = datafile("sas_ellipsoid_testdata.txt") + datainfo = loader.load(data) + profile = SASProfile(datainfo) + + gen.setProfile(profile) + gen.scale.value = 1.0 + gen.radius_a.value = 20 + gen.radius_b.value = 400 + gen.background.value = 0.01 + + y = gen(profile.xobs) + diff = profile.yobs - y + res = numpy.dot(diff, diff) + assert 0 == pytest.approx(res) + return diff --git a/tests/test_sgconstraints.py b/tests/test_sgconstraints.py index f057ab5d..f0696321 100644 --- a/tests/test_sgconstraints.py +++ b/tests/test_sgconstraints.py @@ -17,189 +17,189 @@ import unittest import numpy - -from .utils import ( - _msg_nopyobjcryst, - _msg_nostructure, - datafile, - has_pyobjcryst, - has_structure, -) +import pytest # ---------------------------------------------------------------------------- -class TestSGConstraints(unittest.TestCase): - - @unittest.skipUnless(has_pyobjcryst, _msg_nopyobjcryst) - def test_ObjCryst_constrainSpaceGroup(self): - """Make sure that all Parameters are constrained properly. - - This tests constrainSpaceGroup from - diffpy.srfit.structure.sgconstraints, which is performed - automatically when an ObjCrystCrystalParSet is created. - """ - from diffpy.srfit.structure.objcrystparset import ObjCrystCrystalParSet - - pi = numpy.pi - - occryst = makeLaMnO3() - stru = ObjCrystCrystalParSet(occryst.GetName(), occryst) - # Make sure we actually create the constraints - stru._constrainSpaceGroup() - # Make the space group parameters individually - stru.sgpars.latpars - stru.sgpars.xyzpars - stru.sgpars.adppars - - # Check the orthorhombic lattice - l = stru.getLattice() - self.assertTrue(l.alpha.const) - self.assertTrue(l.beta.const) - self.assertTrue(l.gamma.const) - self.assertEqual(pi / 2, l.alpha.getValue()) - self.assertEqual(pi / 2, l.beta.getValue()) - self.assertEqual(pi / 2, l.gamma.getValue()) - - self.assertFalse(l.a.const) - self.assertFalse(l.b.const) - self.assertFalse(l.c.const) - self.assertEqual(0, len(l._constraints)) - - # Now make sure the scatterers are constrained properly - scatterers = stru.getScatterers() - la = scatterers[0] - self.assertFalse(la.x.const) - self.assertFalse(la.y.const) - self.assertTrue(la.z.const) - self.assertEqual(0, len(la._constraints)) - - mn = scatterers[1] - self.assertTrue(mn.x.const) - self.assertTrue(mn.y.const) - self.assertTrue(mn.z.const) - self.assertEqual(0, len(mn._constraints)) - - o1 = scatterers[2] - self.assertFalse(o1.x.const) - self.assertFalse(o1.y.const) - self.assertTrue(o1.z.const) - self.assertEqual(0, len(o1._constraints)) - - o2 = scatterers[3] - self.assertFalse(o2.x.const) - self.assertFalse(o2.y.const) - self.assertFalse(o2.z.const) - self.assertEqual(0, len(o2._constraints)) - - # Make sure we can't constrain these - self.assertRaises(ValueError, mn.constrain, mn.x, "y") - self.assertRaises(ValueError, mn.constrain, mn.y, "z") - self.assertRaises(ValueError, mn.constrain, mn.z, "x") - - # Nor can we make them into variables - from diffpy.srfit.fitbase.fitrecipe import FitRecipe - - f = FitRecipe() - self.assertRaises(ValueError, f.addVar, mn.x) - - return - - @unittest.skipUnless(has_structure, _msg_nostructure) - def test_DiffPy_constrainAsSpaceGroup(self): - """Test the constrainAsSpaceGroup function.""" - from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet - from diffpy.srfit.structure.sgconstraints import constrainAsSpaceGroup - - stru = makeLaMnO3_P1() - parset = DiffpyStructureParSet("LaMnO3", stru) - - sgpars = constrainAsSpaceGroup( - parset, - "P b n m", - scatterers=parset.getScatterers()[::2], - constrainadps=True, - ) - - # Make sure that the new parameters were created - for par in sgpars: - self.assertNotEqual(None, par) - self.assertNotEqual(None, par.getValue()) - - # Test the unconstrained atoms - for scatterer in parset.getScatterers()[1::2]: - self.assertFalse(scatterer.x.const) - self.assertFalse(scatterer.y.const) - self.assertFalse(scatterer.z.const) - self.assertFalse(scatterer.U11.const) - self.assertFalse(scatterer.U22.const) - self.assertFalse(scatterer.U33.const) - self.assertFalse(scatterer.U12.const) - self.assertFalse(scatterer.U13.const) - self.assertFalse(scatterer.U23.const) - self.assertEqual(0, len(scatterer._constraints)) - - proxied = [p.par for p in sgpars] - - def _consttest(par): - return par.const - - def _constrainedtest(par): - return par.constrained - - def _proxytest(par): - return par in proxied - - def _alltests(par): - return _consttest(par) or _constrainedtest(par) or _proxytest(par) - - for idx, scatterer in enumerate(parset.getScatterers()[::2]): - # Under this scheme, atom 6 is free to vary - test = False - for par in [scatterer.x, scatterer.y, scatterer.z]: - test |= _alltests(par) - self.assertTrue(test) - - test = False - for par in [ - scatterer.U11, - scatterer.U22, - scatterer.U33, - scatterer.U12, - scatterer.U13, - scatterer.U23, - ]: - test |= _alltests(par) - - self.assertTrue(test) - - return - - @unittest.skipUnless(has_structure, _msg_nostructure) - def test_ConstrainAsSpaceGroup_args(self): - """Test the arguments processing of constrainAsSpaceGroup function.""" - from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet - from diffpy.srfit.structure.sgconstraints import constrainAsSpaceGroup - from diffpy.structure.spacegroups import GetSpaceGroup - - stru = makeLaMnO3_P1() - parset = DiffpyStructureParSet("LaMnO3", stru) - sgpars = constrainAsSpaceGroup(parset, "P b n m") - sg = GetSpaceGroup("P b n m") - parset2 = DiffpyStructureParSet("LMO", makeLaMnO3_P1()) - sgpars2 = constrainAsSpaceGroup(parset2, sg) - list(sgpars) - list(sgpars2) - self.assertEqual(sgpars.names, sgpars2.names) - return - - -# End of class TestSGConstraints - -# Local helper functions ----------------------------------------------------- - - -def makeLaMnO3_P1(): +def test_ObjCryst_constrainSpaceGroup(pyobjcryst_available): + """Make sure that all Parameters are constrained properly. + + This tests constrainSpaceGroup from + diffpy.srfit.structure.sgconstraints, which is performed + automatically when an ObjCrystCrystalParSet is created. + """ + if not pyobjcryst_available: + pytest.skip("pyobjcrysta package not available") + + from diffpy.srfit.structure.objcrystparset import ObjCrystCrystalParSet + + pi = numpy.pi + + occryst = makeLaMnO3() + stru = ObjCrystCrystalParSet(occryst.GetName(), occryst) + # Make sure we actually create the constraints + stru._constrainSpaceGroup() + # Make the space group parameters individually + stru.sgpars.latpars + stru.sgpars.xyzpars + stru.sgpars.adppars + + # Check the orthorhombic lattice + lattice = stru.getLattice() + assert lattice.alpha.const + assert lattice.beta.const + assert lattice.gamma.const + assert pi / 2 == lattice.alpha.getValue() + assert pi / 2 == lattice.beta.getValue() + assert pi / 2 == lattice.gamma.getValue() + + assert not lattice.a.const + assert not lattice.b.const + assert not lattice.c.const + assert 0 == len(lattice._constraints) + + # Now make sure the scatterers are constrained properly + scatterers = stru.getScatterers() + la = scatterers[0] + assert not la.x.const + assert not la.y.const + assert la.z.const + assert 0 == len(la._constraints) + + mn = scatterers[1] + assert mn.x.const + assert mn.y.const + assert mn.z.const + assert 0 == len(mn._constraints) + + o1 = scatterers[2] + assert not o1.x.const + assert not o1.y.const + assert o1.z.const + assert 0 == len(o1._constraints) + + o2 = scatterers[3] + assert not o2.x.const + assert not o2.y.const + assert not o2.z.const + assert 0 == len(o2._constraints) + + # Make sure we can't constrain these + with pytest.raises(ValueError): + mn.constrain(mn.x, "y") + + with pytest.raises(ValueError): + mn.constrain(mn.y, "z") + + with pytest.raises(ValueError): + mn.constrain(mn.z, "x") + + # Nor can we make them into variables + from diffpy.srfit.fitbase.fitrecipe import FitRecipe + + f = FitRecipe() + with pytest.raises(ValueError): + f.addVar(mn.x) + + return + + +def test_DiffPy_constrainAsSpaceGroup(datafile, pyobjcryst_available): + """Test the constrainAsSpaceGroup function.""" + if not pyobjcryst_available: + pytest.skip("pyobjcrysta package not available") + + from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet + from diffpy.srfit.structure.sgconstraints import constrainAsSpaceGroup + + stru = makeLaMnO3_P1(datafile) + parset = DiffpyStructureParSet("LaMnO3", stru) + + sgpars = constrainAsSpaceGroup( + parset, + "P b n m", + scatterers=parset.getScatterers()[::2], + constrainadps=True, + ) + + # Make sure that the new parameters were created + for par in sgpars: + assert par is not None + assert par.getValue() is not None + + # Test the unconstrained atoms + for scatterer in parset.getScatterers()[1::2]: + assert not scatterer.x.const + assert not scatterer.y.const + assert not scatterer.z.const + assert not scatterer.U11.const + assert not scatterer.U22.const + assert not scatterer.U33.const + assert not scatterer.U12.const + assert not scatterer.U13.const + assert not scatterer.U23.const + assert 0 == len(scatterer._constraints) + + proxied = [p.par for p in sgpars] + + def _consttest(par): + return par.const + + def _constrainedtest(par): + return par.constrained + + def _proxytest(par): + return par in proxied + + def _alltests(par): + return _consttest(par) or _constrainedtest(par) or _proxytest(par) + + for idx, scatterer in enumerate(parset.getScatterers()[::2]): + # Under this scheme, atom 6 is free to vary + test = False + for par in [scatterer.x, scatterer.y, scatterer.z]: + test |= _alltests(par) + assert test + + test = False + for par in [ + scatterer.U11, + scatterer.U22, + scatterer.U33, + scatterer.U12, + scatterer.U13, + scatterer.U23, + ]: + test |= _alltests(par) + + assert test + + return + + +def test_ConstrainAsSpaceGroup_args(pyobjcryst_available, datafile): + """Test the arguments processing of constrainAsSpaceGroup function.""" + if not pyobjcryst_available: + pytest.skip("pyobjcrysta package not available") + + from diffpy.srfit.structure.diffpyparset import DiffpyStructureParSet + from diffpy.srfit.structure.sgconstraints import constrainAsSpaceGroup + from diffpy.structure.spacegroups import GetSpaceGroup + + stru = makeLaMnO3_P1(datafile) + parset = DiffpyStructureParSet("LaMnO3", stru) + sgpars = constrainAsSpaceGroup(parset, "P b n m") + sg = GetSpaceGroup("P b n m") + parset2 = DiffpyStructureParSet("LMO", makeLaMnO3_P1(datafile)) + sgpars2 = constrainAsSpaceGroup(parset2, sg) + list(sgpars) + list(sgpars2) + assert sgpars.names == sgpars2.names + return + + +def makeLaMnO3_P1(datafile): from diffpy.structure import Structure stru = Structure() diff --git a/tests/test_speed.py b/tests/test_speed.py index 19af219b..6bc75014 100644 --- a/tests/test_speed.py +++ b/tests/test_speed.py @@ -14,8 +14,6 @@ ############################################################################## """Tests for refinableobj module.""" -from __future__ import print_function - import random import numpy @@ -23,16 +21,14 @@ import diffpy.srfit.equation.literals as literals import diffpy.srfit.equation.visitors as visitors -from .utils import _makeArgs - x = numpy.arange(0, 20, 0.05) -def makeLazyEquation(): +def makeLazyEquation(make_args): """Make a lazy equation and see how fast it is.""" # Make some variables - v1, v2, v3, v4, v5, v6, v7 = _makeArgs(7) + v1, v2, v3, v4, v5, v6, v7 = make_args(7) # Make some operations mult = literals.MultiplicationOperator() @@ -463,14 +459,17 @@ def profileTest(): return -if __name__ == "__main__": - for i in range(1, 13): - speedTest2(i) - """ - for i in range(1, 9): - weightedTest(i) - """ - """From diffpy.srfit.equation.builder import EquationFactory import random - import cProfile cProfile.run('profileTest()', 'prof') import pstats p = - pstats.Stats('prof') p.strip_dirs() p.sort_stats('time') p.print_stats(10) - profileTest()""" +# if __name__ == "__main__": +# for i in range(1, 13): +# speedTest2(i) +# """ +# for i in range(1, 9): +# weightedTest(i) +# """ +# """From diffpy.srfit.equation.builder import +# EquationFactory import random +# import cProfile cProfile.run('profileTest()', 'prof') +# import pstats p = +# pstats.Stats('prof') p.strip_dirs() p.sort_stats('time') +# p.print_stats(10) +# profileTest()""" diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 58a7ffcd..546edaaa 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -16,19 +16,22 @@ import unittest +import pytest + import diffpy.srfit.equation.literals as literals import diffpy.srfit.equation.visitors as visitors -from .utils import _makeArgs - -class TestValidator(unittest.TestCase): +class TestValidator: + @pytest.fixture(autouse=True) + def setup(self, make_args): + self.make_args = make_args def testSimpleFunction(self): """Test a simple function.""" # Make some variables - v1, v2, v3, v4 = _makeArgs(4) + v1, v2, v3, v4 = self.make_args(4) # Make some operations mult = literals.MultiplicationOperator() @@ -51,25 +54,25 @@ def testSimpleFunction(self): # Now validate validator = visitors.Validator() mult.identify(validator) - self.assertEqual(4, len(validator.errors)) + assert 4 == len(validator.errors) # Fix the equation minus.addLiteral(v3) validator.reset() mult.identify(validator) - self.assertEqual(3, len(validator.errors)) + assert 3 == len(validator.errors) # Fix the name of plus plus.name = "add" validator.reset() mult.identify(validator) - self.assertEqual(2, len(validator.errors)) + assert 2 == len(validator.errors) # Fix the symbol of plus plus.symbol = "+" validator.reset() mult.identify(validator) - self.assertEqual(1, len(validator.errors)) + assert 1 == len(validator.errors) # Fix the operation of plus import numpy @@ -77,24 +80,27 @@ def testSimpleFunction(self): plus.operation = numpy.add validator.reset() mult.identify(validator) - self.assertEqual(0, len(validator.errors)) + assert 0 == len(validator.errors) # Add another literal to minus minus.addLiteral(v1) validator.reset() mult.identify(validator) - self.assertEqual(1, len(validator.errors)) + assert 1 == len(validator.errors) return -class TestArgFinder(unittest.TestCase): +class TestArgFinder: + @pytest.fixture(autouse=True) + def setup(self, make_args): + self.make_args = make_args def testSimpleFunction(self): """Test a simple function.""" # Make some variables - v1, v2, v3, v4 = _makeArgs(4) + v1, v2, v3, v4 = self.make_args(4) # Make some operations mult = literals.MultiplicationOperator() @@ -118,33 +124,36 @@ def testSimpleFunction(self): # now get the args args = visitors.getArgs(mult) - self.assertEqual(4, len(args)) - self.assertTrue(v1 in args) - self.assertTrue(v2 in args) - self.assertTrue(v3 in args) - self.assertTrue(v4 in args) + assert 4 == len(args) + assert v1 in args + assert v2 in args + assert v3 in args + assert v4 in args return def testArg(self): """Test just an Argument equation.""" # Make some variables - v1 = _makeArgs(1)[0] + v1 = self.make_args(1)[0] args = visitors.getArgs(v1) - self.assertEqual(1, len(args)) - self.assertTrue(args[0] is v1) + assert 1 == len(args) + assert args[0] == v1 return -class TestSwapper(unittest.TestCase): +class TestSwapper: + @pytest.fixture(autouse=True) + def setup(self, make_args): + self.make_args = make_args def testSimpleFunction(self): """Test a simple function.""" # Make some variables - v1, v2, v3, v4, v5 = _makeArgs(5) + v1, v2, v3, v4, v5 = self.make_args(5) # Make some operations mult = literals.MultiplicationOperator() @@ -168,43 +177,44 @@ def testSimpleFunction(self): v5.setValue(5) # Evaluate - self.assertEqual(8, mult.value) + assert 8 == mult.value # Now swap an argument visitors.swap(mult, v2, v5) # Check that the operator value is invalidated - self.assertTrue(mult._value is None) - self.assertFalse(v2.hasObserver(minus._flush)) - self.assertTrue(v5.hasObserver(minus._flush)) + assert mult._value is None + assert not v2.hasObserver(minus._flush) + assert v5.hasObserver(minus._flush) # now get the args args = visitors.getArgs(mult) - self.assertEqual(4, len(args)) - self.assertTrue(v1 in args) - self.assertTrue(v2 not in args) - self.assertTrue(v3 in args) - self.assertTrue(v4 in args) - self.assertTrue(v5 in args) + assert 4 == len(args) + assert v1 in args + assert v2 not in args + assert v3 in args + assert v4 in args + assert v5 in args # Re-evaluate (1+3)*(4-5) = -4 - self.assertEqual(-4, mult.value) + assert -4 == mult.value # Swap out the "-" operator plus2 = literals.AdditionOperator() visitors.swap(mult, minus, plus2) - self.assertTrue(mult._value is None) - self.assertFalse(minus.hasObserver(mult._flush)) - self.assertTrue(plus2.hasObserver(mult._flush)) + assert mult._value is None + assert not minus.hasObserver(mult._flush) + assert plus2.hasObserver(mult._flush) # plus2 has no arguments yet. Verify this. - self.assertRaises(TypeError, mult.getValue) + with pytest.raises(TypeError): + mult.getValue() # Add the arguments to plus2. plus2.addLiteral(v4) plus2.addLiteral(v5) # Re-evaluate (1+3)*(4+5) = 36 - self.assertEqual(36, mult.value) + assert 36 == mult.value return diff --git a/tests/utils.py b/tests/utils.py index 79a77353..4574cf27 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -20,93 +20,11 @@ import diffpy.srfit.equation.literals as literals from diffpy.srfit.sas.sasimport import sasimport - -from . import logger - -# Resolve availability of optional third-party packages. - -# srfit-sasview or sasview - -try: - _msg_nosas = "No module named 'sas.pr.invertor'" - sasimport("sas.pr.invertor") - _msg_nosas = "No module named 'sas.models'" - sasimport("sas.models") - has_sas = True -except ImportError as e: - has_sas = False - logger.warning("%s, SaS tests skipped.", e) - -# diffpy.structure - -_msg_nostructure = "No module named 'diffpy.structure'" -try: - import diffpy.structure as m - - del m - has_structure = True -except ImportError: - has_structure = False - logger.warning("Cannot import diffpy.structure, Structure tests skipped.") - -# pyobjcryst - -_msg_nopyobjcryst = "No module named 'pyobjcryst'" -try: - import pyobjcryst as m - - del m - has_pyobjcryst = True -except ImportError: - has_pyobjcryst = False - logger.warning("Cannot import pyobjcryst, pyobjcryst tests skipped.") - -# diffpy.srreal - -_msg_nosrreal = "No module named 'diffpy.srreal'" -try: - import diffpy.srreal.pdfcalculator as m - - del m - has_srreal = True -except ImportError: - has_srreal = False - logger.warning("Cannot import diffpy.srreal, PDF tests skipped.") +from tests import logger # Helper functions for testing ----------------------------------------------- -def _makeArgs(num): - args = [] - for i in range(num): - j = i + 1 - args.append(literals.Argument(name="v%i" % j, value=j)) - return args - - -def noObserversInGlobalBuilders(): - """True if no observer function leaks to global builder objects. - - Ensure objects are not immortal due to a reference from static - value. - """ - from diffpy.srfit.equation.builder import _builders - - rv = True - for n, b in _builders.items(): - if b.literal and b.literal._observers: - rv = False - break - return rv - - -def datafile(filename): - from pkg_resources import resource_filename - - rv = resource_filename(__name__, "testdata/" + filename) - return rv - - def capturestdout(f, *args, **kwargs): """Capture the standard output from a call of function f.""" savestdout = sys.stdout