Skip to content

Commit

Permalink
ENH NDVar.argmax()/.argmin(): axis parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
christianbrodbeck committed May 10, 2021
1 parent 1e698e7 commit eecaae1
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
67 changes: 55 additions & 12 deletions eelbrain/_data_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -3636,30 +3636,70 @@ def any(self, axis: AxisArg = None, **regions) -> Union[NDVar, Var, bool]:
"""
return self._aggregate_over_dims(axis, regions, np.any)

def argmax(self):
"""Find the index of the largest value.
def argmax(
self,
axis: Union[str, int] = None,
name: str = None,
) -> Union[float, str, tuple, NDVar, Var]:
"""Find the index of the largest value
``ndvar[ndvar.argmax()]`` is equivalent to ``ndvar.max()``.
Parameters
----------
axis
Axis along which to find the maximum (by default find the maximum
in the whole :class:`NDVar`).
name
Name of the output :class:`NDVar` (default is the current name).
Returns
-------
argmax : index | tuple
argmax
Index appropriate for the NDVar's dimensions. If NDVar has more
than one dimensions, a tuple of indices.
"""
if axis is not None:
if isinstance(axis, str):
axis = self.get_axis(axis)
dim = self.dims[axis]
x = np.argmax(self.x, axis)
x = dim._dim_index(x)
dims = [dim_ for i, dim_ in enumerate(self.dims) if i != axis]
return self._package_aggregated_output(x, dims, name)
return self._dim_index_unravel(self.x.argmax())

def argmin(self):
"""Find the index of the smallest value.
def argmin(
self,
axis: Union[str, int] = None,
name: str = None,
) -> Union[float, str, tuple, NDVar, Var]:
"""Find the index of the smallest value
``ndvar[ndvar.argmin()]`` is equivalent to ``ndvar.min()``.
Parameters
----------
axis
Axis along which to find the minimum (by default find the minimum
in the whole :class:`NDVar`).
name
Name of the output :class:`NDVar` (default is the current name).
Returns
-------
argmin : index | tuple
argmin
Index appropriate for the NDVar's dimensions. If NDVar has more
than one dimensions, a tuple of indices.
"""
if axis is not None:
if isinstance(axis, str):
axis = self.get_axis(axis)
dim = self.dims[axis]
x = np.argmin(self.x, axis)
x = dim._dim_index(x)
dims = [dim_ for i, dim_ in enumerate(self.dims) if i != axis]
return self._package_aggregated_output(x, dims, name)
return self._dim_index_unravel(self.x.argmin())

def _array_index(self, arg):
Expand Down Expand Up @@ -4647,7 +4687,7 @@ def norm(self, dim, ord=2, name=None):
mask = all_masked
x = np.ma.masked_array(x, mask)
dims = self.dims[:axis] + self.dims[axis + 1:]
return self._package_aggregated_output(x, dims, name, self.info)
return self._package_aggregated_output(x, dims, name)

def ols(self, x, name=None):
"""Sample-wise ordinary least squares regressions
Expand Down Expand Up @@ -4744,14 +4784,17 @@ def ols_t(self, x, name=None):
"dependent variable (%i)" % (len(x), len(self)))

t = stats.lm_t(self.x, x._parametrize())[2][1:] # drop intercept
return NDVar(t, (Case, *self.dims[1:]), name or self.name, self.info)
if name is None:
name = self.name
return NDVar(t, (Case, *self.dims[1:]), name, self.info)

def _package_aggregated_output(self, x, dims, name, info=None):
args = op_name(self, info=info, name=name)
ndims = len(dims)
if ndims == 0:
if len(dims) == 0:
return x
elif ndims == 1 and isinstance(dims[0], Case):
if info is None:
info = self.info
args = op_name(self, info=info, name=name)
if len(dims) == 1 and isinstance(dims[0], Case):
return Var(x, *args)
else:
return NDVar(x, dims, *args)
Expand Down
6 changes: 6 additions & 0 deletions eelbrain/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1311,6 +1311,12 @@ def test_ndvar_indexing():
assert x[('L10', 0.1)] == 20
assert x.sub(source='L10').argmax() == 0.1
assert x.sub(time=0.1).argmax() == 'L10'
# across axis
x9 = x[:9]
assert_array_equal(x9.argmax('time'), x9.x.argmax(1) * 0.01)
assert_array_equal(x9.argmin('time'), x9.x.argmin(1) * 0.01)
assert x9[0].argmax('time') == 0.04
assert x9[0].argmin('time') == 0.00

# broadcasting
u = ds[0, 'uts']
Expand Down

0 comments on commit eecaae1

Please sign in to comment.