Skip to content

Commit

Permalink
Merge pull request #97 from convexengineering/another_plotfit_fix
Browse files Browse the repository at this point in the history
plotfit fix
  • Loading branch information
pgkirsch committed Jul 9, 2021
2 parents bd3b9ca + 48d8eeb commit 28094ed
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
15 changes: 7 additions & 8 deletions gpfit/plot_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# pylint: disable=invalid-name
# pylint: disable=too-many-locals
def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'):
def plot_fit_1d(udata, wdata, K=1, fitclass='SMA', plotspace='log'):
"Finds and plots a fit (MA or SMA) for 1D data"

cstrt, _ = fit(np.log(udata), np.log(wdata), K, fitclass)
Expand All @@ -24,11 +24,11 @@ def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'):


if fitclass == 'SMA':
wexps, = cstrt.left.exps
alpha, = list(wexps.values())
uvarkey, = cstrt.right.varkeys
A = [d[uvarkey]/alpha for d in cstrt.right.exps]
B = np.log(cstrt.right.cs) / alpha
wexps, = cstrt[0].left.exps
alpha = list(wexps.values())[0]
uvarkey, = cstrt[0].right.varkeys
A = [d[uvarkey]/alpha for d in cstrt[0].right.exps]
B = np.log(cstrt[0].right.cs) / alpha

ww = 0
for k in range(K):
Expand All @@ -50,11 +50,10 @@ def plot_fit_1d(udata, wdata, K=1, fitclass='MA', plotspace='log'):
ax.set_xlabel('u')
ax.legend(['Data'] + stringlist, loc='best')

plt.show()
return f, ax

if __name__ == "__main__":
N = 51
U = np.logspace(0, np.log10(3), N)
W = (U**2+3) / (U+1)**2
plot_fit_1d(U, W, K=2, fitclass='MA', plotspace="linear")
plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear")
3 changes: 3 additions & 0 deletions gpfit/tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from gpfit.tests import t_print_fit
TESTS += t_print_fit.TESTS

from gpfit.tests import t_plot_fit
TESTS += t_plot_fit.TESTS

from gpfit.tests import t_ex6_3
TESTS += t_ex6_3.TESTS

Expand Down
26 changes: 26 additions & 0 deletions gpfit/tests/t_plot_fit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"unit tests for gpfit.print_fit module"
import unittest
import numpy as np
from gpfit.plot_fit import plot_fit_1d


class TestPlotFit(unittest.TestCase):
"Unit tests for plot_fit_1d"

def test_plot_fit_1d(self):
N = 51
U = np.logspace(0, np.log10(3), N)
W = (U**2+3) / (U+1)**2
plot_fit_1d(U, W, K=2, fitclass='SMA', plotspace="linear")


TESTS = [TestPlotFit]

if __name__ == '__main__':
SUITE = unittest.TestSuite()
LOADER = unittest.TestLoader()

for t in TESTS:
SUITE.addTests(LOADER.loadTestsFromTestCase(t))

unittest.TextTestRunner(verbosity=2).run(SUITE)

0 comments on commit 28094ed

Please sign in to comment.