Skip to content

Commit

Permalink
clean: downgrade typing syntax and stdlibrary use to python 3.8
Browse files Browse the repository at this point in the history
Also gitignore env8 directory, where I keep my python3.8 environment
  • Loading branch information
Jacob-Stevens-Haas committed Jan 16, 2024
1 parent 323d115 commit cef7167
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ venv/
ENV/
env.bak/
venv.bak/
env8

# automatically generated by setuptools-scm
pysindy/version.py
Expand Down
3 changes: 2 additions & 1 deletion pysindy/differentiation/finite_difference.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List
from typing import Union

import numpy as np
Expand Down Expand Up @@ -232,7 +233,7 @@ def _accumulate(self, coeffs, x):
)

def _differentiate(
self, x: NDArray, t: Union[NDArray, float, list[float]]
self, x: NDArray, t: Union[NDArray, float, List[float]]
) -> NDArray:
"""
Apply finite difference method.
Expand Down
51 changes: 27 additions & 24 deletions pysindy/utils/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from enum import Enum
from typing import Collection
from typing import Dict
from typing import get_args
from typing import List
from typing import Literal
from typing import NewType
Expand All @@ -36,9 +37,9 @@
HANDLED_FUNCTIONS = {}

AxesWarning = type("AxesWarning", (SyntaxWarning,), {})
BasicIndexer = Union[slice, int, type(Ellipsis), type(None), str]
Indexer = BasicIndexer | NDArray | list
StandardIndexer = Union[slice, int, type(None), NDArray[np.dtype(int)]]
BasicIndexer = Union[slice, int, type(Ellipsis), None, str]
Indexer = Union[BasicIndexer, NDArray, List]
StandardIndexer = Union[slice, int, None, NDArray[np.dtype(int)]]
OldIndex = NewType("OldIndex", int) # Before moving advanced axes adajent
KeyIndex = NewType("KeyIndex", int)
NewIndex = NewType("NewIndex", int)
Expand All @@ -52,8 +53,8 @@ class Sentinels(Enum):
class _AxisMapping:
"""Convenience wrapper for a two-way map between axis names and indexes."""

fwd_map: dict[str, list[int]]
reverse_map: dict[int, str]
fwd_map: Dict[str, List[int]]
reverse_map: Dict[int, str]

def __init__(
self,
Expand Down Expand Up @@ -90,14 +91,14 @@ def coerce_sequence(obj):
)

@staticmethod
def fwd_from_names(names: List[str]) -> dict[str, Sequence[int]]:
fwd_map: dict[str, Sequence[int]] = {}
def fwd_from_names(names: List[str]) -> Dict[str, Sequence[int]]:
fwd_map: Dict[str, Sequence[int]] = {}
for ax_ind, name in enumerate(names):
_compat_dict_append(fwd_map, name, [ax_ind])
return fwd_map

@staticmethod
def _compat_axes(in_dict: dict[str, list[int]]) -> dict[str, Union[list[int], int]]:
def _compat_axes(in_dict: Dict[str, List[int]]) -> Dict[str, Union[list[int], int]]:
"""Like fwd_map, but unpack single-element axis lists"""
axes = {}
for k, v in in_dict.items():
Expand Down Expand Up @@ -269,7 +270,7 @@ def __getattr__(self, name):
return shape
raise AttributeError(f"'{type(self)}' object has no attribute '{name}'")

def __getitem__(self, key: Indexer | Sequence[Indexer], /):
def __getitem__(self, key: Union[Indexer, Sequence[Indexer]], /):
if isinstance(key, tuple):
base_indexer = tuple(None if isinstance(k, str) else k for k in key)
else:
Expand Down Expand Up @@ -542,12 +543,14 @@ def einsum(

def _join_unique_names(l_of_s: List[str]) -> str:
ordered_uniques = dict.fromkeys(l_of_s).keys()
return "_".join(ax_name.removeprefix("ax_") for ax_name in ordered_uniques)
return "_".join(
ax_name[3:] if ax_name[:3] == "ax_" else ax_name for ax_name in ordered_uniques
)


def _label_einsum_scripts(
lscripts: list[str], operands: tuple[AxesArray]
) -> list[dict[str, str]]:
lscripts: List[str], operands: tuple[AxesArray]
) -> List[dict[str, str]]:
"""Create a list of what axis name each script refers to in its operand."""
allscript_names: List[Dict[str, List[str]]] = []
for lscr, op in zip(lscripts.split(","), operands):
Expand Down Expand Up @@ -644,9 +647,9 @@ def standardize_indexer(
if not any(ax_key is Ellipsis for ax_key in key):
key = [*key, Ellipsis]

new_key: list[Indexer] = []
new_key: List[Indexer] = []
for ax_key in key:
if not isinstance(ax_key, BasicIndexer):
if not isinstance(ax_key, get_args(BasicIndexer)):
ax_key = np.array(ax_key)
if ax_key.dtype == np.dtype(np.bool_):
new_key += ax_key.nonzero()
Expand All @@ -655,15 +658,15 @@ def standardize_indexer(

new_key = _expand_indexer_ellipsis(new_key, arr.ndim)
# Can't identify position of advanced indexers before expanding ellipses
adv_inds: list[KeyIndex] = []
adv_inds: List[KeyIndex] = []
for key_ind, ax_key in enumerate(new_key):
if isinstance(ax_key, np.ndarray):
adv_inds.append(KeyIndex(key_ind))

return new_key, tuple(adv_inds)


def _expand_indexer_ellipsis(key: list[Indexer], ndim: int) -> list[Indexer]:
def _expand_indexer_ellipsis(key: List[Indexer], ndim: int) -> List[Indexer]:
"""Replace ellipsis in indexers with the appropriate amount of slice(None)"""
# [...].index errors if list contains numpy array
ellind = [ind for ind, val in enumerate(key) if val is ...][0]
Expand All @@ -686,9 +689,9 @@ def _determine_adv_broadcasting(


def _rename_broadcast_axes(
new_axes: list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]],
adv_names: list[str],
) -> list[tuple[int, str]]:
new_axes: List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]],
adv_names: List[str],
) -> List[tuple[int, str]]:
"""Normalize sentinel and NoneType names"""

def _calc_bcast_name(*names: str) -> str:
Expand All @@ -713,7 +716,7 @@ def _calc_bcast_name(*names: str) -> str:

def replace_adv_indexers(
key: Sequence[StandardIndexer],
adv_inds: list[int],
adv_inds: List[int],
bcast_start_ax: int,
bcast_nd: int,
) -> tuple[
Expand All @@ -727,9 +730,9 @@ def replace_adv_indexers(


def _apply_indexing(
key: tuple[StandardIndexer], reverse_map: dict[int, str]
key: tuple[StandardIndexer], reverse_map: Dict[int, str]
) -> tuple[
list[int], list[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], list[str]
List[int], List[tuple[int, None | str | Literal[Sentinels.ADV_NAME]]], List[str]
]:
"""Determine where axes should be removed and added
Expand Down Expand Up @@ -810,8 +813,8 @@ def wrap_axes(axes: dict, obj):


T = TypeVar("T", bound=int) # TODO: Bind to a non-sequence after type-negation PEP
ItemOrList = Union[T, list[T]]
CompatDict = dict[str, ItemOrList[T]]
ItemOrList = Union[T, List[T]]
CompatDict = Dict[str, ItemOrList[T]]


def _compat_dict_append(
Expand Down

0 comments on commit cef7167

Please sign in to comment.