Skip to content

Commit

Permalink
Typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
hexane360 committed Jun 20, 2024
1 parent 113d028 commit a28c52f
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 124 deletions.
152 changes: 76 additions & 76 deletions atomlib/atomcell.py

Large diffs are not rendered by default.

87 changes: 44 additions & 43 deletions atomlib/atoms.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import polars.testing
import polars.type_aliases

from .types import to_vec3, VecLike, ParamSpec, Concatenate, TypeAlias
from .types import to_vec3, VecLike, ParamSpec, Concatenate, TypeAlias, Self
from .bbox import BBox3D
from .elem import get_elem, get_sym, get_mass
from .transform import Transform3D, IntoTransform3D, AffineTransform3D
Expand Down Expand Up @@ -162,8 +162,8 @@ def wrapper(self: HasAtomsT, *args: P.args, **kwargs: P.kwargs) -> HasAtomsT:

def _fwd_frame(
impl_f: t.Callable[Concatenate[polars.DataFrame, P], T]
) -> t.Callable[[t.Callable[Concatenate[HasAtomsT, P], t.Any]], t.Callable[Concatenate[HasAtomsT, P], T]]:
def inner(f: t.Callable[Concatenate[HasAtomsT, P], t.Any]) -> t.Callable[Concatenate[HasAtomsT, P], T]:
) -> t.Callable[[t.Callable[Concatenate[t.Any, P], T]], t.Callable[Concatenate[HasAtoms, P], T]]:
def inner(f: t.Callable[Concatenate[HasAtoms, P], T]) -> t.Callable[Concatenate[HasAtoms, P], T]:
@wraps(f)
def wrapper(self: HasAtoms, *args: P.args, **kwargs: P.kwargs) -> T:
return impl_f(self._get_frame(), *args, **kwargs)
Expand Down Expand Up @@ -198,7 +198,7 @@ def get_atoms(self, frame: t.Literal['local'] = 'local') -> Atoms:
...

@abc.abstractmethod
def with_atoms(self: HasAtomsT, atoms: HasAtoms, frame: t.Literal['local'] = 'local') -> HasAtomsT:
def with_atoms(self, atoms: HasAtoms, frame: t.Literal['local'] = 'local') -> Self:
"""
Return a copy of self with the inner [`Atoms`][atomlib.atoms.Atoms] replaced.
Expand Down Expand Up @@ -229,7 +229,7 @@ def _get_frame(self) -> polars.DataFrame:

@property
@_fwd_frame(lambda df: df.columns)
def columns(self) -> t.Sequence[str]:
def columns(self) -> t.List[str]:
"""
Return the column names in `self`.
Expand All @@ -240,7 +240,7 @@ def columns(self) -> t.Sequence[str]:

@property
@_fwd_frame(lambda df: df.dtypes)
def dtypes(self) -> t.Sequence[polars.DataType]:
def dtypes(self) -> t.List[polars.DataType]:
"""
Return the datatypes in `self`.
Expand Down Expand Up @@ -335,10 +335,10 @@ def drop(self, *columns: t.Union[str, t.Iterable[str]]) -> polars.DataFrame:
# row-wise operations

def filter(
self: HasAtomsT,
self,
*predicates: t.Union[None, IntoExprColumn, t.Iterable[IntoExprColumn], bool, t.List[bool], numpy.ndarray],
**constraints: t.Any,
) -> HasAtomsT:
) -> Self:
"""Filter `self`, removing rows which evaluate to `False`."""
# TODO clean up
preds_not_none: t.Tuple[t.Union[IntoExprColumn, t.Iterable[IntoExprColumn], bool, t.List[bool], numpy.ndarray], ...]
Expand Down Expand Up @@ -430,22 +430,22 @@ def concat(cls: t.Type[HasAtomsT],

@t.overload
def partition_by(
self: HasAtomsT, by: t.Union[str, t.Sequence[str]], *more_by: str,
self, by: t.Union[str, t.Sequence[str]], *more_by: str,
maintain_order: bool = True, include_key: bool = True, as_dict: t.Literal[False] = False
) -> t.List[HasAtomsT]:
) -> t.List[Self]:
...

@t.overload
def partition_by(
self: HasAtomsT, by: t.Union[str, t.Sequence[str]], *more_by: str,
self, by: t.Union[str, t.Sequence[str]], *more_by: str,
maintain_order: bool = True, include_key: bool = True, as_dict: t.Literal[True] = ...
) -> t.Dict[t.Any, HasAtomsT]:
) -> t.Dict[t.Any, Self]:
...

def partition_by(
self: HasAtomsT, by: t.Union[str, t.Sequence[str]], *more_by: str,
self, by: t.Union[str, t.Sequence[str]], *more_by: str,
maintain_order: bool = True, include_key: bool = True, as_dict: bool = False
) -> t.Union[t.List[HasAtomsT], t.Dict[t.Any, HasAtomsT]]:
) -> t.Union[t.List[Self], t.Dict[t.Any, Self]]:
"""
Group by the given columns and partition into separate dataframes.
Expand All @@ -467,7 +467,7 @@ def select(
self,
*exprs: t.Union[IntoExpr, t.Iterable[IntoExpr]],
**named_exprs: IntoExpr,
):
) -> polars.DataFrame:
"""
Select `exprs` from `self`, and return as a [`polars.DataFrame`][polars.DataFrame].
Expand All @@ -487,10 +487,10 @@ def select_schema(self, schema: SchemaDict) -> polars.DataFrame:
return _select_schema(self, schema)

def select_props(
self: HasAtomsT,
self,
*exprs: t.Union[IntoExpr, t.Iterable[IntoExpr]],
**named_exprs: IntoExpr
) -> HasAtomsT:
) -> Self:
"""
Select `exprs` from `self`, while keeping required columns.
Expand Down Expand Up @@ -546,10 +546,10 @@ def __contains__(self, key: str) -> bool:
"""Return whether `self` contains the given column."""
...

def __add__(self: HasAtomsT, other: IntoAtoms) -> HasAtomsT:
def __add__(self, other: IntoAtoms) -> HasAtoms:
return self.__class__.concat((self, other), how='inner')

def __radd__(self: HasAtomsT, other: IntoAtoms) -> HasAtomsT:
def __radd__(self, other: IntoAtoms) -> HasAtoms:
return self.__class__.concat((other, self), how='inner')

def __getitem__(self, column: str) -> polars.Series:
Expand All @@ -572,7 +572,8 @@ def bbox_atoms(self) -> BBox3D:

bbox = bbox_atoms

def transform_atoms(self: HasAtomsT, transform: IntoTransform3D, selection: t.Optional[AtomSelection] = None, *, transform_velocities: bool = False) -> HasAtomsT:
def transform_atoms(self, transform: IntoTransform3D, selection: t.Optional[AtomSelection] = None, *,
transform_velocities: bool = False) -> Self:
"""
Transform the atoms in `self` by `transform`.
If `selection` is given, only transform the atoms in `selection`.
Expand All @@ -587,7 +588,7 @@ def transform_atoms(self: HasAtomsT, transform: IntoTransform3D, selection: t.Op

transform = transform_atoms

def round_near_zero(self: HasAtomsT, tol: float = 1e-14) -> HasAtomsT:
def round_near_zero(self, tol: float = 1e-14) -> Self:
"""
Round atom position values near zero to zero.
"""
Expand All @@ -596,9 +597,9 @@ def round_near_zero(self: HasAtomsT, tol: float = 1e-14) -> HasAtomsT:
for col in range(3)
).list.to_array(3))

def crop(self: HasAtomsT, x_min: float = -numpy.inf, x_max: float = numpy.inf,
def crop(self, x_min: float = -numpy.inf, x_max: float = numpy.inf,
y_min: float = -numpy.inf, y_max: float = numpy.inf,
z_min: float = -numpy.inf, z_max: float = numpy.inf) -> HasAtomsT:
z_min: float = -numpy.inf, z_max: float = numpy.inf) -> Self:
"""
Crop, removing all atoms outside of the specified region, inclusive.
"""
Expand All @@ -611,12 +612,12 @@ def crop(self: HasAtomsT, x_min: float = -numpy.inf, x_max: float = numpy.inf,

crop_atoms = crop

def _wrap(self: HasAtomsT, eps: float = 1e-5) -> HasAtomsT:
def _wrap(self, eps: float = 1e-5) -> Self:
coords = (self.coords() + eps) % 1. - eps
return self.with_coords(coords)

def deduplicate(self: HasAtomsT, tol: float = 1e-3, subset: t.Iterable[str] = ('x', 'y', 'z', 'symbol'),
keep: UniqueKeepStrategy = 'first', maintain_order: bool = True) -> HasAtomsT:
def deduplicate(self, tol: float = 1e-3, subset: t.Iterable[str] = ('x', 'y', 'z', 'symbol'),
keep: UniqueKeepStrategy = 'first', maintain_order: bool = True) -> Self:
"""
De-duplicate atoms in `self`. Atoms of the same `symbol` that are closer than `tolerance`
to each other (by Euclidian distance) will be removed, leaving only the atom specified by
Expand Down Expand Up @@ -722,22 +723,22 @@ def masses(self) -> t.Optional[polars.Series]:
return self.try_get_column('mass')

@t.overload
def add_atom(self: HasAtomsT, elem: t.Union[int, str], x: ArrayLike, /, *,
def add_atom(self, elem: t.Union[int, str], x: ArrayLike, /, *,
y: None = None, z: None = None,
**kwargs: t.Any) -> HasAtomsT:
**kwargs: t.Any) -> Self:
...

@t.overload
def add_atom(self: HasAtomsT, elem: t.Union[int, str], /,
def add_atom(self, elem: t.Union[int, str], /,
x: float, y: float, z: float,
**kwargs: t.Any) -> HasAtomsT:
**kwargs: t.Any) -> Self:
...

def add_atom(self: HasAtomsT, elem: t.Union[int, str], /,
def add_atom(self, elem: t.Union[int, str], /,
x: t.Union[ArrayLike, float],
y: t.Optional[float] = None,
z: t.Optional[float] = None,
**kwargs: t.Any) -> HasAtomsT:
**kwargs: t.Any) -> Self:
"""
Return a copy of `self` with an extra atom.
Expand Down Expand Up @@ -802,7 +803,7 @@ def pos(self,

return selection

def with_index(self: HasAtomsT, index: t.Optional[AtomValues] = None) -> HasAtomsT:
def with_index(self, index: t.Optional[AtomValues] = None) -> Self:
"""
Returns `self` with a row index added in column 'i' (dtype [`polars.Int64`][polars.datatypes.Int64]).
If `index` is not specified, defaults to an existing index or a new index.
Expand All @@ -813,7 +814,7 @@ def with_index(self: HasAtomsT, index: t.Optional[AtomValues] = None) -> HasAtom
index = numpy.arange(len(self), dtype=numpy.int64)
return self.with_column(_values_to_expr(self, index, polars.Int64).alias('i'))

def with_wobble(self: HasAtomsT, wobble: t.Optional[AtomValues] = None) -> HasAtomsT:
def with_wobble(self, wobble: t.Optional[AtomValues] = None) -> Self:
"""
Return `self` with the given displacements in column 'wobble' (dtype [`polars.Float64`][polars.datatypes.Float64]).
If `wobble` is not specified, defaults to the already-existing wobbles or 0.
Expand All @@ -823,7 +824,7 @@ def with_wobble(self: HasAtomsT, wobble: t.Optional[AtomValues] = None) -> HasAt
wobble = 0. if wobble is None else wobble
return self.with_column(_values_to_expr(self, wobble, polars.Float64).alias('wobble'))

def with_occupancy(self: HasAtomsT, frac_occupancy: t.Optional[AtomValues] = None) -> HasAtomsT:
def with_occupancy(self, frac_occupancy: t.Optional[AtomValues] = None) -> Self:
"""
Return self with the given fractional occupancies (dtype [`polars.Float64`][polars.datatypes.Float64]).
If `frac_occupancy` is not specified, defaults to the already-existing occupancies or 1.
Expand All @@ -833,7 +834,7 @@ def with_occupancy(self: HasAtomsT, frac_occupancy: t.Optional[AtomValues] = Non
frac_occupancy = 1. if frac_occupancy is None else frac_occupancy
return self.with_column(_values_to_expr(self, frac_occupancy, polars.Float64).alias('frac_occupancy'))

def apply_wobble(self: HasAtomsT, rng: t.Union[numpy.random.Generator, int, None] = None) -> HasAtomsT:
def apply_wobble(self, rng: t.Union[numpy.random.Generator, int, None] = None) -> Self:
"""
Displace the atoms in `self` by the amount in the `wobble` column.
`wobble` is interpretated as a mean-squared displacement, which is distributed
Expand All @@ -848,7 +849,7 @@ def apply_wobble(self: HasAtomsT, rng: t.Union[numpy.random.Generator, int, None
coords += stddev[:, None] * rng.standard_normal(coords.shape)
return self.with_coords(coords)

def apply_occupancy(self: HasAtomsT, rng: t.Union[numpy.random.Generator, int, None] = None) -> HasAtomsT:
def apply_occupancy(self, rng: t.Union[numpy.random.Generator, int, None] = None) -> Self:
"""
For each atom in `self`, use its `frac_occupancy` to randomly decide whether to remove it.
"""
Expand All @@ -860,7 +861,7 @@ def apply_occupancy(self: HasAtomsT, rng: t.Union[numpy.random.Generator, int, N
choice = rng.binomial(1, frac).astype(numpy.bool_)
return self.filter(polars.lit(choice))

def with_type(self: HasAtomsT, types: t.Optional[AtomValues] = None) -> HasAtomsT:
def with_type(self, types: t.Optional[AtomValues] = None) -> Self:
"""
Return `self` with the given atom types in column 'type'.
If `types` is not specified, use the already existing types or auto-assign them.
Expand Down Expand Up @@ -888,7 +889,7 @@ def with_type(self: HasAtomsT, types: t.Optional[AtomValues] = None) -> HasAtoms
assert (new.get_column('type') == 0).sum() == 0
return new

def with_mass(self: HasAtomsT, mass: t.Optional[ArrayLike] = None) -> HasAtomsT:
def with_mass(self, mass: t.Optional[ArrayLike] = None) -> Self:
"""
Return `self` with the given atom masses in column `'mass'`.
If `mass` is not specified, use the already existing masses or auto-assign them.
Expand All @@ -911,7 +912,7 @@ def with_mass(self: HasAtomsT, mass: t.Optional[ArrayLike] = None) -> HasAtomsT:
assert (new.get_column('mass').abs() < 1e-10).sum() == 0
return new

def with_symbol(self: HasAtomsT, symbols: ArrayLike, selection: t.Optional[AtomSelection] = None) -> HasAtomsT:
def with_symbol(self, symbols: ArrayLike, selection: t.Optional[AtomSelection] = None) -> Self:
"""
Return `self` with the given atomic symbols.
"""
Expand All @@ -925,7 +926,7 @@ def with_symbol(self: HasAtomsT, symbols: ArrayLike, selection: t.Optional[AtomS
symbols = polars.Series('symbol', list(numpy.broadcast_to(symbols, len(self))), dtype=polars.Utf8)
return self.with_columns((symbols, get_elem(symbols)))

def with_coords(self: HasAtomsT, pts: ArrayLike, selection: t.Optional[AtomSelection] = None, *, frame: t.Literal['local'] = 'local') -> HasAtomsT:
def with_coords(self, pts: ArrayLike, selection: t.Optional[AtomSelection] = None, *, frame: t.Literal['local'] = 'local') -> Self:
"""
Return `self` replaced with the given atomic positions.
"""
Expand All @@ -940,8 +941,8 @@ def with_coords(self: HasAtomsT, pts: ArrayLike, selection: t.Optional[AtomSelec
pts = numpy.broadcast_to(pts, (len(self), 3))
return self.with_columns(polars.Series('coords', pts, polars.Array(polars.Float64, 3)))

def with_velocity(self: HasAtomsT, pts: t.Optional[ArrayLike] = None,
selection: t.Optional[AtomSelection] = None) -> HasAtomsT:
def with_velocity(self, pts: t.Optional[ArrayLike] = None,
selection: t.Optional[AtomSelection] = None) -> Self:
"""
Return `self` replaced with the given atomic velocities.
If `pts` is not specified, use the already existing velocities or zero.
Expand Down
2 changes: 1 addition & 1 deletion atomlib/elem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def get_sym(elem: t.Union[int, polars.Series]):
_ = [_get_sym(t.cast(int, e)) for e in elem.to_list() if e is not None]
raise

return _get_sym(t.cast(int, elem))
return _get_sym(elem)


@t.overload
Expand Down
8 changes: 5 additions & 3 deletions atomlib/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import numpy

V = t.TypeVar('V')
V2 = t.TypeVar('V2')
T = t.TypeVar('T')
T_co = t.TypeVar('T_co', covariant=True)

WSPACE_RE = re.compile(r"\s+")
Expand Down Expand Up @@ -255,14 +257,14 @@ class Parser(t.Generic[T_co, V]):
"""Regex matching operators, brackets, and whitespace"""

@t.overload
def __init__(self: Parser[str, V], ops: t.Sequence[Op[V]],
def __init__(self: Parser[str, V2], ops: t.Sequence[Op[V2]],
parse_scalar: t.Optional[t.Callable[[str], str]] = None,
groups: t.Optional[t.Sequence[t.Tuple[str, str]]] = None):
...

@t.overload
def __init__(self: Parser[T_co, V], ops: t.Sequence[Op[V]],
parse_scalar: t.Callable[[str], T_co],
def __init__(self: Parser[T, V2], ops: t.Sequence[Op[V2]],
parse_scalar: t.Callable[[str], T],
groups: t.Optional[t.Sequence[t.Tuple[str, str]]] = None):
...

Expand Down
9 changes: 9 additions & 0 deletions atomlib/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@
"""Re-export of [`typing.TypeAlias`][typing.TypeAlias]"""


if sys.version_info < (3, 11):
import typing_extensions
Self = typing_extensions.Self
"""Re-export of [`typing.Self`][typing.Self]"""
else:
Self = t.Self
"""Re-export of [`typing.Self`][typing.Self]"""


Vec3 = NDArray[numpy.floating[t.Any]]
"""3D float vector, of shape (3,)."""

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dependencies = [
"matplotlib~=3.5",
"requests~=2.28",
"lxml~=5.0",
"typing-extensions~=4.4;python_version<'3.10'",
"typing-extensions~=4.4;python_version<'3.11'",
"importlib_resources>=5.0", # importlib.resources backport
]

Expand Down

0 comments on commit a28c52f

Please sign in to comment.