diff --git a/gpfit/plot_fit.py b/gpfit/plot_fit.py index 1dba6d7..fe28bc2 100644 --- a/gpfit/plot_fit.py +++ b/gpfit/plot_fit.py @@ -7,7 +7,7 @@ # pylint: disable=invalid-name # pylint: disable=too-many-locals -def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'): +def plot_fit_1d(udata, wdata, K=1, fitclass='SMA', plotspace='log'): "Finds and plots a fit (MA or SMA) for 1D data" cstrt, _ = fit(np.log(udata), np.log(wdata), K, fitclass) @@ -50,11 +50,10 @@ def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'): ax.set_xlabel('u') ax.legend(['Data'] + stringlist, loc='best') - plt.show() return f, ax if __name__ == "__main__": N = 51 U = np.logspace(0, np.log10(3), N) W = (U**2+3) / (U+1)**2 - plot_fit_1d(U, W, K=2, fitclass='MA', plotspace="linear") + plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear") diff --git a/gpfit/tests/run_tests.py b/gpfit/tests/run_tests.py index e3d7fe7..b9deb24 100644 --- a/gpfit/tests/run_tests.py +++ b/gpfit/tests/run_tests.py @@ -26,6 +26,9 @@ from gpfit.tests import t_print_fit TESTS += t_print_fit.TESTS +from gpfit.tests import t_plot_fit +TESTS += t_plot_fit.TESTS + from gpfit.tests import t_ex6_3 TESTS += t_ex6_3.TESTS diff --git a/gpfit/tests/t_plot_fit.py b/gpfit/tests/t_plot_fit.py new file mode 100644 index 0000000..12e6e38 --- /dev/null +++ b/gpfit/tests/t_plot_fit.py @@ -0,0 +1,26 @@ +"unit tests for gpfit.print_fit module" +import unittest +import numpy as np +from gpfit.plot_fit import plot_fit_1d + + +class TestPlotFit(unittest.TestCase): + "Unit tests for plot_fit_1d" + + def test_plot_fit_1d(self): + N = 51 + U = np.logspace(0, np.log10(3), N) + W = (U**2+3) / (U+1)**2 + plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear") + + +TESTS = [TestPlotFit] + +if __name__ == '__main__': + SUITE = unittest.TestSuite() + LOADER = unittest.TestLoader() + + for t in TESTS: + SUITE.addTests(LOADER.loadTestsFromTestCase(t)) + + unittest.TextTestRunner(verbosity=2).run(SUITE)