Skip to content

Commit

Permalink
WIP: Offload AxesArray construction logic to _AxisMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Apr 30, 2023
1 parent 47e8793 commit 647e6ec
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 94 deletions.
251 changes: 169 additions & 82 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import copy
import warnings
from typing import Collection
from typing import List
from typing import MutableMapping
from typing import Sequence
from typing import Union

import numpy as np
from sklearn.base import TransformerMixin
Expand All @@ -12,6 +13,87 @@
AxesWarning = type("AxesWarning", (SyntaxWarning,), {})


class _AxisMapping:
"""Convenience wrapper for a two-way map between axis names and
indexes.
"""

def __init__(
self,
axes: MutableMapping[str, Union[int, Sequence[int]]] = None,
in_ndim: int = 0,
):
if axes is None:
axes = {}
axes = copy.deepcopy(axes)
self.fwd_map = {}
self.reverse_map = {}
null = object()

def coerce_sequence(obj):
if isinstance(obj, Sequence):
return sorted(obj)
return [obj]

for ax_name, ax_ids in axes.items():
ax_ids = coerce_sequence(ax_ids)
self.fwd_map[ax_name] = ax_ids
for ax_id in ax_ids:
old_name = self.reverse_map.get(ax_id, null)
if old_name is not null:
raise ValueError(f"Assigned multiple definitions to axis {ax_id}")
if ax_id >= in_ndim:
raise ValueError(
f"Assigned definition to axis {ax_id}, but array only has"
f" {in_ndim} axes"
)
self.reverse_map[ax_id] = ax_name
if len(self.reverse_map) != in_ndim:
warnings.warn(
f"{len(self.reverse_map)} axes labeled for array with {in_ndim} axes",
AxesWarning,
)

@staticmethod
def _compat_axes(in_dict: dict[str, Sequence]) -> dict[str, Union[Sequence, int]]:
"""Turn single-element axis index lists into ints"""
axes = {}
for k, v in in_dict.items():
if len(v) == 1:
axes[k] = v[0]
else:
axes[k] = v
return axes

@property
def compat_axes(self):
return self._compat_axes(self.fwd_map)

def reduce(self, axis: Union[int, None] = None):
"""Create an axes dict from self with specified axis
removed and all greater axes decremented.
Arguments:
axis: the axis index to remove. By numpy ufunc convention,
axis=None (default) removes _all_ axes.
"""
if axis is None:
return {}
new_axes = copy.deepcopy(self.fwd_map)
in_ndim = len(self.reverse_map)
remove_ax_name = self.reverse_map[axis]
if len(new_axes[remove_ax_name]) == 1:
new_axes.pop(remove_ax_name)
else:
new_axes[remove_ax_name].remove(axis)
decrement_names = set()
for ax_id in range(axis + 1, in_ndim):
decrement_names.add(self.reverse_map[ax_id])
for dec_name in decrement_names:
new_axes[dec_name] = [ax_id - 1 for ax_id in new_axes[dec_name]]
return self._compat_axes(new_axes)


class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):
"""A numpy-like array that keeps track of the meaning of its axes.
Expand All @@ -30,93 +112,85 @@ class AxesArray(np.lib.mixins.NDArrayOperatorsMixin, np.ndarray):

def __new__(cls, input_array, axes):
obj = np.asarray(input_array).view(cls)
defaults = {
"ax_time": None,
"ax_coord": None,
"ax_sample": None,
"ax_spatial": [],
}
n_axes = sum(1 for k, v in axes.items() if v)
if axes is None:
return obj
axes = {}
in_ndim = len(input_array.shape)
if n_axes != in_ndim:
warnings.warn(
f"{n_axes} axes labeled for array with {in_ndim} axes", AxesWarning
)
axes = {**defaults, **axes}
listed_axes = [
el for k, v in axes.items() if isinstance(v, Collection) for el in v
]
listed_axes += [
v
for k, v in axes.items()
if not isinstance(v, Collection) and v is not None
]
_reverse_map = {}
for axis in listed_axes:
if axis >= in_ndim:
raise ValueError(
f"Assigned definition to axis {axis}, but array only has"
f" {in_ndim} axes"
)
ax_names = [ax_name for ax_name in axes if axes[ax_name] == axis]
if len(ax_names) > 1:
raise ValueError(f"Assigned multiple definitions to axis {axis}")
_reverse_map[axis] = ax_names[0]
obj.__dict__.update({**axes})
obj.__dict__["_reverse_map"] = _reverse_map
obj.__ax_map = _AxisMapping(axes, in_ndim)
return obj

def __getitem__(self, key, /):
remove_axes = []
if isinstance(key, int):
remove_axes.append(key)
if isinstance(key, Sequence):
for axis, k in enumerate(key):
if isinstance(k, int):
remove_axes.append(axis)
new_item = super().__getitem__(key)
if not isinstance(new_item, AxesArray):
return new_item
for axis in remove_axes:
ax_name = self._reverse_map[axis]
if isinstance(new_item.__dict__[ax_name], int):
new_item.__dict__[ax_name] = None
else:
new_item.__dict__[ax_name].remove(axis)
new_item._reverse_map.pop(axis)
return new_item

def __array_finalize__(self, obj) -> None:
if obj is None:
return
self._reverse_map = copy.deepcopy(getattr(obj, "_reverse_map", {}))
self.ax_time = getattr(obj, "ax_time", None)
self.ax_coord = getattr(obj, "ax_coord", None)
self.ax_sample = getattr(obj, "ax_sample", None)
self.ax_spatial = getattr(obj, "ax_spatial", [])

@property
def n_spatial(self):
return tuple(self.shape[ax] for ax in self.ax_spatial)

@property
def n_time(self):
return self.shape[self.ax_time] if self.ax_time is not None else 1

@property
def n_sample(self):
return self.shape[self.ax_sample] if self.ax_sample is not None else 1
def axes(self):
return self.__ax_map.compat_axes

@property
def n_coord(self):
return self.shape[self.ax_coord] if self.ax_coord is not None else 1
def _reverse_map(self):
return self.__ax_map.reverse_map

@property
def shape(self):
return super().shape

def __getattr__(self, name):
parts = name.split("_", 1)
if parts[0] == "ax":
return self.axes[name]
if parts[0] == "n":
fwd_map = self.__ax_map.fwd_map
shape = tuple(self.shape[ax_id] for ax_id in fwd_map["ax_" + parts[1]])
if len(shape) == 1:
return shape[0]
return shape
raise AttributeError(f"'{type(self)}' object has no attribute '{name}'")

# def __getitem__(self, key, /):
# pass
# return super().__getitem__(self, key)
# def __getitem__(self, key, /):
# remove_axes = []
# if isinstance(key, int):
# remove_axes.append(key)
# if isinstance(key, Sequence):
# for axis, k in enumerate(key):
# if isinstance(k, int):
# remove_axes.append(axis)
# new_item = super().__getitem__(key)
# if not isinstance(new_item, AxesArray):
# return new_item
# for axis in remove_axes:
# ax_name = self._reverse_map[axis]
# if isinstance(new_item.__dict__[ax_name], int):
# new_item.__dict__[ax_name] = None
# else:
# new_item.__dict__[ax_name].remove(axis)
# new_item._reverse_map.pop(axis)
# return new_item

def __array_wrap__(self, out_arr, context=None):
return super().__array_wrap__(self, out_arr, context)

def __array_finalize__(self, obj) -> None:
if obj is None: # explicit construction via super().__new__().. not called?
return
# view from numpy array, called in constructor but also tests
if all(
(
not isinstance(obj, AxesArray),
self.shape == (),
not hasattr(self, "__ax_map"),
)
):
self.__ax_map = _AxisMapping({})
# required by ravel() and view() used in numpy testing. Also for zeros_like...
elif all(
(
isinstance(obj, AxesArray),
not hasattr(self, "__ax_map"),
self.shape == obj.shape,
)
):
self.__ax_map = _AxisMapping(obj.axes, len(obj.shape))
# maybe add errors for incompatible views?

def __array_ufunc__(
self, ufunc, method, *inputs, out=None, **kwargs
): # this method is called whenever you use a ufunc
Expand Down Expand Up @@ -145,17 +219,30 @@ def __array_ufunc__(
return
if ufunc.nout == 1:
results = (results,)
results = tuple(
(AxesArray(np.asarray(result), self.__dict__) if output is None else output)
for result, output in zip(results, outputs)
)
if method == "reduce" and (
"keepdims" not in kwargs.keys() or kwargs["keepdims"] is False
):
axes = None
if kwargs["axis"] is not None:
axes = self.__ax_map.reduce(axis=kwargs["axis"])
else:
axes = self.axes
final_results = []
for result, output in zip(results, outputs):
if output is not None:
final_results.append(output)
elif axes is None:
final_results.append(result)
else:
final_results.append(AxesArray(np.asarray(result), axes))
results = tuple(final_results)
return results[0] if len(results) == 1 else results

def __array_function__(self, func, types, args, kwargs):
if func not in HANDLED_FUNCTIONS:
arr = super(AxesArray, self).__array_function__(func, types, args, kwargs)
if isinstance(arr, np.ndarray):
return AxesArray(arr, axes=self.__dict__)
return AxesArray(arr, axes=self.axes)
elif arr is not None:
return arr
return
Expand All @@ -177,7 +264,7 @@ def decorator(func):
@implements(np.concatenate)
def concatenate(arrays, axis=0):
parents = [np.asarray(obj) for obj in arrays]
ax_list = [obj.__dict__ for obj in arrays if isinstance(obj, AxesArray)]
ax_list = [obj.axes for obj in arrays if isinstance(obj, AxesArray)]
for ax1, ax2 in zip(ax_list[:-1], ax_list[1:]):
if ax1 != ax2:
raise TypeError("Concatenating >1 AxesArray with incompatible axes")
Expand Down

0 comments on commit 647e6ec

Please sign in to comment.