Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Fetching contributors…

Cannot retrieve contributors at this time

116 lines (102 sloc) 3.98 kb
"""
Utilities that manipulate strides to achieve desirable effects.
An explanation of strides can be found in the "ndarray.rst" file in the
NumPy reference guide.
"""
import numpy as np
__all__ = ['broadcast_arrays']
class DummyArray(object):
""" Dummy object that just exists to hang __array_interface__ dictionaries
and possibly keep alive a reference to a base array.
"""
def __init__(self, interface, base=None):
self.__array_interface__ = interface
self.base = base
def as_strided(x, shape=None, strides=None):
""" Make an ndarray from the given array with the given shape and strides.
"""
interface = dict(x.__array_interface__)
if shape is not None:
interface['shape'] = tuple(shape)
if strides is not None:
interface['strides'] = tuple(strides)
return np.asarray(DummyArray(interface, base=x))
def broadcast_arrays(*args):
"""
Broadcast any number of arrays against each other.
Parameters
----------
`*args` : array_likes
The arrays to broadcast.
Returns
-------
broadcasted : list of arrays
These arrays are views on the original arrays. They are typically
not contiguous. Furthermore, more than one element of a
broadcasted array may refer to a single memory location. If you
need to write to the arrays, make copies first.
Examples
--------
>>> x = np.array([[1,2,3]])
>>> y = np.array([[1],[2],[3]])
>>> np.broadcast_arrays(x, y)
[array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]), array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])]
Here is a useful idiom for getting contiguous copies instead of
non-contiguous views.
>>> map(np.array, np.broadcast_arrays(x, y))
[array([[1, 2, 3],
[1, 2, 3],
[1, 2, 3]]), array([[1, 1, 1],
[2, 2, 2],
[3, 3, 3]])]
"""
args = map(np.asarray, args)
shapes = [x.shape for x in args]
if len(set(shapes)) == 1:
# Common case where nothing needs to be broadcasted.
return args
shapes = [list(s) for s in shapes]
strides = [list(x.strides) for x in args]
nds = [len(s) for s in shapes]
biggest = max(nds)
# Go through each array and prepend dimensions of length 1 to each of the
# shapes in order to make the number of dimensions equal.
for i in range(len(args)):
diff = biggest - nds[i]
if diff > 0:
shapes[i] = [1] * diff + shapes[i]
strides[i] = [0] * diff + strides[i]
# Chech each dimension for compatibility. A dimension length of 1 is
# accepted as compatible with any other length.
common_shape = []
for axis in range(biggest):
lengths = [s[axis] for s in shapes]
unique = set(lengths + [1])
if len(unique) > 2:
# There must be at least two non-1 lengths for this axis.
raise ValueError("shape mismatch: two or more arrays have "
"incompatible dimensions on axis %r." % (axis,))
elif len(unique) == 2:
# There is exactly one non-1 length. The common shape will take this
# value.
unique.remove(1)
new_length = unique.pop()
common_shape.append(new_length)
# For each array, if this axis is being broadcasted from a length of
# 1, then set its stride to 0 so that it repeats its data.
for i in range(len(args)):
if shapes[i][axis] == 1:
shapes[i][axis] = new_length
strides[i][axis] = 0
else:
# Every array has a length of 1 on this axis. Strides can be left
# alone as nothing is broadcasted.
common_shape.append(1)
# Construct the new arrays.
broadcasted = [as_strided(x, shape=sh, strides=st) for (x,sh,st) in
zip(args, shapes, strides)]
return broadcasted
Jump to Line
Something went wrong with that request. Please try again.