Skip to content

Commit

Permalink
added tests for split method
Browse files Browse the repository at this point in the history
  • Loading branch information
maekke97 committed Apr 13, 2017
1 parent 37af089 commit a67d7aa
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
8 changes: 3 additions & 5 deletions HierMat/hmat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""hmat.py: :class:`HMat`, :func:`build_hmatrix`, :func:`recursion_build_hmatrix`, :class:`StructureWarning`
"""
import numbers
import operator

import numpy

Expand Down Expand Up @@ -148,7 +147,8 @@ def __eq__(self, other):
length = len(self.blocks)
if len(other.blocks) != length:
return False
if self.blocks != other.blocks:
eqs = [a == b for a in self.blocks for b in other.blocks]
if sum(eqs) != length: # ignore order of blocks
return False
if not isinstance(self.content, type(other.content)):
return False
Expand Down Expand Up @@ -190,8 +190,6 @@ def __add__(self, other):
return self._add_rmat(other)
elif isinstance(other, numpy.matrix):
return self._add_matrix(other)
elif isinstance(other, numbers.Number):
return self._add_matrix(numpy.matrix(other))
else:
raise NotImplementedError('unsupported operand type(s) for +: {0} and {1}'.format(type(self), type(other)))

Expand Down Expand Up @@ -219,7 +217,7 @@ def _add_hmat(self, other):
return HMat(content=self.content + other.content, shape=self.shape, root_index=self.root_index)
# if we get here, both have children
if len(self.blocks) == len(other.blocks):
blocks = map(operator.add, self.blocks, other.blocks)
blocks = [self[index] + other[index] for index in self.block_structure()]
return HMat(blocks=blocks, shape=self.shape, root_index=self.root_index)
else:
raise ValueError('can not add {0} and {1}. number of blocks is different'.format(self, other))
Expand Down
31 changes: 29 additions & 2 deletions tests/test_hmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,18 @@ def test_add(self):
addend3 = HMat(content=numpy.matrix(numpy.ones((4, 2))), shape=(4, 2), root_index=(3, 0))
addend4 = HMat(content=numpy.matrix(numpy.ones((4, 4))), shape=(4, 4), root_index=(3, 2))
addend_hmat = HMat(blocks=[addend1, addend2, addend3, addend4], shape=(7, 6), root_index=(0, 0))
splitter_mat = HMat(content=numpy.matrix(numpy.zeros((7, 6))), shape=(7, 6), root_index=(0, 0))
self.assertEqual(addend_hmat + splitter_mat, addend_hmat)
self.assertEqual(splitter_mat + addend_hmat, addend_hmat)
res = addend_hmat + self.hmat
self.assertEqual(res, addend_hmat)
self.assertRaises(ValueError, addend1.__add__, addend2)
addend = HMat(content=numpy.matrix(numpy.ones((3, 2))), shape=(3, 2), root_index=(0, 0))
self.assertRaises(ValueError, addend.__add__, addend2)
addend = HMat(content=numpy.matrix(numpy.ones((7, 6))), shape=(7, 6), root_index=(0, 0))
self.assertRaises(ValueError, addend.__add__, addend_hmat)
self.assertRaises(NotImplementedError, addend_hmat.__add__, 'bla')
self.assertRaises(NotImplementedError, addend_hmat.__add__, numpy.ones((7, 6)))
addend_hmat = HMat(content=numpy.matrix(1), shape=(1, 1), root_index=(0, 0))
self.assertEqual(addend_hmat + 0, addend_hmat)
addend_hmat = HMat(blocks=[addend1, addend2, addend3], shape=(7, 6), root_index=(0, 0))
self.assertRaises(ValueError, self.hmat.__add__, addend_hmat)
check = HMat(content=numpy.matrix(2 * numpy.ones((3, 4))), shape=(3, 4), root_index=(0, 0))
Expand All @@ -146,6 +149,16 @@ def test_add(self):
check = 2 * addend_hmat
self.assertEqual(check, res)

def test_rmat(self):
addend1 = HMat(content=numpy.matrix(numpy.ones((3, 4))), shape=(3, 4), root_index=(0, 0))
addend2 = HMat(content=numpy.matrix(numpy.ones((3, 2))), shape=(3, 2), root_index=(0, 4))
addend3 = HMat(content=numpy.matrix(numpy.ones((4, 2))), shape=(4, 2), root_index=(3, 0))
addend4 = HMat(content=numpy.matrix(numpy.ones((4, 4))), shape=(4, 4), root_index=(3, 2))
addend_hmat = HMat(blocks=[addend1, addend2, addend3, addend4], shape=(7, 6), root_index=(0, 0))
mat = numpy.matrix(numpy.zeros((7, 6)))
res = addend_hmat.__radd__(mat)
self.assertEqual(addend_hmat, res)

def test_repr(self):
check = '<HMat with {content}>'.format(content=self.hmat_lvl2.blocks)
self.assertEqual(self.hmat_lvl2.__repr__(), check)
Expand Down Expand Up @@ -245,6 +258,12 @@ def test_mul_with_hmat(self):
check_rmat = RMat(numpy.matrix(3*numpy.ones((3, 1))), right_mat=numpy.matrix(numpy.ones((3, 1))))
check = HMat(content=check_rmat, shape=(3, 3), root_index=(0, 0))
self.assertEqual(hmat * hmat1, check)
blocks = [HMat(content=numpy.matrix(1), shape=(1, 1), root_index=(i, j)) for i in xrange(3) for j in xrange(3)]
block_mat = HMat(blocks=blocks, shape=(3, 3), root_index=(0, 0))
hmat = HMat(content=numpy.matrix(numpy.ones((3, 3))), shape=(3, 3), root_index=(0, 0))
hmat1 = HMat(content=rmat, shape=(3, 3), root_index=(0, 0))
self.assertRaises(NotImplementedError, hmat.__mul__, block_mat)
self.assertEqual(hmat1 * block_mat, 3*hmat)
res1 = HMat(content=numpy.matrix(6 * numpy.ones((3, 3))), shape=(3, 3), root_index=(0, 0))
res2 = HMat(content=numpy.matrix(6 * numpy.ones((3, 2))), shape=(3, 2), root_index=(0, 3))
res3 = HMat(content=numpy.matrix(6 * numpy.ones((2, 3))), shape=(2, 3), root_index=(3, 0))
Expand All @@ -262,6 +281,14 @@ def test_mul_with_hmat(self):

def test_split(self):
self.assertRaises(NotImplementedError, self.hmat.split, self.hmat.block_structure())
check = HMat(content='bla', shape=(2, 2), root_index=(0, 0))
self.assertRaises(NotImplementedError, check.split, {(0, 0): (1, 1)})
splitter = HMat(content=RMat(numpy.matrix(numpy.ones((2, 1))), numpy.matrix(numpy.ones((2, 1)))),
shape=(2, 2), root_index=(0, 0))
check_blocks = [HMat(content=RMat(numpy.matrix(numpy.ones((1, 1))), numpy.matrix(numpy.ones((1, 1)))),
shape=(1, 1), root_index=(i, j)) for i in xrange(2) for j in xrange(2)]
check = HMat(blocks=check_blocks, shape=(2, 2), root_index=(0, 0))
self.assertEqual(splitter.split(check.block_structure()), check)

def test_build_hmatrix(self):
full_func = lambda x: numpy.matrix(numpy.ones(x.shape()))
Expand Down

0 comments on commit a67d7aa

Please sign in to comment.