Skip to content

Commit

Permalink
More typing fixes (#374)
Browse files Browse the repository at this point in the history
  • Loading branch information
facelessuser committed Nov 26, 2023
1 parent 4d2d193 commit 85fdc1d
Show file tree
Hide file tree
Showing 16 changed files with 119 additions and 71 deletions.
64 changes: 42 additions & 22 deletions coloraide/algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import itertools as it
from .deprecate import deprecated
from .types import (
ArrayLike, MatrixLike, VectorLike, TensorLike, Array, Matrix, Tensor, Vector,
Shape, ShapeLike, DimHints, SupportsFloatOrInt, MathType
ArrayLike, MatrixLike, VectorLike, TensorLike, Array, Matrix, Tensor, Vector, VectorBool, MatrixBool, TensorBool,
MatrixInt, MathType, Shape, ShapeLike, DimHints, SupportsFloatOrInt
)
from typing import Callable, Sequence, Iterator, Any, Iterable, overload

Expand Down Expand Up @@ -661,7 +661,7 @@ def interpolate(points: list[Vector], method: str = 'linear') -> Interpolate:
################################
# Matrix/linear algebra math
################################
def pretty(value: Array | float, *, _depth: int = 0, _shape: Shape | None = None) -> str:
def pretty(value: float | ArrayLike, *, _depth: int = 0, _shape: Shape | None = None) -> str:
"""Format the print output."""

if _shape is None:
Expand All @@ -677,7 +677,7 @@ def pretty(value: Array | float, *, _depth: int = 0, _shape: Shape | None = None
return str(value)


def pprint(value: Array | float) -> None:
def pprint(value: float | ArrayLike) -> None:
"""Print the matrix or value."""

print(pretty(value))
Expand All @@ -686,13 +686,13 @@ def pprint(value: Array | float) -> None:
def all(a: float | ArrayLike) -> bool: # noqa: A001
"""Return true if all elements are "true"."""

return _all(flatiter(a))
return _all(flatiter(a)) # type: ignore[arg-type]


def any(a: float | ArrayLike) -> bool: # noqa: A001
"""Return true if all elements are "true"."""

return _any(flatiter(a))
return _any(flatiter(a)) # type: ignore[arg-type]


def vdot(a: VectorLike, b: VectorLike) -> float:
Expand Down Expand Up @@ -804,12 +804,22 @@ def cross(a: VectorLike, b: VectorLike) -> Vector:


@overload
def cross(a: MatrixLike, b: Any) -> Matrix:
def cross(a: MatrixLike, b: VectorLike | MatrixLike) -> Matrix:
...


@overload
def cross(a: Any, b: MatrixLike) -> Matrix:
def cross(a: VectorLike | MatrixLike, b: MatrixLike) -> Matrix:
...


@overload
def cross(a: TensorLike, b: Any) -> Tensor:
...


@overload
def cross(a: Any, b: TensorLike) -> Tensor:
...


Expand Down Expand Up @@ -1174,7 +1184,7 @@ def matmul(
raise ValueError('Inputs require at least 1 dimension, scalars are not allowed')


def _matrix_chain_order(shapes: Sequence[Shape]) -> list[list[int]]:
def _matrix_chain_order(shapes: Sequence[Shape]) -> MatrixInt:
"""
Calculate chain order.
Expand All @@ -1193,7 +1203,7 @@ def _matrix_chain_order(shapes: Sequence[Shape]) -> list[list[int]]:

n = len(shapes)
m = full((n, n), 0) # type: Any
s = full((n, n), 0) # type: list[list[int]] # type: ignore[assignment]
s = full((n, n), 0) # type: MatrixInt # type: ignore[assignment]
p = [a[0] for a in shapes] + [shapes[-1][1]]

for d in range(1, n):
Expand All @@ -1208,7 +1218,7 @@ def _matrix_chain_order(shapes: Sequence[Shape]) -> list[list[int]]:
return s


def _multi_dot(arrays: Sequence[ArrayLike], indexes: list[list[int]], i: int, j: int) -> ArrayLike:
def _multi_dot(arrays: Sequence[ArrayLike], indexes: MatrixInt, i: int, j: int) -> ArrayLike:
"""Recursively dot the matrices in the array."""

if i != j:
Expand Down Expand Up @@ -1262,7 +1272,7 @@ def multi_dot(arrays: Sequence[ArrayLike]) -> Any:
is_vector = True

# Make sure everything is a 2-D matrix as the next calculations only work for 2-D.
if not _all(len(s) == 2 for s in shapes):
if not _all(len(s) == 2 for s in shapes): # type: ignore[arg-type]
raise ValueError('All arrays must be 2-D matrices')

# No need to do the expensive and complicated chain order algorithm for only 3.
Expand Down Expand Up @@ -1312,7 +1322,7 @@ def __init__(self, array: ArrayLike | float, old: Shape, new: Shape) -> None:
self._chunk_subindex = 0
self._chunk_max = 0
self._chunk_index = 0
self._chunk = [] # type: list[float]
self._chunk = [] # type: Vector

# Unravel the data as it will be quicker to slice the data in a flattened form
# than iterating over the dimensions to replicate the data.
Expand Down Expand Up @@ -1952,17 +1962,17 @@ def isclose(a: float, b: float, *, dims: DimHints | None = ..., **kwargs: Any) -


@overload
def isclose(a: VectorLike, b: VectorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[bool]:
def isclose(a: VectorLike, b: VectorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> VectorBool:
...


@overload
def isclose(a: MatrixLike, b: MatrixLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[list[bool]]:
def isclose(a: MatrixLike, b: MatrixLike, *, dims: DimHints | None = ..., **kwargs: Any) -> MatrixBool:
...


@overload
def isclose(a: TensorLike, b: TensorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[list[list[bool]]]:
def isclose(a: TensorLike, b: TensorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> TensorBool:
...


Expand All @@ -1975,17 +1985,17 @@ def isnan(a: float, *, dims: DimHints | None = ..., **kwargs: Any) -> bool:


@overload
def isnan(a: VectorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[bool]:
def isnan(a: VectorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> VectorBool:
...


@overload
def isnan(a: MatrixLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[list[bool]]:
def isnan(a: MatrixLike, *, dims: DimHints | None = ..., **kwargs: Any) -> MatrixBool:
...


@overload
def isnan(a: TensorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> list[list[list[bool]]]:
def isnan(a: TensorLike, *, dims: DimHints | None = ..., **kwargs: Any) -> TensorBool:
...


Expand Down Expand Up @@ -2424,7 +2434,7 @@ def reshape(array: ArrayLike | float, new_shape: int | ShapeLike) -> float | Arr
return m # type: ignore[no-any-return]


def _shape(a: Any, s: Shape) -> Shape:
def _shape(a: ArrayLike | float, s: Shape) -> Shape:
"""
Get the shape of the array.
Expand Down Expand Up @@ -2458,7 +2468,7 @@ def shape(a: ArrayLike | float) -> Shape:
return _shape(a, ())


def fill_diagonal(matrix: MatrixLike, val: float | ArrayLike, wrap: bool = False) -> None:
def fill_diagonal(matrix: Matrix | Tensor, val: float | ArrayLike, wrap: bool = False) -> None:
"""Fill an N-D matrix diagonal."""

s = shape(matrix)
Expand Down Expand Up @@ -2628,7 +2638,7 @@ def lu(
size = s[1]
wide = True
for _ in range(diff):
matrix.append([0.0] * size) # type: ignore[list-item] # noqa: PERF401
matrix.append([0.0] * size) # type: ignore[arg-type] # noqa: PERF401
# Tall
else:
tall = True
Expand Down Expand Up @@ -2970,6 +2980,16 @@ def inv(matrix: MatrixLike | TensorLike) -> Matrix | Tensor:
return _back_sub_matrix(u, _forward_sub_matrix(l, p, s2), s2)


@overload
def vstack(arrays: Sequence[float | Vector | Matrix]) -> Matrix:
...


@overload
def vstack(arrays: Sequence[Tensor]) -> Tensor:
...


def vstack(arrays: Sequence[ArrayLike | float]) -> Matrix | Tensor:
"""Vertical stack."""

Expand Down
10 changes: 5 additions & 5 deletions coloraide/color.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _parse(
data: VectorLike | None = None,
alpha: float = util.DEF_ALPHA,
**kwargs: Any
) -> tuple[Space, list[float]]:
) -> tuple[Space, Vector]:
"""Parse the color."""

# Parse a color string or color space name and coordinates
Expand Down Expand Up @@ -1004,7 +1004,7 @@ def discrete(
max_delta_e: float = 0,
delta_e: str | None = None,
delta_e_args: dict[str, Any] | None = None,
domain: list[float] | None = None,
domain: Vector | None = None,
**interpolate_args: Any
) -> Interpolator:
"""Create a discrete interpolation."""
Expand All @@ -1031,7 +1031,7 @@ def interpolate(
hue: str = util.DEF_HUE_ADJ,
premultiplied: bool = True,
extrapolate: bool = False,
domain: list[float] | None = None,
domain: Vector | None = None,
method: str | None = None,
padding: float | tuple[float, float] | None = None,
carryforward: bool | None = None,
Expand Down Expand Up @@ -1208,10 +1208,10 @@ def get(self, name: str, *, nans: bool = True) -> float:
...

@overload
def get(self, name: list[str] | tuple[str, ...], *, nans: bool = True) -> list[float]:
def get(self, name: list[str] | tuple[str, ...], *, nans: bool = True) -> Vector:
...

def get(self, name: str | list[str] | tuple[str, ...], *, nans: bool = True) -> float | list[float]:
def get(self, name: str | list[str] | tuple[str, ...], *, nans: bool = True) -> float | Vector:
"""Get channel."""

# Handle single channel
Expand Down
6 changes: 3 additions & 3 deletions coloraide/gamut/pointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ..spaces.lch import lab_to_lch, lch_to_lab
from .. import algebra as alg
from .. import util
from ..types import Vector
from ..types import Vector, Matrix
from typing import TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -181,7 +181,7 @@ def in_pointer_gamut(color: Color, tolerance: float) -> bool:
return c <= (get_chroma_limit(l, h) + tolerance)


def pointer_gamut_boundary(lightness: float | None = None) -> list[Vector]:
def pointer_gamut_boundary(lightness: float | None = None) -> Matrix:
"""
Calculate the Pointer gamut boundary points for the given lightness.
Expand All @@ -192,7 +192,7 @@ def pointer_gamut_boundary(lightness: float | None = None) -> list[Vector]:
# Maximum Pointer gamut boundary
# For each hue, find the lightness/chroma point that is furthest away from the white point.
if lightness is None:
max_gamut = [] # type: list[Vector]
max_gamut = [] # type: Matrix
for i, h in enumerate(LCH_H):
max_dxy = 0.0
max_xyy = [0.0, 0.0, 0.0]
Expand Down
16 changes: 8 additions & 8 deletions coloraide/interpolate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from abc import ABCMeta, abstractmethod
from .. import algebra as alg
from .. spaces import HSVish, HSLish, Cylindrical, RGBish, LChish, Labish
from ..types import Vector, ColorInput, Plugin
from ..types import Matrix, Vector, ColorInput, Plugin
from typing import Callable, Sequence, Mapping, Any, TYPE_CHECKING

if TYPE_CHECKING: # pragma: no cover
Expand Down Expand Up @@ -52,7 +52,7 @@ def hint(mid: float) -> Callable[..., float]:
return functools.partial(midpoint, h=mid)


def normalize_domain(d: list[float]) -> list[float]:
def normalize_domain(d: Vector) -> Vector:
"""Normalize domain between 0 and 1."""

total = d[-1] - d[0]
Expand All @@ -70,7 +70,7 @@ class Interpolator(metaclass=ABCMeta):

def __init__(
self,
coordinates: list[Vector],
coordinates: Matrix,
channel_names: Sequence[str],
create: type[Color],
easings: list[Callable[..., float] | None],
Expand Down Expand Up @@ -114,7 +114,7 @@ def __init__(
self.padding(padding)

# Set the domain
self._domain = [] # type: list[float]
self._domain = [] # type: Vector
if domain is not None:
self.domain(domain)

Expand Down Expand Up @@ -206,7 +206,7 @@ def domain(self, domain: Sequence[float]) -> None:

# Ensure domain ascends.
# If we have a domain of length 1, we will duplicate it.
d = [] # type: list[float]
d = [] # type: Vector
if domain:
length = len(domain)

Expand Down Expand Up @@ -458,7 +458,7 @@ class Interpolate(Plugin, metaclass=ABCMeta):
@abstractmethod
def interpolator(
self,
coordinates: list[Vector],
coordinates: Matrix,
channel_names: Sequence[str],
create: type[Color],
easings: list[Callable[..., float] | None],
Expand All @@ -468,7 +468,7 @@ def interpolator(
progress: Mapping[str, Callable[..., float]] | Callable[..., float] | None,
premultiplied: bool,
extrapolate: bool = False,
domain: list[float] | None = None,
domain: Vector | None = None,
padding: float | tuple[float, float] | None = None,
hue: str = 'shorter',
**kwargs: Any
Expand Down Expand Up @@ -650,7 +650,7 @@ def interpolator(
hue: str,
premultiplied: bool,
extrapolate: bool,
domain: list[float] | None = None,
domain: Vector | None = None,
padding: float | tuple[float, float] | None = None,
carryforward: bool = False,
powerless: bool = False,
Expand Down
Loading

0 comments on commit 85fdc1d

Please sign in to comment.