Skip to content

Commit

Permalink
add unit test for plot_fit_1d
Browse files Browse the repository at this point in the history
Also:
* Default SMA instead of MA
* Don't show plot
  • Loading branch information
pgkirsch committed Jul 9, 2021
1 parent bb25aee commit 425517a
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 3 deletions.
5 changes: 2 additions & 3 deletions gpfit/plot_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'):
def plot_fit_1d(udata, wdata, K=1, fitclass='SMA', plotspace='log'):
"Finds and plots a fit (MA or SMA) for 1D data"

cstrt, _ = fit(np.log(udata), np.log(wdata), K, fitclass)
Expand Down Expand Up @@ -50,11 +50,10 @@ def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'):
ax.set_xlabel('u')
ax.legend(['Data'] + stringlist, loc='best')

plt.show()
return f, ax

if __name__ == "__main__":
N = 51
U = np.logspace(0, np.log10(3), N)
W = (U**2+3) / (U+1)**2
plot_fit_1d(U, W, K=2, fitclass='MA', plotspace="linear")
plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear")
3 changes: 3 additions & 0 deletions gpfit/tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from gpfit.tests import t_print_fit
TESTS += t_print_fit.TESTS

from gpfit.tests import t_plot_fit
TESTS += t_plot_fit.TESTS

from gpfit.tests import t_ex6_3
TESTS += t_ex6_3.TESTS

Expand Down
26 changes: 26 additions & 0 deletions gpfit/tests/t_plot_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"unit tests for gpfit.print_fit module"
import unittest
import numpy as np
from gpfit.plot_fit import plot_fit_1d


class TestPlotFit(unittest.TestCase):
"Unit tests for plot_fit_1d"

def test_plot_fit_1d(self):
N = 51
U = np.logspace(0, np.log10(3), N)
W = (U**2+3) / (U+1)**2
plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear")


TESTS = [TestPlotFit]

if __name__ == '__main__':
SUITE = unittest.TestSuite()
LOADER = unittest.TestLoader()

for t in TESTS:
SUITE.addTests(LOADER.loadTestsFromTestCase(t))

unittest.TextTestRunner(verbosity=2).run(SUITE)

0 comments on commit 425517a

Please sign in to comment.