Skip to content

Commit

Permalink
Fix type-hints
Browse files Browse the repository at this point in the history
  • Loading branch information
espdev committed Mar 25, 2020
1 parent f38fd7c commit 50ce128
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 55 deletions.
2 changes: 0 additions & 2 deletions csaps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
)
from csaps._types import (
UnivariateDataType,
UnivariateVectorizedDataType,
MultivariateDataType,
NdGridDataType,
)
Expand All @@ -46,7 +45,6 @@

# Type-hints
'UnivariateDataType',
'UnivariateVectorizedDataType',
'MultivariateDataType',
'NdGridDataType',
]
95 changes: 70 additions & 25 deletions csaps/_shortcut.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,95 @@
"""

from collections import abc as c_abc
from typing import Optional, Union, Sequence, NamedTuple
from typing import Optional, Union, Sequence, NamedTuple, overload

import numpy as np

from ._base import ISmoothingSpline
from ._sspumv import CubicSmoothingSpline
from ._sspndg import ndgrid_prepare_data_sites, NdGridCubicSmoothingSpline
from ._types import (
UnivariateDataType,
UnivariateVectorizedDataType,
NdGridDataType,
)

_XDataType = Union[UnivariateDataType, NdGridDataType]
_YDataType = Union[UnivariateVectorizedDataType, np.ndarray]
_XiDataType = Optional[Union[UnivariateDataType, NdGridDataType]]
_WeightsDataType = Optional[Union[UnivariateDataType, NdGridDataType]]
_SmoothDataType = Optional[Union[float, Sequence[Optional[float]]]]
from ._types import UnivariateDataType, MultivariateDataType, NdGridDataType


class AutoSmoothingResult(NamedTuple):
"""The result for auto smoothing for `csaps` function"""

values: _YDataType
values: MultivariateDataType
"""Smoothed data values"""

smooth: _SmoothDataType
smooth: Union[float, Sequence[Optional[float]]]
"""The calculated smoothing parameter"""


_ReturnType = Union[
_YDataType,
AutoSmoothingResult,
ISmoothingSpline,
]
# **************************************
# csaps signatures
#
@overload
def csaps(xdata: UnivariateDataType,
ydata: MultivariateDataType,
*,
weights: Optional[UnivariateDataType] = None,
smooth: Optional[float] = None,
axis: Optional[int] = None) -> ISmoothingSpline: ...


@overload
def csaps(xdata: UnivariateDataType,
ydata: MultivariateDataType,
xidata: UnivariateDataType,
*,
weights: Optional[UnivariateDataType] = None,
axis: Optional[int] = None) -> AutoSmoothingResult: ...


@overload
def csaps(xdata: UnivariateDataType,
ydata: MultivariateDataType,
xidata: UnivariateDataType,
*,
smooth: float,
weights: Optional[UnivariateDataType] = None,
axis: Optional[int] = None) -> MultivariateDataType: ...


@overload
def csaps(xdata: NdGridDataType,
ydata: MultivariateDataType,
*,
weights: Optional[NdGridDataType] = None,
smooth: Optional[Sequence[float]] = None,
axis: Optional[int] = None) -> ISmoothingSpline: ...


@overload
def csaps(xdata: NdGridDataType,
ydata: MultivariateDataType,
xidata: NdGridDataType,
*,
weights: Optional[NdGridDataType] = None,
axis: Optional[int] = None) -> AutoSmoothingResult: ...


@overload
def csaps(xdata: NdGridDataType,
ydata: MultivariateDataType,
xidata: NdGridDataType,
*,
smooth: Sequence[float],
weights: Optional[NdGridDataType] = None,
axis: Optional[int] = None) -> MultivariateDataType: ...
#
# csaps signatures
# **************************************


def csaps(xdata: _XDataType,
ydata: _YDataType,
xidata: _XiDataType = None,
def csaps(xdata: Union[UnivariateDataType, NdGridDataType],
ydata: MultivariateDataType,
xidata: Optional[Union[UnivariateDataType, NdGridDataType]] = None,
*,
weights: _WeightsDataType = None,
smooth: _SmoothDataType = None,
axis: Optional[int] = None) -> _ReturnType:
weights: Optional[Union[UnivariateDataType, NdGridDataType]] = None,
smooth: Optional[Union[float, Sequence[float]]] = None,
axis: Optional[int] = None) -> Union[MultivariateDataType, ISmoothingSpline, AutoSmoothingResult]:
"""Smooths the univariate/multivariate/gridded data or computes the corresponding splines
This function might be used as the main API for smoothing any data.
Expand Down
6 changes: 3 additions & 3 deletions csaps/_sspumv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import scipy.sparse.linalg as la

from ._base import SplinePPFormBase, ISmoothingSpline
from ._types import UnivariateDataType, UnivariateVectorizedDataType, MultivariateDataType
from ._types import UnivariateDataType, MultivariateDataType
from ._reshape import from_2d, to_2d


Expand Down Expand Up @@ -120,7 +120,7 @@ class CubicSmoothingSpline(ISmoothingSpline[SplinePPForm, float, UnivariateDataT

def __init__(self,
xdata: UnivariateDataType,
ydata: UnivariateVectorizedDataType,
ydata: MultivariateDataType,
weights: ty.Optional[UnivariateDataType] = None,
smooth: ty.Optional[float] = None,
axis: int = -1):
Expand Down Expand Up @@ -301,7 +301,7 @@ class UnivariateCubicSmoothingSpline(ISmoothingSpline[SplinePPForm, float, Univa

def __init__(self,
xdata: UnivariateDataType,
ydata: UnivariateVectorizedDataType,
ydata: MultivariateDataType,
weights: ty.Optional[UnivariateDataType] = None,
smooth: ty.Optional[float] = None,
axis: int = -1) -> None:
Expand Down
36 changes: 11 additions & 25 deletions csaps/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,18 @@
"""

import typing as ty
import numbers
from collections import abc
from typing import Union, Sequence, Tuple, TypeVar
from numbers import Number
import numpy as np


UnivariateDataType = ty.Union[
np.ndarray,
ty.Sequence[numbers.Number]
]
UnivariateDataType = Union[np.ndarray, Sequence[Number]]
MultivariateDataType = Union[np.ndarray, abc.Sequence]
NdGridDataType = Sequence[UnivariateDataType]

UnivariateVectorizedDataType = ty.Union[
UnivariateDataType,
# FIXME: mypy does not support recursive types
# https://github.com/python/mypy/issues/731
# ty.Sequence['UnivariateVectorizedDataType']
]

MultivariateDataType = ty.Union[
np.ndarray,
ty.Sequence[UnivariateDataType]
]

NdGridDataType = ty.Sequence[UnivariateDataType]

TData = ty.TypeVar('TData', np.ndarray, ty.Sequence[np.ndarray])
TProps = ty.TypeVar('TProps', int, ty.Tuple[int, ...])
TSmooth = ty.TypeVar('TSmooth', float, ty.Tuple[float, ...])
TXi = ty.TypeVar('TXi', UnivariateDataType, NdGridDataType)
TSpline = ty.TypeVar('TSpline')
TData = TypeVar('TData', np.ndarray, Sequence[np.ndarray])
TProps = TypeVar('TProps', int, Tuple[int, ...])
TSmooth = TypeVar('TSmooth', float, Tuple[float, ...])
TXi = TypeVar('TXi', UnivariateDataType, NdGridDataType)
TSpline = TypeVar('TSpline')

0 comments on commit 50ce128

Please sign in to comment.