Skip to content

Commit

Permalink
[refactor]: move operatives to protools in core as they must prove th…
Browse files Browse the repository at this point in the history
…emselves useful before placement in a public module (like tools) of openseize, passing pylint, mypy, doctest, codespell
  • Loading branch information
mscaudill committed Aug 29, 2023
1 parent 212c85b commit e4b64c6
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 395 deletions.
127 changes: 105 additions & 22 deletions src/openseize/core/protools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from typing import Optional, Tuple, Union
from functools import partial
from itertools import zip_longest

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -189,16 +190,17 @@ def multiply_along_axis(pro: Producer,
arr: npt.NDArray,
axis: int,
) -> Producer:
"""Multiplies each produced array of a producer by a 1-D array along a
single axis.
"""Multiplies produced arrays by a 1-D array along a single axis.
Args:
pro:
A producer of ndarrays to be multiplied along axis.
arr:
A 1-D array whose length must match producers shape along axis.
A 1-D array whose length must match producers length along a single
axis.
axis:
The axis along which to multiply.
The axis along which to multiply. This function supports
multiplication along any single axis including the production axis.
Examples:
>>> x = np.arange(10000).reshape(2, 4, 1250)
Expand All @@ -214,46 +216,127 @@ def multiply_along_axis(pro: Producer,
"""

arr = np.array(arr)

# FIXME I need to take care of when multiplication is along producing axis
if axis == pro.axis:
if len(arr) != pro.shape[pro.axis]:
msg = ('Multiplication along the production axis requires '
'length of arr to match chunksize {} != {}.')
raise ValueError(msg.format(len(arr), pro.chunksize))
if arr.ndim > 1:
raise ValueError('Dimensions of multiplier arr must be exactly 1.')

# ensure the arr shape matches the producers shape along axis
elif len(arr) != pro.shape[axis]:
if len(arr) != pro.shape[axis]:
msg = 'operands could not be broadcast together with shapes {} {}'
raise ValueError(msg.format(pro.shape, arr.shape))

# reshape the input array to be broadcastable with produced arrays
ndims = len(pro.shape)
shape = np.ones(ndims, dtype=int)
shape[axis] = len(arr)
y = arr.reshape(shape)
x = arr.reshape(shape) #type: Union[npt.NDArray, Producer]

# if multiplying along pro axis convert arr 'x' to producer
if axis == pro.axis:
x = producer(x, chunksize=pro.chunksize, axis=pro.axis)

func = partial(_multiply_gen, pro, y)
func = partial(_multiply_gen, pro, x)
return producer(func, chunksize=pro.chunksize, axis=pro.axis,
shape=pro.shape)


def _multiply_gen(pro, arr):
def _multiply_gen(pro, multiplier):
"""A generating helper function that multiplies produced arrays by an
ndarray.
ndarray or producer of ndarrays.
This helper function is a generating function (not a producer) and is not
intended to be called externally.
intended to be called externally. It assumes that multipliers shape is
broadcastable to producers shape.
Args:
pro:
A producer of ndarrays.
arr:
An ndarray of the same dims as each produced array.
multiplier:
An ndarray or a producer of ndarrays. The number of dims of this
object must match the dims of pro and have shape of 1 along all axes
except 1 axis whose length must equal the length of the producer
along this axis.
Yields:
The element-wise product of each produced array with arr.
The element-wise product of each produced array with multiplier.
"""

# non-production axis multiplication factors
factors = zip_longest(pro, multiplier, fillvalue=multiplier)

# production axis multiplication factors
if isinstance(multiplier, Producer):
factors = zip(pro, multiplier)

for arr, mult in factors:
yield arr * mult


def slice_along_axis(pro: Producer,
start: Optional[int] = None,
stop: Optional[int] = None,
step: Optional[int] = None,
axis: int = -1,
) -> Producer:
"""Returns a producer producing values between start and stop in step
increments along axis.
Args:
pro:
A producer instance to slice along axis.
start:
The start index of the slice along axis. If None, slice will start
at 0.
stop:
The stop index of the slice along axis. If None slice will extend to
last element(s) of producer along axis.
step:
The size of index steps between start and stop of slice.
axis:
The axis of the producer to be sliced.
Examples:
>>> x = np.random.random((4,10000))
>>> pro = producer(x, chunksize=1000, axis=-1)
>>> sliced_pro = slice_along_axis(pro, 100, 200)
>>> np.allclose(x[:,100:200], sliced_pro.to_array())
True
Returns:
A producer of ndarrays.
"""

for x in pro:
yield x * arr
# get start, stop, step indices for the slicing axis
start, stop, step = slice(start, stop, step).indices(pro.shape[axis])

if axis == pro.axis:
# slicing along production axis is just masking
mask = np.zeros(pro.shape[axis], dtype=bool)
mask[start:stop:step] = True
return producer(pro, pro.chunksize, pro.axis, mask=mask)

# slicing along non-production axis changes shape of produced arrays
new_shape = list(pro.shape)
new_shape[axis] = (stop - start) // step
func = partial(_slice_along_gen, pro, start, stop, step, axis)
return producer(func, pro.chunksize, pro.axis, shape=new_shape)


def _slice_along_gen(pro, start, stop, step, axis):
"""A generating helper function for slicing a producer along
a non-production axis between start and stop in step increments.
Args:
pro:
A producer instance to slice.
start:
The start index of the slice. May be None.
stop:
The stop index of the slice. May be None.
step:
The step size between start and stop to slice with. May be None.
axis:
The non-production axis along which to slice.
"""

for arr in pro:
yield arraytools.slice_along_axis(arr, start, stop, step, axis=axis)
Loading

0 comments on commit e4b64c6

Please sign in to comment.