Skip to content

Commit

Permalink
Fix kron for multiple-dimensions. kron is defined so tile(b, s) is th…
Browse files Browse the repository at this point in the history
…e same as kron(ones(s,b.dtype), b)
  • Loading branch information
teoliphant committed Oct 9, 2006
1 parent b7f719a commit 9c9f739
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
23 changes: 16 additions & 7 deletions numpy/lib/shape_base.py
@@ -1,7 +1,7 @@
__all__ = ['atleast_1d','atleast_2d','atleast_3d','vstack','hstack',
'column_stack','row_stack', 'dstack','array_split','split','hsplit',
'vsplit','dsplit','apply_over_axes','expand_dims',
'apply_along_axis', 'kron', 'tile']
'apply_along_axis', 'kron', 'tile', 'get_array_wrap']

import numpy.core.numeric as _nx
from numpy.core.numeric import asarray, zeros, newaxis, outer, \
Expand Down Expand Up @@ -526,7 +526,7 @@ def dsplit(ary,indices_or_sections):
raise ValueError, 'vsplit only works on arrays of 3 or more dimensions'
return split(ary,indices_or_sections,2)

def _getwrapper(*args):
def get_array_wrap(*args):
"""Find the wrapper for the array with the highest priority.
In case of ties, leftmost wins. If no wrapper is found, return None
Expand All @@ -547,19 +547,28 @@ def kron(a,b):
[ ... ... ],
[ a[m-1,0]*b, a[m-1,1]*b, ... , a[m-1,n-1]*b ]]
"""
wrapper = _getwrapper(a, b)
wrapper = get_array_wrap(a, b)
b = asanyarray(b)
a = array(a,copy=False,subok=True,ndmin=b.ndim)
ndb, nda = b.ndim, a.ndim
if (nda == 0 or ndb == 0):
return a * b
as = a.shape
bs = b.shape
if not a.flags.contiguous:
a = reshape(a, as)
if not b.flags.contiguous:
b = reshape(b, bs)
o = outer(a,b)
result = o.reshape(as + bs)
axis = a.ndim-1
for k in xrange(b.ndim):
nd = ndb
if (ndb != nda):
if (ndb > nda):
as = (1,)*(ndb-nda) + as
else:
bs = (1,)*(nda-ndb) + bs
nd = nda
result = outer(a,b).reshape(as+bs)
axis = nd-1
for k in xrange(nd):
result = concatenate(result, axis=axis)
if wrapper is not None:
result = wrapper(result)
Expand Down
17 changes: 14 additions & 3 deletions numpy/lib/tests/test_shape_base.py
Expand Up @@ -11,8 +11,6 @@ def check_simple(self):
a = ones((20,10),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))
def check_simple101(self,level=11):
# This test causes segmentation fault (Numeric 23.3,23.6,Python 2.3.4)
# when enabled and shape(a)[1]>100. See Issue 202.
a = ones((10,101),'d')
assert_array_equal(apply_along_axis(len,0,a),len(a)*ones(shape(a)[1]))

Expand Down Expand Up @@ -370,6 +368,7 @@ class myarray(ndarray):
assert_equal(type(kron(a,ma)), ndarray)
assert_equal(type(kron(ma,a)), myarray)


class test_tile(NumpyTestCase):
def check_basic(self):
a = array([0,1,2])
Expand All @@ -380,7 +379,19 @@ def check_basic(self):
assert_equal(tile(b, 2), [[1,2,1,2],[3,4,3,4]])
assert_equal(tile(b,(2,1)),[[1,2],[3,4],[1,2],[3,4]])
assert_equal(tile(b,(2,2)),[[1,2,1,2],[3,4,3,4],[1,2,1,2],[3,4,3,4]])


def check_kroncompare(self):
import numpy.random as nr
reps=[(2,),(1,2),(2,1),(2,2),(2,3,2),(3,2)]
shape=[(3,),(2,3),(3,4,3),(3,2,3),(4,3,2,4),(2,2)]
for s in shape:
b = nr.randint(0,10,size=s)
for r in reps:
a = ones(r, b.dtype)
large = tile(b, r)
klarge = kron(a, b)
assert_equal(large, klarge)

# Utility

def compare_results(res,desired):
Expand Down

0 comments on commit 9c9f739

Please sign in to comment.