Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

Rolling window #31

Closed
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+38 −1
Split
View
@@ -7,7 +7,7 @@
"""
import numpy as np
-__all__ = ['broadcast_arrays']
+__all__ = ['broadcast_arrays', 'rolling_window']
class DummyArray(object):
""" Dummy object that just exists to hang __array_interface__ dictionaries
@@ -113,3 +113,40 @@ def broadcast_arrays(*args):
broadcasted = [as_strided(x, shape=sh, strides=st) for (x,sh,st) in
zip(args, shapes, strides)]
return broadcasted
+
+def rolling_window(a, window):
+ """
+ Make an ndarray with a rolling window of the last dimension
+
+ Parameters
+ ----------
+ a : array_like
+ Array to add rolling window to
+ window : int
+ Size of rolling window
+
+ Returns
+ -------
+ Array that is a view of the original array with a added dimension
+ of size w.
+
+ Examples
+ --------
+ >>> x=np.arange(10).reshape((2,5))
+ >>> np.rolling_window(x, 3)
+ array([[[0, 1, 2], [1, 2, 3], [2, 3, 4]],
+ [[5, 6, 7], [6, 7, 8], [7, 8, 9]]])
+
+ Calculate rolling mean of last dimension:
+ >>> np.mean(np.rolling_window(x, 3), -1)
+ array([[ 1., 2., 3.],
+ [ 6., 7., 8.]])
+
+ """
+ if window < 1:
+ raise ValueError, "`window` must be at least 1."
+ if window > a.shape[-1]:
+ raise ValueError, "`window` is too long."
+ shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
+ strides = a.strides + (a.strides[-1],)
+ return as_strided(a, shape=shape, strides=strides)