Skip to content

Commit

Permalink
WIP but not at a stable point
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Jun 5, 2023
1 parent 0acda3b commit a54e684
Showing 1 changed file with 103 additions and 23 deletions.
126 changes: 103 additions & 23 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List
from typing import MutableMapping
from typing import NewType
from typing import Optional
from typing import Sequence
from typing import Union

Expand All @@ -18,6 +19,11 @@
OldIndex = NewType("OldIndex", int)
KeyIndex = NewType("KeyIndex", int)
NewIndex = NewType("NewIndex", int)
# ListOrItem = list[T] | T
PartialReIndexer = tuple[KeyIndex, Optional[OldIndex], str]
CompleteReIndexer = tuple[
list[KeyIndex], Optional[list[OldIndex]], Optional[list[NewIndex]]
]


class _AxisMapping:
Expand Down Expand Up @@ -180,12 +186,48 @@ def __getattr__(self, name):
return shape
raise AttributeError(f"'{type(self)}' object has no attribute '{name}'")

def __getitem__(self, key, /):
def __getitem__(self, key: Indexer | Sequence[Indexer], /):
output = super().__getitem__(key)
if not isinstance(output, AxesArray):
return output
in_dim = self.shape
key, adv_ids = _standardize_indexer(self, key)
key, adv_inds = _standardize_indexer(self, key)
if adv_inds:
adjacent, bcast_nd, bcast_start_axis = _determine_adv_broadcasting(adv_inds)
else:
adjacent, bcast_nd, bcast_start_axis = True, 0, 0
old_index = OldIndex(0)
pindexers: list[PartialReIndexer | list[PartialReIndexer]] = []
for key_ind, indexer in enumerate(key):
if isinstance(indexer, int | slice | np.ndarray):
pindexers.append((key_ind, old_index, indexer))
old_index += 1
elif indexer is None:
pindexers.append((key_ind, [None], None))
else:
raise TypeError(
f"AxesArray indexer of type {type(indexer)} not understood"
)
if not adjacent:
_move_idxs_to_front(key, adv_inds)
adv_inds = range(len(adv_inds))
pindexers = _squeeze_to_sublist(pindexers, adv_inds)
cindexers: list[CompleteReIndexer] = []
curr_axis = 0
for pindexer in enumerate(pindexers):
if isinstance(pindexer, list): # advanced indexing bundle
bcast_idxers = _adv_broadcast_magic(key, adv_inds, pindexer)
cindexers += bcast_idxers
curr_axis += bcast_nd
elif pindexer[-1] is None:
cindexers.append((*pindexer[:-1], curr_axis))
curr_axis += 1
elif isinstance(pindexer[-1], int):
cindexers.append((*pindexer[:-1], None))
elif isinstance(pindexer[-1], slice):
cindexers.append((*pindexer[:-1], curr_axis))
curr_axis += 1

remove_axes = []
new_axes = []
leftshift = 0
Expand All @@ -197,27 +239,19 @@ def __getitem__(self, key, /):
elif isinstance(indexer, int):
remove_axes.append(key_ind - rightshift)
leftshift += 1
if adv_ids:
adv_ids = sorted(adv_ids)
if adv_inds:
adv_inds = sorted(adv_inds)
source_axis = [ # after basic indexing applied # noqa
len([id for id in range(idx_id) if key[id] is not None])
for idx_id in adv_ids
for idx_id in adv_inds
]
adv_indexers = [np.array(key[i]) for i in adv_ids] # noqa
adv_indexers = [np.array(key[i]) for i in adv_inds] # noqa
bcast_nd = np.broadcast(*adv_indexers).nd
adjacent = all(i + 1 == j for i, j in zip(adv_ids[:-1], adv_ids[1:]))
bcast_start_axis = 0 if not adjacent else min(adv_ids)
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
bcast_start_axis = 0 if not adjacent else min(adv_inds)
adv_map = {}

def _compare_bcast_shapes(result_ndim, base_shape):
"""Identify which broadcast shape axes are due to base_shape"""
return [
result_ndim - 1 - ax_id
for ax_id, length in enumerate(reversed(base_shape))
if length > 1
]

for idx_id, idxer in zip(adv_ids, adv_indexers):
for idx_id, idxer in zip(adv_inds, adv_indexers):
base_idxer_ax_name = self._reverse_map[ # count non-None keys
len([id for id in range(idx_id) if key[id] is not None])
]
Expand Down Expand Up @@ -253,7 +287,7 @@ def _compare_bcast_shapes(result_ndim, base_shape):
# otherwise moved to beginning
remove_axes.append(adv_map.keys()) # Error: remove_axis takes ints

out_obj = np.broadcast(np.array(key[i]) for i in adv_ids) # noqa
out_obj = np.broadcast(np.array(key[i]) for i in adv_inds) # noqa
pass
# mulligan structured arrays, etc.
new_map = _AxisMapping(
Expand Down Expand Up @@ -381,19 +415,19 @@ def _standardize_indexer(
Returns:
A tuple of the normalized indexer as well as the indexes of
fancy indexers
advanced indexers
"""
if not isinstance(key, tuple):
key = (key,)
if not any(ax_key is Ellipsis for ax_key in key):
key = (*key, Ellipsis)
new_key = []
fancy_inds = []
adv_inds = []
slicedim = 0
for indexer_ind, ax_key in enumerate(key):
if not isinstance(ax_key, BasicIndexer):
ax_key = np.array(ax_key)
fancy_inds.append(indexer_ind)
adv_inds.append(indexer_ind)
new_key.append(ax_key)
if isinstance(ax_key, slice | int | np.ndarray):
slicedim += 1
Expand All @@ -403,8 +437,54 @@ def _standardize_indexer(
if isinstance(v, type(Ellipsis)):
ellind = i
new_key[ellind : ellind + 1] = ellipsis_dims * (slice(None),)
fancy_inds = [ind if ind < ellind else ind + ellind for ind in fancy_inds]
return tuple(new_key), tuple(fancy_inds)
adv_inds = [ind if ind < ellind else ind + ellind for ind in adv_inds]
return tuple(new_key), tuple(adv_inds)


def _adv_broadcast_magic(*args):
raise NotImplementedError


def _compare_bcast_shapes(result_ndim: int, base_shape: tuple[int]) -> list[int]:
"""Identify which broadcast shape axes are due to base_shape
Args:
result_ndim: number of dimensions broadcast shape has
base_shape: shape of one element of broadcasting
Result:
tuple of axes in broadcast result that come from base shape
"""
return [
result_ndim - 1 - ax_id
for ax_id, length in enumerate(reversed(base_shape))
if length > 1
]


def _move_idxs_to_front(li: list, idxs: Sequence) -> None:
"""Move all items at indexes specified to the front of a list"""
front = []
for idx in reversed(idxs):
obj = li.pop(idx)
front.insert(0, obj)
li = front + li


def _determine_adv_broadcasting(
key: Indexer | Sequence[Indexer], adv_inds: Sequence[OldIndex]
) -> tuple:
"""Calculate the shape and location for the result of advanced indexing"""
adjacent = all(i + 1 == j for i, j in zip(adv_inds[:-1], adv_inds[1:]))
adv_indexers = [np.array(key[i]) for i in adv_inds]
bcast_nd = np.broadcast(*adv_indexers).nd
bcast_start_axis = 0 if not adjacent else min(adv_inds)
return adjacent, bcast_nd, bcast_start_axis


def _squeeze_to_sublist(li: list, idxs: Sequence) -> list:
"Turn contiguous elements of a list into a sub-list in the same position"
return li[: min(idxs)] + [li[idx] for idx in idxs] + li[max(idxs) :]


def comprehend_axes(x):
Expand Down

0 comments on commit a54e684

Please sign in to comment.