Permalink
Browse files

reordering.py, test_ch.py: fixed stacking bug and added unit test for it

  • Loading branch information...
1 parent a5fe52f commit ada4b2ed46bdc91870cedda757c0761f6b513ed3 @mattloper committed Jan 4, 2015
Showing with 11 additions and 2 deletions.
  1. +2 −2 reordering.py
  2. +9 −0 test_ch.py
View
@@ -379,7 +379,7 @@ def compute_r(self):
def everything(self):
if not hasattr(self, '_everything'):
self._everything = np.arange(self.r.size).reshape(self.r.shape)
- self._everything = np.rollaxis(self._everything, self.axis, 0)
+ self._everything = np.swapaxes(self._everything, self.axis, 0)
return self._everything
def compute_dr_wrt(self, wrt):
@@ -402,7 +402,7 @@ def compute_dr_wrt(self, wrt):
if term is wrt:
JS += [_JS]
data += [_data]
- IS += [self.everything[offset:offset+tsz].ravel()]
+ IS += [np.swapaxes(self.everything[offset:offset+tsz], self.axis, 0).ravel()]
offset += tsz
IS = np.concatenate(IS).ravel()
JS = np.concatenate(JS).ravel()
View
@@ -215,6 +215,15 @@ def test_stacking(self):
self.assertFalse(np.any(residuals1))
self.assertFalse(np.any(residuals2))
+
+ d0 = ch.array(np.arange(60).reshape((10,6)))
+ d1 = ch.vstack((d0[:4], d0[4:]))
+ d2 = ch.hstack((d1[:,:3], d1[:,3:]))
+ tmp = d2.dr_wrt(d0).todense()
+ diff = tmp - np.eye(tmp.shape[0])
+ self.assertFalse(np.any(diff.ravel()))
+
+
#def test_drs(self):
# a = ch.Ch(2)

0 comments on commit ada4b2e

Please sign in to comment.