Skip to content

Commit

Permalink
Typing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Jan 26, 2024
1 parent 5ed914c commit 2fde729
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import polars.testing
import polars.type_aliases

from .types import to_vec3, VecLike
from .types import to_vec3, VecLike, ParamSpec, Concatenate
from .bbox import BBox3D
from .elem import get_elem, get_sym, get_mass
from .transform import Transform3D, IntoTransform3D, AffineTransform3D
Expand Down Expand Up @@ -144,13 +144,13 @@ def _select_schema(df: t.Union[polars.DataFrame, HasAtoms], schema: SchemaDict)


HasAtomsT = t.TypeVar('HasAtomsT', bound='HasAtoms')
P = t.ParamSpec('P')
P = ParamSpec('P')
T = t.TypeVar('T')


def _map_unchecked(
f: t.Callable[t.Concatenate[HasAtomsT, P], polars.DataFrame]
) -> t.Callable[t.Concatenate[HasAtomsT, P], HasAtomsT]:
f: t.Callable[Concatenate[HasAtomsT, P], polars.DataFrame]
) -> t.Callable[Concatenate[HasAtomsT, P], HasAtomsT]:

@wraps(f)
def wrapper(self: HasAtomsT, *args: P.args, **kwargs: P.kwargs) -> HasAtomsT:
Expand All @@ -161,9 +161,9 @@ def wrapper(self: HasAtomsT, *args: P.args, **kwargs: P.kwargs) -> HasAtomsT:


def _frame_delegate(
impl_f: t.Callable[t.Concatenate[polars.DataFrame, P], T]
) -> t.Callable[[t.Callable[t.Concatenate[HasAtomsT, P], t.Any]], t.Callable[t.Concatenate[HasAtomsT, P], T]]:
def inner(f: t.Callable[t.Concatenate[HasAtomsT, P], t.Any]) -> t.Callable[t.Concatenate[HasAtomsT, P], T]:
impl_f: t.Callable[Concatenate[polars.DataFrame, P], T]
) -> t.Callable[[t.Callable[Concatenate[HasAtomsT, P], t.Any]], t.Callable[Concatenate[HasAtomsT, P], T]]:
def inner(f: t.Callable[Concatenate[HasAtomsT, P], t.Any]) -> t.Callable[Concatenate[HasAtomsT, P], T]:
@wraps(f)
def wrapper(self: HasAtoms, *args: P.args, **kwargs: P.kwargs) -> T:
return impl_f(self._get_frame(), *args, **kwargs)
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_column_index(self, name: str) -> int:
def group_by(self, by: t.Union[IntoExpr, t.Iterable[IntoExpr]], *more_by: IntoExpr, maintain_order: bool = False):
...

def pipe(self: HasAtomsT, function: t.Callable[t.Concatenate[HasAtomsT, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
def pipe(self: HasAtomsT, function: t.Callable[Concatenate[HasAtomsT, P], T], *args: P.args, **kwargs: P.kwargs) -> T:
"""Apply `function` to `self` (in method-call syntax)."""
return function(self, *args, **kwargs)

Expand Down

0 comments on commit 2fde729

Please sign in to comment.