Skip to content

Commit

Permalink
Fix issue #31.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthias-baer committed Dec 21, 2020
1 parent f896ad4 commit 05cff82
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 7 deletions.
2 changes: 2 additions & 0 deletions findiff/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ def apply(self, rhs, *args, **kwargs):
def stencil(self, shape, h=None, acc=None, old_stl=None):
if h is None and self.spac is not None:
h = self.spac
if acc is None and self.acc is not None:
acc = self.acc
return self.pds.stencil(shape, h, acc, old_stl)

def matrix(self, shape, h=None, acc=None):
Expand Down
48 changes: 41 additions & 7 deletions test/test_bugs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import sys

sys.path.insert(1, '..')

import unittest
Expand All @@ -10,9 +11,8 @@
class TestOldBugs(unittest.TestCase):

def test_findiff_should_raise_exception_when_applied_to_unevaluated_function(self):

def f(x, y):
return 5*x**2 - 5*x + 10*y**2 -10*y
return 5 * x ** 2 - 5 * x + 10 * y ** 2 - 10 * y

d_dx = FinDiff(1, 0.01)
self.assertRaises(ValueError, lambda ff: d_dx(ff), f)
Expand All @@ -31,17 +31,51 @@ def test_high_accuracy_results_in_type_error(self):

def test_matrix_repr_with_different_accs(self):
# issue 28
shape = (11, )
shape = (11,)
d1 = findiff.FinDiff(0, 1, 2).matrix(shape)
d2 = findiff.FinDiff(0, 1, 2, acc=4).matrix(shape)

self.assertTrue(np.max(np.abs((d1 - d2).toarray())) > 1)

x = np.linspace(0, 10, 11)
f = x**2
f = x ** 2
df = d2.dot(f)
np.testing.assert_almost_equal(2*np.ones_like(f), df)
np.testing.assert_almost_equal(2 * np.ones_like(f), df)

def test_accuracy_should_be_passed_down_to_stencil(self):
# issue 31

shape = 11, 11
dx = 1.
d1x = FinDiff(0, dx, 1, acc=4)
stencil1 = d1x.stencil(shape)

expected = {
('L', 'L'): {(0, 0): -2.083333333333331, (1, 0): 3.9999999999999916, (2, 0): -2.999999999999989,
(3, 0): 1.3333333333333268, (4, 0): -0.24999999999999858},
('L', 'C'): {(0, 0): -2.083333333333331, (1, 0): 3.9999999999999916, (2, 0): -2.999999999999989,
(3, 0): 1.3333333333333268, (4, 0): -0.24999999999999858},
('L', 'H'): {(0, 0): -2.083333333333331, (1, 0): 3.9999999999999916, (2, 0): -2.999999999999989,
(3, 0): 1.3333333333333268, (4, 0): -0.24999999999999858},
('C', 'L'): {(-2, 0): 0.08333333333333333, (-1, 0): -0.6666666666666666, (0, 0): 0.0,
(1, 0): 0.6666666666666666, (2, 0): -0.08333333333333333},
('C', 'C'): {(-2, 0): 0.08333333333333333, (-1, 0): -0.6666666666666666, (0, 0): 0.0,
(1, 0): 0.6666666666666666, (2, 0): -0.08333333333333333},
('C', 'H'): {(-2, 0): 0.08333333333333333, (-1, 0): -0.6666666666666666, (0, 0): 0.0,
(1, 0): 0.6666666666666666, (2, 0): -0.08333333333333333},
('H', 'L'): {(-4, 0): 0.24999999999999958, (-3, 0): -1.3333333333333313, (-2, 0): 2.9999999999999956,
(-1, 0): -3.999999999999996, (0, 0): 2.0833333333333317},
('H', 'C'): {(-4, 0): 0.24999999999999958, (-3, 0): -1.3333333333333313, (-2, 0): 2.9999999999999956,
(-1, 0): -3.999999999999996, (0, 0): 2.0833333333333317},
('H', 'H'): {(-4, 0): 0.24999999999999958, (-3, 0): -1.3333333333333313, (-2, 0): 2.9999999999999956,
(-1, 0): -3.999999999999996, (0, 0): 2.0833333333333317},
}

for char_pt in stencil1.data:
stl = stencil1.data[char_pt]
self.assertDictEqual(expected[char_pt], stl)



if __name__ == '__main__':
unittest.main()
if __name__ == '__main__':
unittest.main()

0 comments on commit 05cff82

Please sign in to comment.