diff --git a/.github/workflows/pyright.yml b/.github/workflows/pyright.yml new file mode 100644 index 00000000..73dc81f1 --- /dev/null +++ b/.github/workflows/pyright.yml @@ -0,0 +1,19 @@ +on: + - push + - pull_request + +name: Type checker +jobs: + pyright: + name: pyright + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@master + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - run: pip install uv + - run: uv pip install --system -e .[amber,ase,pymatgen] rdkit openbabel-wheel + - uses: jakebailey/pyright-action@v2 + with: + version: 1.1.363 diff --git a/benchmark/test_import.py b/benchmark/test_import.py index 04d46137..846d72b2 100644 --- a/benchmark/test_import.py +++ b/benchmark/test_import.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess import sys diff --git a/docs/conf.py b/docs/conf.py index 3f897fc9..e3c0b3d4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,6 +11,8 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + import os import subprocess as sp import sys diff --git a/docs/make_format.py b/docs/make_format.py index 8a7878f9..2b3c03c6 100644 --- a/docs/make_format.py +++ b/docs/make_format.py @@ -1,8 +1,15 @@ +from __future__ import annotations + import csv import os +import sys from collections import defaultdict from inspect import Parameter, Signature, cleandoc, signature -from typing import Literal + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal from numpydoc.docscrape import Parameter as numpydoc_Parameter from numpydoc.docscrape_sphinx import SphinxDocString diff --git a/docs/nb/try_dpdata.ipynb b/docs/nb/try_dpdata.ipynb index 7dc225b4..1a0a7328 100644 --- a/docs/nb/try_dpdata.ipynb +++ b/docs/nb/try_dpdata.ipynb @@ -13,6 +13,8 @@ "metadata": {}, "outputs": [], "source": [ + "from __future__ import annotations\n", + "\n", "import dpdata" ] }, diff --git a/dpdata/__about__.py b/dpdata/__about__.py index d5cfca64..3ee47d3c 100644 --- a/dpdata/__about__.py +++ b/dpdata/__about__.py @@ -1 +1,3 @@ +from __future__ import annotations + __version__ = "unknown" diff --git a/dpdata/__init__.py b/dpdata/__init__.py index 847554d3..f2cd233f 100644 --- a/dpdata/__init__.py +++ b/dpdata/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import lammps, md, vasp from .bond_order_system import BondOrderSystem from .system import LabeledSystem, MultiSystems, System diff --git a/dpdata/__main__.py b/dpdata/__main__.py index aad1556f..4c60f3f2 100644 --- a/dpdata/__main__.py +++ b/dpdata/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dpdata.cli import dpdata_cli if __name__ == "__main__": diff --git a/dpdata/abacus/md.py b/dpdata/abacus/md.py index b96a0fd0..fa184177 100644 --- a/dpdata/abacus/md.py +++ b/dpdata/abacus/md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings diff --git a/dpdata/abacus/relax.py b/dpdata/abacus/relax.py index fb3c8da0..976243b8 100644 --- a/dpdata/abacus/relax.py +++ b/dpdata/abacus/relax.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import numpy as np diff --git a/dpdata/abacus/scf.py b/dpdata/abacus/scf.py index df50b010..193e4d4b 100644 --- a/dpdata/abacus/scf.py +++ b/dpdata/abacus/scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import re import warnings diff --git a/dpdata/amber/mask.py b/dpdata/amber/mask.py index e3ae1e8d..155e2a7b 100644 --- a/dpdata/amber/mask.py +++ b/dpdata/amber/mask.py @@ -1,5 +1,7 @@ """Amber mask.""" +from __future__ import annotations + try: import parmed except ImportError: diff --git a/dpdata/amber/md.py b/dpdata/amber/md.py index 91240121..f3217fbd 100644 --- a/dpdata/amber/md.py +++ b/dpdata/amber/md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import re diff --git a/dpdata/amber/sqm.py b/dpdata/amber/sqm.py index 5dcbf995..1be3802a 100644 --- a/dpdata/amber/sqm.py +++ b/dpdata/amber/sqm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.periodic_table import ELEMENTS diff --git a/dpdata/ase_calculator.py b/dpdata/ase_calculator.py index c0579978..1de760a5 100644 --- a/dpdata/ase_calculator.py +++ b/dpdata/ase_calculator.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING from ase.calculators.calculator import ( # noqa: TID253 Calculator, @@ -23,7 +25,10 @@ class DPDataCalculator(Calculator): dpdata driver """ - name = "dpdata" + @property + def name(self) -> str: + return "dpdata" + implemented_properties = ["energy", "free_energy", "forces", "virial", "stress"] def __init__(self, driver: Driver, **kwargs) -> None: @@ -32,9 +37,9 @@ def __init__(self, driver: Driver, **kwargs) -> None: def calculate( self, - atoms: Optional["Atoms"] = None, - properties: List[str] = ["energy", "forces"], - system_changes: List[str] = all_changes, + atoms: Atoms | None = None, + properties: list[str] = ["energy", "forces"], + system_changes: list[str] = all_changes, ): """Run calculation with a driver. @@ -48,10 +53,10 @@ def calculate( system_changes : List[str], optional unused, only for function signature compatibility, by default all_changes """ - if atoms is not None: - self.atoms = atoms.copy() + assert atoms is not None + atoms = atoms.copy() - system = dpdata.System(self.atoms, fmt="ase/structure") + system = dpdata.System(atoms, fmt="ase/structure") data = system.predict(driver=self.driver).data self.results["energy"] = data["energies"][0] diff --git a/dpdata/bond_order_system.py b/dpdata/bond_order_system.py index 1b6f903d..7a23acca 100644 --- a/dpdata/bond_order_system.py +++ b/dpdata/bond_order_system.py @@ -1,5 +1,7 @@ # %% # Bond Order System +from __future__ import annotations + from copy import deepcopy import numpy as np @@ -96,13 +98,14 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): mol = fmtobj.from_bond_order_system(file_name, **kwargs) self.from_rdkit_mol(mol) if hasattr(fmtobj.from_bond_order_system, "post_func"): - for post_f in fmtobj.from_bond_order_system.post_func: + for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self def to_fmt_obj(self, fmtobj, *args, **kwargs): from rdkit.Chem import Conformer + assert self.rdkit_mol is not None self.rdkit_mol.RemoveAllConformers() for ii in range(self.get_nframes()): conf = Conformer() @@ -145,9 +148,9 @@ def get_formal_charges(self): """Return the formal charges on each atom.""" return self.data["formal_charges"] - def copy(self): + def copy(self): # type: ignore new_mol = deepcopy(self.rdkit_mol) - self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol) + return self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol) def __add__(self, other): raise NotImplementedError( diff --git a/dpdata/cli.py b/dpdata/cli.py index 2e39d17d..aadff1a8 100644 --- a/dpdata/cli.py +++ b/dpdata/cli.py @@ -1,7 +1,8 @@ """Command line interface for dpdata.""" +from __future__ import annotations + import argparse -from typing import Optional from . import __version__ from .system import LabeledSystem, MultiSystems, System @@ -59,11 +60,11 @@ def convert( *, from_file: str, from_format: str = "auto", - to_file: Optional[str] = None, - to_format: Optional[str] = None, + to_file: str | None = None, + to_format: str | None = None, no_labeled: bool = False, multi: bool = False, - type_map: Optional[list] = None, + type_map: list | None = None, **kwargs, ): """Convert files from one format to another one. diff --git a/dpdata/cp2k/cell.py b/dpdata/cp2k/cell.py index 7af73353..a3021b81 100644 --- a/dpdata/cp2k/cell.py +++ b/dpdata/cp2k/cell.py @@ -1,4 +1,5 @@ # %% +from __future__ import annotations import numpy as np diff --git a/dpdata/cp2k/output.py b/dpdata/cp2k/output.py index c84355c4..bd827595 100644 --- a/dpdata/cp2k/output.py +++ b/dpdata/cp2k/output.py @@ -1,4 +1,6 @@ # %% +from __future__ import annotations + import math import re from collections import OrderedDict diff --git a/dpdata/data_type.py b/dpdata/data_type.py index 64d4c5b1..bbc7401d 100644 --- a/dpdata/data_type.py +++ b/dpdata/data_type.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from enum import Enum, unique -from typing import TYPE_CHECKING, Tuple +from typing import TYPE_CHECKING import numpy as np @@ -50,7 +52,7 @@ def __init__( self, name: str, dtype: type, - shape: Tuple[int, Axis] = None, + shape: tuple[int | Axis, ...] | None = None, required: bool = True, ) -> None: self.name = name @@ -58,8 +60,9 @@ def __init__( self.shape = shape self.required = required - def real_shape(self, system: "System") -> Tuple[int]: + def real_shape(self, system: System) -> tuple[int]: """Returns expected real shape of a system.""" + assert self.shape is not None shape = [] for ii in self.shape: if ii is Axis.NFRAMES: @@ -70,7 +73,7 @@ def real_shape(self, system: "System") -> Tuple[int]: shape.append(system.get_natoms()) elif ii is Axis.NBONDS: # BondOrderSystem - shape.append(system.get_nbonds()) + shape.append(system.get_nbonds()) # type: ignore elif ii == -1: shape.append(AnyInt(-1)) elif isinstance(ii, int): @@ -79,7 +82,7 @@ def real_shape(self, system: "System") -> Tuple[int]: raise RuntimeError("Shape is not an int!") return tuple(shape) - def check(self, system: "System"): + def check(self, system: System): """Check if a system has correct data of this type. Parameters diff --git a/dpdata/deepmd/comp.py b/dpdata/deepmd/comp.py index 7b909b16..ab004447 100644 --- a/dpdata/deepmd/comp.py +++ b/dpdata/deepmd/comp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import shutil diff --git a/dpdata/deepmd/mixed.py b/dpdata/deepmd/mixed.py index 0d0ad89d..b25107db 100644 --- a/dpdata/deepmd/mixed.py +++ b/dpdata/deepmd/mixed.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import shutil diff --git a/dpdata/deepmd/raw.py b/dpdata/deepmd/raw.py index c7a64ec4..e772714a 100644 --- a/dpdata/deepmd/raw.py +++ b/dpdata/deepmd/raw.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import warnings diff --git a/dpdata/dftbplus/output.py b/dpdata/dftbplus/output.py index ba8f6c84..0f10c3ac 100644 --- a/dpdata/dftbplus/output.py +++ b/dpdata/dftbplus/output.py @@ -1,9 +1,9 @@ -from typing import Tuple +from __future__ import annotations import numpy as np -def read_dftb_plus(fn_1: str, fn_2: str) -> Tuple[str, np.ndarray, float, np.ndarray]: +def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from DFTB+ input and output. Parameters diff --git a/dpdata/driver.py b/dpdata/driver.py index 81d9a9ed..b5ff5340 100644 --- a/dpdata/driver.py +++ b/dpdata/driver.py @@ -1,12 +1,14 @@ """Driver plugin system.""" +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Callable, List, Union +from typing import TYPE_CHECKING, Callable from .plugin import Plugin if TYPE_CHECKING: - import ase + import ase.calculators.calculator class Driver(ABC): @@ -43,7 +45,7 @@ def register(key: str) -> Callable: return Driver.__DriverPlugin.register(key) @staticmethod - def get_driver(key: str) -> "Driver": + def get_driver(key: str) -> type[Driver]: """Get a driver plugin. Parameters @@ -97,7 +99,7 @@ def label(self, data: dict) -> dict: return NotImplemented @property - def ase_calculator(self) -> "ase.calculators.calculator.Calculator": + def ase_calculator(self) -> ase.calculators.calculator.Calculator: """Returns an ase calculator based on this driver.""" from .ase_calculator import DPDataCalculator @@ -130,7 +132,7 @@ class HybridDriver(Driver): This driver is the hybrid of SQM and DP. """ - def __init__(self, drivers: List[Union[dict, Driver]]) -> None: + def __init__(self, drivers: list[dict | Driver]) -> None: self.drivers = [] for driver in drivers: if isinstance(driver, Driver): @@ -157,6 +159,7 @@ def label(self, data: dict) -> dict: dict labeled data with energies and forces """ + labeled_data = {} for ii, driver in enumerate(self.drivers): lb_data = driver.label(data.copy()) if ii == 0: @@ -199,7 +202,7 @@ def register(key: str) -> Callable: return Minimizer.__MinimizerPlugin.register(key) @staticmethod - def get_minimizer(key: str) -> "Minimizer": + def get_minimizer(key: str) -> type[Minimizer]: """Get a minimizer plugin. Parameters diff --git a/dpdata/fhi_aims/output.py b/dpdata/fhi_aims/output.py index 9947a231..762e8bf4 100755 --- a/dpdata/fhi_aims/output.py +++ b/dpdata/fhi_aims/output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import warnings diff --git a/dpdata/format.py b/dpdata/format.py index cd77561a..ade83c21 100644 --- a/dpdata/format.py +++ b/dpdata/format.py @@ -1,5 +1,7 @@ """Implement the format plugin system.""" +from __future__ import annotations + import os from abc import ABC @@ -163,7 +165,7 @@ def decorator(object): if not isinstance(func_name, (list, tuple, set)): object.post_func = (func_name,) else: - object.post_func = func_name + object.post_func = tuple(func_name) return object return decorator diff --git a/dpdata/gaussian/gjf.py b/dpdata/gaussian/gjf.py index 90aaf2f0..b83dad1c 100644 --- a/dpdata/gaussian/gjf.py +++ b/dpdata/gaussian/gjf.py @@ -3,18 +3,19 @@ # under LGPL 3.0 license """Generate Gaussian input file.""" +from __future__ import annotations + import itertools import re import uuid import warnings -from typing import List, Optional, Tuple, Union import numpy as np from dpdata.periodic_table import Element -def _crd2frag(symbols: List[str], crds: np.ndarray) -> Tuple[int, List[int]]: +def _crd2frag(symbols: list[str], crds: np.ndarray) -> tuple[int, list[int]]: """Detect fragments from coordinates. Parameters @@ -102,12 +103,12 @@ def detect_multiplicity(symbols: np.ndarray) -> int: def make_gaussian_input( sys_data: dict, - keywords: Union[str, List[str]], - multiplicity: Union[str, int] = "auto", + keywords: str | list[str], + multiplicity: str | int = "auto", charge: int = 0, fragment_guesses: bool = False, - basis_set: Optional[str] = None, - keywords_high_multiplicity: Optional[str] = None, + basis_set: str | None = None, + keywords_high_multiplicity: str | None = None, nproc: int = 1, ) -> str: """Make gaussian input file. diff --git a/dpdata/gaussian/log.py b/dpdata/gaussian/log.py index 66881dc1..204cf464 100644 --- a/dpdata/gaussian/log.py +++ b/dpdata/gaussian/log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from ..periodic_table import ELEMENTS diff --git a/dpdata/gromacs/gro.py b/dpdata/gromacs/gro.py index b643eea8..aca2443b 100644 --- a/dpdata/gromacs/gro.py +++ b/dpdata/gromacs/gro.py @@ -1,4 +1,6 @@ #!/usr/bin/env python3 +from __future__ import annotations + import re import numpy as np diff --git a/dpdata/lammps/dump.py b/dpdata/lammps/dump.py index 906fed9e..f0ade2b0 100644 --- a/dpdata/lammps/dump.py +++ b/dpdata/lammps/dump.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import os import sys diff --git a/dpdata/lammps/lmp.py b/dpdata/lammps/lmp.py index 317b30ed..604b18d1 100644 --- a/dpdata/lammps/lmp.py +++ b/dpdata/lammps/lmp.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/md/msd.py b/dpdata/md/msd.py index cfb446dd..dfad9550 100644 --- a/dpdata/md/msd.py +++ b/dpdata/md/msd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .pbc import system_pbc_shift diff --git a/dpdata/md/pbc.py b/dpdata/md/pbc.py index 4eee7c65..e5757661 100644 --- a/dpdata/md/pbc.py +++ b/dpdata/md/pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/md/rdf.py b/dpdata/md/rdf.py index de8f1c74..b41be525 100644 --- a/dpdata/md/rdf.py +++ b/dpdata/md/rdf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/md/water.py b/dpdata/md/water.py index 0cb82cc9..cda4ad48 100644 --- a/dpdata/md/water.py +++ b/dpdata/md/water.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from .pbc import posi_diff, posi_shift diff --git a/dpdata/openmx/omx.py b/dpdata/openmx/omx.py index bd4b7031..d3afff00 100644 --- a/dpdata/openmx/omx.py +++ b/dpdata/openmx/omx.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import numpy as np from ..unit import ( diff --git a/dpdata/orca/output.py b/dpdata/orca/output.py index 13f072f3..183c3c85 100644 --- a/dpdata/orca/output.py +++ b/dpdata/orca/output.py @@ -1,9 +1,9 @@ -from typing import Tuple +from __future__ import annotations import numpy as np -def read_orca_sp_output(fn: str) -> Tuple[np.ndarray, np.ndarray, float, np.ndarray]: +def read_orca_sp_output(fn: str) -> tuple[np.ndarray, np.ndarray, float, np.ndarray]: """Read from ORCA output. Note that both the energy and the gradient should be printed. diff --git a/dpdata/periodic_table.py b/dpdata/periodic_table.py index 6df1fd41..e6b56cb0 100644 --- a/dpdata/periodic_table.py +++ b/dpdata/periodic_table.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json from pathlib import Path diff --git a/dpdata/plugin.py b/dpdata/plugin.py index 20e51eb2..9e18e212 100644 --- a/dpdata/plugin.py +++ b/dpdata/plugin.py @@ -1,5 +1,7 @@ """Base of plugin systems.""" +from __future__ import annotations + class Plugin: """A class to register plugins. diff --git a/dpdata/plugins/3dmol.py b/dpdata/plugins/3dmol.py index ec994dd9..56ec2516 100644 --- a/dpdata/plugins/3dmol.py +++ b/dpdata/plugins/3dmol.py @@ -1,4 +1,4 @@ -from typing import Tuple +from __future__ import annotations import numpy as np @@ -17,7 +17,7 @@ def to_system( self, data: dict, f_idx: int = 0, - size: Tuple[int] = (300, 300), + size: tuple[int] = (300, 300), style: dict = {"stick": {}, "sphere": {"radius": 0.4}}, **kwargs, ): diff --git a/dpdata/plugins/__init__.py b/dpdata/plugins/__init__.py index 66364aa2..15634bc0 100644 --- a/dpdata/plugins/__init__.py +++ b/dpdata/plugins/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib from pathlib import Path diff --git a/dpdata/plugins/abacus.py b/dpdata/plugins/abacus.py index 754221be..eb2d7786 100644 --- a/dpdata/plugins/abacus.py +++ b/dpdata/plugins/abacus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.abacus.md import dpdata.abacus.relax import dpdata.abacus.scf diff --git a/dpdata/plugins/amber.py b/dpdata/plugins/amber.py index cdc92a30..42fce552 100644 --- a/dpdata/plugins/amber.py +++ b/dpdata/plugins/amber.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess as sp import tempfile @@ -124,7 +126,7 @@ class SQMDriver(Driver): -15.41111246 """ - def __init__(self, sqm_exec: str = "sqm", **kwargs: dict) -> None: + def __init__(self, sqm_exec: str = "sqm", **kwargs) -> None: self.sqm_exec = sqm_exec self.kwargs = kwargs diff --git a/dpdata/plugins/ase.py b/dpdata/plugins/ase.py index f3347c99..1d818483 100644 --- a/dpdata/plugins/ase.py +++ b/dpdata/plugins/ase.py @@ -1,4 +1,6 @@ -from typing import TYPE_CHECKING, Optional, Type +from __future__ import annotations + +from typing import TYPE_CHECKING import numpy as np @@ -22,7 +24,7 @@ class ASEStructureFormat(Format): automatic detection fails. """ - def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict: + def from_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a System. Parameters @@ -56,7 +58,7 @@ def from_system(self, atoms: "ase.Atoms", **kwargs) -> dict: } return info_dict - def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: + def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict: """Convert ase.Atoms to a LabeledSystem. Energies and forces are calculated by the calculator. @@ -103,12 +105,12 @@ def from_labeled_system(self, atoms: "ase.Atoms", **kwargs) -> dict: def from_multi_systems( self, file_name: str, - begin: Optional[int] = None, - end: Optional[int] = None, - step: Optional[int] = None, - ase_fmt: Optional[str] = None, + begin: int | None = None, + end: int | None = None, + step: int | None = None, + ase_fmt: str | None = None, **kwargs, - ) -> "ase.Atoms": + ) -> ase.Atoms: """Convert a ASE supported file to ASE Atoms. It will finally be converted to MultiSystems. @@ -195,9 +197,9 @@ class ASETrajFormat(Format): def from_system( self, file_name: str, - begin: Optional[int] = 0, - end: Optional[int] = None, - step: Optional[int] = 1, + begin: int | None = 0, + end: int | None = None, + step: int | None = 1, **kwargs, ) -> dict: """Read ASE's trajectory file to `System` of multiple frames. @@ -239,9 +241,9 @@ def from_system( def from_labeled_system( self, file_name: str, - begin: Optional[int] = 0, - end: Optional[int] = None, - step: Optional[int] = 1, + begin: int | None = 0, + end: int | None = None, + step: int | None = 1, **kwargs, ) -> dict: """Read ASE's trajectory file to `System` of multiple frames. @@ -309,7 +311,7 @@ class ASEDriver(Driver): ASE calculator """ - def __init__(self, calculator: "ase.calculators.calculator.Calculator") -> None: + def __init__(self, calculator: ase.calculators.calculator.Calculator) -> None: """Setup the driver.""" self.calculator = calculator @@ -361,9 +363,9 @@ class ASEMinimizer(Minimizer): def __init__( self, driver: Driver, - optimizer: Optional[Type["Optimizer"]] = None, + optimizer: type[Optimizer] | None = None, fmax: float = 5e-3, - max_steps: Optional[int] = None, + max_steps: int | None = None, optimizer_kwargs: dict = {}, ) -> None: self.calculator = driver.ase_calculator diff --git a/dpdata/plugins/cp2k.py b/dpdata/plugins/cp2k.py index 162098f7..f5c1b539 100644 --- a/dpdata/plugins/cp2k.py +++ b/dpdata/plugins/cp2k.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import dpdata.cp2k.output diff --git a/dpdata/plugins/dftbplus.py b/dpdata/plugins/dftbplus.py index 5c8b4682..247fedc9 100644 --- a/dpdata/plugins/dftbplus.py +++ b/dpdata/plugins/dftbplus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.dftbplus.output import read_dftb_plus diff --git a/dpdata/plugins/fhi_aims.py b/dpdata/plugins/fhi_aims.py index 45b181fc..3c198aff 100644 --- a/dpdata/plugins/fhi_aims.py +++ b/dpdata/plugins/fhi_aims.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.fhi_aims.output from dpdata.format import Format diff --git a/dpdata/plugins/gaussian.py b/dpdata/plugins/gaussian.py index a22ce863..b55447b9 100644 --- a/dpdata/plugins/gaussian.py +++ b/dpdata/plugins/gaussian.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import subprocess as sp import tempfile @@ -81,7 +83,7 @@ class GaussianDriver(Driver): -1102.714590995794 """ - def __init__(self, gaussian_exec: str = "g16", **kwargs: dict) -> None: + def __init__(self, gaussian_exec: str = "g16", **kwargs) -> None: self.gaussian_exec = gaussian_exec self.kwargs = kwargs diff --git a/dpdata/plugins/gromacs.py b/dpdata/plugins/gromacs.py index 20e50835..12dece71 100644 --- a/dpdata/plugins/gromacs.py +++ b/dpdata/plugins/gromacs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.gromacs.gro from dpdata.format import Format diff --git a/dpdata/plugins/lammps.py b/dpdata/plugins/lammps.py index be89be9d..65e7f570 100644 --- a/dpdata/plugins/lammps.py +++ b/dpdata/plugins/lammps.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.lammps.dump import dpdata.lammps.lmp from dpdata.format import Format diff --git a/dpdata/plugins/list.py b/dpdata/plugins/list.py index 68a14074..f7036883 100644 --- a/dpdata/plugins/list.py +++ b/dpdata/plugins/list.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from dpdata.format import Format diff --git a/dpdata/plugins/n2p2.py b/dpdata/plugins/n2p2.py index 7162f09f..b70d6e6f 100644 --- a/dpdata/plugins/n2p2.py +++ b/dpdata/plugins/n2p2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/openmx.py b/dpdata/plugins/openmx.py index 675d1d2c..4e16566d 100644 --- a/dpdata/plugins/openmx.py +++ b/dpdata/plugins/openmx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.md.pbc import dpdata.openmx.omx from dpdata.format import Format diff --git a/dpdata/plugins/orca.py b/dpdata/plugins/orca.py index 2585743e..3d7fa38a 100644 --- a/dpdata/plugins/orca.py +++ b/dpdata/plugins/orca.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/psi4.py b/dpdata/plugins/psi4.py index ec7d9df1..c3b1ee1b 100644 --- a/dpdata/plugins/psi4.py +++ b/dpdata/plugins/psi4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/plugins/pwmat.py b/dpdata/plugins/pwmat.py index 11257c4d..80f219b6 100644 --- a/dpdata/plugins/pwmat.py +++ b/dpdata/plugins/pwmat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.pwmat.atomconfig diff --git a/dpdata/plugins/pymatgen.py b/dpdata/plugins/pymatgen.py index e7e527ff..322298c3 100644 --- a/dpdata/plugins/pymatgen.py +++ b/dpdata/plugins/pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.pymatgen.molecule diff --git a/dpdata/plugins/qe.py b/dpdata/plugins/qe.py index 6a98eedd..682bb202 100644 --- a/dpdata/plugins/qe.py +++ b/dpdata/plugins/qe.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.md.pbc import dpdata.qe.scf import dpdata.qe.traj diff --git a/dpdata/plugins/rdkit.py b/dpdata/plugins/rdkit.py index c7cef07f..f01b277d 100644 --- a/dpdata/plugins/rdkit.py +++ b/dpdata/plugins/rdkit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.rdkit.utils from dpdata.format import Format diff --git a/dpdata/plugins/siesta.py b/dpdata/plugins/siesta.py index 662b5c0e..906eeb51 100644 --- a/dpdata/plugins/siesta.py +++ b/dpdata/plugins/siesta.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import dpdata.siesta.aiMD_output import dpdata.siesta.output from dpdata.format import Format diff --git a/dpdata/plugins/vasp.py b/dpdata/plugins/vasp.py index c182bb95..d0681ceb 100644 --- a/dpdata/plugins/vasp.py +++ b/dpdata/plugins/vasp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np import dpdata.vasp.outcar diff --git a/dpdata/plugins/xyz.py b/dpdata/plugins/xyz.py index fdb5bf3b..322bf77c 100644 --- a/dpdata/plugins/xyz.py +++ b/dpdata/plugins/xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/dpdata/psi4/input.py b/dpdata/psi4/input.py index ad053281..3959cb75 100644 --- a/dpdata/psi4/input.py +++ b/dpdata/psi4/input.py @@ -1,4 +1,9 @@ -import numpy as np +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import numpy as np # Angston is used in Psi4 by default template = """molecule {{ diff --git a/dpdata/psi4/output.py b/dpdata/psi4/output.py index e93858de..9ccf90e1 100644 --- a/dpdata/psi4/output.py +++ b/dpdata/psi4/output.py @@ -1,11 +1,11 @@ -from typing import Tuple +from __future__ import annotations import numpy as np from dpdata.unit import LengthConversion -def read_psi4_output(fn: str) -> Tuple[str, np.ndarray, float, np.ndarray]: +def read_psi4_output(fn: str) -> tuple[str, np.ndarray, float, np.ndarray]: """Read from Psi4 output. Note that both the energy and the gradient should be printed. diff --git a/dpdata/pwmat/atomconfig.py b/dpdata/pwmat/atomconfig.py index f128aa5f..62eff77c 100644 --- a/dpdata/pwmat/atomconfig.py +++ b/dpdata/pwmat/atomconfig.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import numpy as np from ..periodic_table import ELEMENTS diff --git a/dpdata/pwmat/movement.py b/dpdata/pwmat/movement.py index 748744d6..ccfd819d 100644 --- a/dpdata/pwmat/movement.py +++ b/dpdata/pwmat/movement.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import warnings import numpy as np diff --git a/dpdata/py.typed b/dpdata/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/dpdata/pymatgen/molecule.py b/dpdata/pymatgen/molecule.py index fc05b07a..8d397984 100644 --- a/dpdata/pymatgen/molecule.py +++ b/dpdata/pymatgen/molecule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import Counter import numpy as np diff --git a/dpdata/pymatgen/structure.py b/dpdata/pymatgen/structure.py index 9f47baee..36e411c0 100644 --- a/dpdata/pymatgen/structure.py +++ b/dpdata/pymatgen/structure.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/qe/scf.py b/dpdata/qe/scf.py index cd9c6f28..37e5fbab 100755 --- a/dpdata/qe/scf.py +++ b/dpdata/qe/scf.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import os diff --git a/dpdata/qe/traj.py b/dpdata/qe/traj.py index e27990cb..1fbf0f71 100644 --- a/dpdata/qe/traj.py +++ b/dpdata/qe/traj.py @@ -1,4 +1,6 @@ #!/usr/bin/python3 +from __future__ import annotations + import warnings import numpy as np diff --git a/dpdata/rdkit/sanitize.py b/dpdata/rdkit/sanitize.py index 45060abc..2b0d7663 100644 --- a/dpdata/rdkit/sanitize.py +++ b/dpdata/rdkit/sanitize.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import time from copy import deepcopy diff --git a/dpdata/rdkit/utils.py b/dpdata/rdkit/utils.py index 9c7e50af..efeef607 100644 --- a/dpdata/rdkit/utils.py +++ b/dpdata/rdkit/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/dpdata/siesta/aiMD_output.py b/dpdata/siesta/aiMD_output.py index 4e1890ec..daa4f6a2 100644 --- a/dpdata/siesta/aiMD_output.py +++ b/dpdata/siesta/aiMD_output.py @@ -1,4 +1,5 @@ # !/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/siesta/output.py b/dpdata/siesta/output.py index 7418d543..0c944d5b 100644 --- a/dpdata/siesta/output.py +++ b/dpdata/siesta/output.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/stat.py b/dpdata/stat.py index 8de64982..5ec39570 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -1,4 +1,6 @@ -from abc import ABCMeta, abstractproperty +from __future__ import annotations + +from abc import ABCMeta, abstractmethod from functools import lru_cache import numpy as np @@ -61,11 +63,13 @@ def __init__(self, system_1: SYSTEM_TYPE, system_2: SYSTEM_TYPE) -> None: self.system_1 = system_1 self.system_2 = system_2 - @abstractproperty + @property + @abstractmethod def e_errors(self) -> np.ndarray: """Energy errors.""" - @abstractproperty + @property + @abstractmethod def f_errors(self) -> np.ndarray: """Force errors.""" @@ -114,12 +118,16 @@ class Errors(ErrorsBase): @lru_cache() def e_errors(self) -> np.ndarray: """Energy errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) return self.system_1["energies"] - self.system_2["energies"] @property @lru_cache() def f_errors(self) -> np.ndarray: """Force errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) return (self.system_1["forces"] - self.system_2["forces"]).ravel() @@ -147,6 +155,8 @@ class MultiErrors(ErrorsBase): @lru_cache() def e_errors(self) -> np.ndarray: """Energy errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) errors = [] for nn in self.system_1.systems.keys(): ss1 = self.system_1[nn] @@ -158,6 +168,8 @@ def e_errors(self) -> np.ndarray: @lru_cache() def f_errors(self) -> np.ndarray: """Force errors.""" + assert isinstance(self.system_1, self.SYSTEM_TYPE) + assert isinstance(self.system_2, self.SYSTEM_TYPE) errors = [] for nn in self.system_1.systems.keys(): ss1 = self.system_1[nn] diff --git a/dpdata/system.py b/dpdata/system.py index 33b7e7cf..2614bc23 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1,10 +1,24 @@ # %% +from __future__ import annotations + import glob import hashlib +import numbers import os +import sys import warnings from copy import deepcopy -from typing import Any, Dict, Optional, Tuple, Union +from typing import ( + TYPE_CHECKING, + Any, + Iterable, + overload, +) + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal import numpy as np @@ -13,6 +27,7 @@ # ensure all plugins are loaded! import dpdata.plugins +import dpdata.plugins.deepmd from dpdata.amber.mask import load_param_file, pick_by_amber_mask from dpdata.data_type import Axis, DataError, DataType, get_data_types from dpdata.driver import Driver, Minimizer @@ -26,6 +41,9 @@ utf8len, ) +if TYPE_CHECKING: + import parmed + def load_format(fmt): fmt = fmt.lower() @@ -64,11 +82,11 @@ class System: Attributes ---------- - DTYPES : tuple[DataType] + DTYPES : tuple[DataType, ...] data types of this class """ - DTYPES = ( + DTYPES: tuple[DataType, ...] = ( DataType("atom_numbs", list, (Axis.NTYPES,)), DataType("atom_names", list, (Axis.NTYPES,)), DataType("atom_types", np.ndarray, (Axis.NATOMS,)), @@ -84,13 +102,14 @@ class System: def __init__( self, - file_name=None, - fmt="auto", - type_map=None, - begin=0, - step=1, - data=None, - convergence_check=True, + # some formats do not use string as input + file_name: Any = None, + fmt: str = "auto", + type_map: list[str] | None = None, + begin: int = 0, + step: int = 1, + data: dict[str, Any] | None = None, + convergence_check: bool = True, **kwargs, ): """Constructor. @@ -211,13 +230,13 @@ def check_data(self): post_funcs = Plugin() - def from_fmt(self, file_name, fmt="auto", **kwargs): + def from_fmt(self, file_name: Any, fmt: str = "auto", **kwargs: Any): fmt = fmt.lower() if fmt == "auto": fmt = os.path.basename(file_name).split(".")[-1].lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) - def from_fmt_obj(self, fmtobj, file_name, **kwargs): + def from_fmt_obj(self, fmtobj: Format, file_name: Any, **kwargs: Any): data = fmtobj.from_system(file_name, **kwargs) if data: if isinstance(data, (list, tuple)): @@ -227,11 +246,11 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): self.data = {**self.data, **data} self.check_data() if hasattr(fmtobj.from_system, "post_func"): - for post_f in fmtobj.from_system.post_func: + for post_f in fmtobj.from_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self - def to(self, fmt: str, *args, **kwargs) -> "System": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> System: """Dump systems to the specific format. Parameters @@ -250,7 +269,7 @@ def to(self, fmt: str, *args, **kwargs) -> "System": """ return self.to_fmt_obj(load_format(fmt), *args, **kwargs) - def to_fmt_obj(self, fmtobj, *args, **kwargs): + def to_fmt_obj(self, fmtobj: Format, *args: Any, **kwargs: Any): return fmtobj.to_system(self.data, *args, **kwargs) def __repr__(self): @@ -268,13 +287,32 @@ def __str__(self): ret += "\n" + " ".join(map(str, self.get_atom_numbs())) return ret + @overload + def __getitem__(self, key: int | slice | list | np.ndarray) -> System: ... + @overload + def __getitem__( + self, key: Literal["atom_names", "real_atom_names"] + ) -> list[str]: ... + @overload + def __getitem__(self, key: Literal["atom_numbs"]) -> list[int]: ... + @overload + def __getitem__(self, key: Literal["nopbc"]) -> bool: ... + @overload + def __getitem__( + self, key: Literal["orig", "coords", "energies", "forces", "virials"] + ) -> np.ndarray: ... + @overload + def __getitem__(self, key: str) -> Any: + # other cases, for example customized data + ... + def __getitem__(self, key): """Returns proerty stored in System by key or by idx.""" if isinstance(key, (int, slice, list, np.ndarray)): return self.sub_system(key) return self.data[key] - def __len__(self): + def __len__(self) -> int: """Returns number of frames in the system.""" return self.get_nframes() @@ -293,13 +331,15 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) - def dump(self, filename, indent=4): + def dump(self, filename: str, indent: int = 4): """Dump .json or .yaml file.""" from monty.serialization import dumpfn dumpfn(self.as_dict(), filename, indent=indent) - def map_atom_types(self, type_map=None) -> np.ndarray: + def map_atom_types( + self, type_map: dict[str, int] | list[str] | None = None + ) -> np.ndarray: """Map the atom types of the system. Parameters @@ -338,7 +378,7 @@ def map_atom_types(self, type_map=None) -> np.ndarray: return new_atom_types @staticmethod - def load(filename): + def load(filename: str): """Rebuild System obj. from .json or .yaml file.""" from monty.serialization import loadfn @@ -347,7 +387,7 @@ def load(filename): @classmethod def from_dict(cls, data: dict): """Construct a System instance from a data dict.""" - from monty.serialization import MontyDecoder + from monty.serialization import MontyDecoder # type: ignore decoded = { k: MontyDecoder().process_decoded(v) @@ -356,7 +396,7 @@ def from_dict(cls, data: dict): } return cls(**decoded) - def as_dict(self): + def as_dict(self) -> dict: """Returns data dict of System instance.""" d = { "@module": self.__class__.__module__, @@ -365,23 +405,23 @@ def as_dict(self): } return d - def get_atom_names(self): + def get_atom_names(self) -> list[str]: """Returns name of atoms.""" return self.data["atom_names"] - def get_atom_types(self): + def get_atom_types(self) -> np.ndarray: """Returns type of atoms.""" return self.data["atom_types"] - def get_atom_numbs(self): + def get_atom_numbs(self) -> list[int]: """Returns number of atoms.""" return self.data["atom_numbs"] - def get_nframes(self): + def get_nframes(self) -> int: """Returns number of frames in the system.""" return len(self.data["cells"]) - def get_natoms(self): + def get_natoms(self) -> int: """Returns total number of atoms in the system.""" return len(self.data["atom_types"]) @@ -393,7 +433,7 @@ def copy(self): """Returns a copy of the system.""" return self.__class__.from_dict({"data": deepcopy(self.data)}) - def sub_system(self, f_idx): + def sub_system(self, f_idx: int | slice | list | np.ndarray): """Construct a subsystem from the system. Parameters @@ -408,15 +448,18 @@ def sub_system(self, f_idx): """ tmp = self.__class__() # convert int to array_like - if isinstance(f_idx, (int, np.int64)): + if isinstance(f_idx, numbers.Integral): f_idx = np.array([f_idx]) + assert not isinstance(f_idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NFRAMES in tt.shape: axis_nframes = tt.shape.index(Axis.NFRAMES) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray | list] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_nframes] = f_idx tmp.data[tt.name] = self.data[tt.name][tuple(new_shape)] else: @@ -424,7 +467,7 @@ def sub_system(self, f_idx): tmp.data[tt.name] = self.data[tt.name] return tmp - def append(self, system): + def append(self, system: System) -> bool: """Append a system to this system. Parameters @@ -480,7 +523,7 @@ def append(self, system): self.data["nopbc"] = False return True - def convert_to_mixed_type(self, type_map=None): + def convert_to_mixed_type(self, type_map: list[str] | None = None): """Convert the data dict to mixed type format structure, in order to append systems with different formula but the same number of atoms. Change the 'atom_names' to one placeholder type 'MIXED_TOKEN' and add 'real_atom_types' to store the real type @@ -506,7 +549,7 @@ def convert_to_mixed_type(self, type_map=None): self.data["atom_numbs"] = [natoms] self.data["atom_names"] = ["MIXED_TOKEN"] - def sort_atom_names(self, type_map=None): + def sort_atom_names(self, type_map: list[str] | None = None): """Sort atom_names of the system and reorder atom_numbs and atom_types accoarding to atom_names. If type_map is not given, atom_names will be sorted by alphabetical order. If type_map is given, atom_names will be type_map. @@ -518,7 +561,7 @@ def sort_atom_names(self, type_map=None): """ self.data = sort_atom_names(self.data, type_map=type_map) - def check_type_map(self, type_map): + def check_type_map(self, type_map: list[str] | None): """Assign atom_names to type_map if type_map is given and different from atom_names. @@ -530,7 +573,7 @@ def check_type_map(self, type_map): if type_map is not None and type_map != self.data["atom_names"]: self.sort_atom_names(type_map=type_map) - def apply_type_map(self, type_map): + def apply_type_map(self, type_map: list[str]): """Customize the element symbol order and it should maintain order consistency in dpgen or deepmd-kit. It is especially recommended for multiple complexsystems with multiple elements. @@ -560,13 +603,15 @@ def sort_atom_types(self) -> np.ndarray: continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_natoms] = idx self.data[tt.name] = self.data[tt.name][tuple(new_shape)] return idx @property - def formula(self): + def formula(self) -> str: """Return the formula of this system, like C3H5O2.""" return "".join( [ @@ -578,7 +623,7 @@ def formula(self): ) @property - def uniq_formula(self): + def uniq_formula(self) -> str: """Return the uniq_formula of this system. The uniq_formula sort the elements in formula by names. Systems with the same uniq_formula can be append together. @@ -628,7 +673,7 @@ def short_name(self) -> str: return short_formula return self.formula_hash - def extend(self, systems): + def extend(self, systems: Iterable[System]): """Extend a system list to this system. Parameters @@ -646,7 +691,7 @@ def apply_pbc(self): self.data["coords"] = np.matmul(ncoord, self.data["cells"]) @post_funcs.register("remove_pbc") - def remove_pbc(self, protect_layer=9): + def remove_pbc(self, protect_layer: int = 9): """This method does NOT delete the definition of the cells, it (1) revises the cell to a cubic cell and ensures that the cell boundary to any atom in the system is no less than `protect_layer` @@ -661,7 +706,7 @@ def remove_pbc(self, protect_layer=9): assert protect_layer >= 0, "the protect_layer should be no less than 0" remove_pbc(self.data, protect_layer) - def affine_map(self, trans, f_idx=0): + def affine_map(self, trans, f_idx: int | numbers.Integral = 0): assert np.linalg.det(trans) != 0 self.data["cells"][f_idx] = np.matmul(self.data["cells"][f_idx], trans) self.data["coords"][f_idx] = np.matmul(self.data["coords"][f_idx], trans) @@ -679,7 +724,7 @@ def rot_lower_triangular(self): for ii in range(self.get_nframes()): self.rot_frame_lower_triangular(ii) - def rot_frame_lower_triangular(self, f_idx=0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): qq, rr = np.linalg.qr(self.data["cells"][f_idx].T) if np.linalg.det(qq) < 0: qq = -qq @@ -696,11 +741,11 @@ def rot_frame_lower_triangular(self, f_idx=0): self.affine_map(rot, f_idx=f_idx) return np.matmul(qq, rot) - def add_atom_names(self, atom_names): + def add_atom_names(self, atom_names: list[str]): """Add atom_names that do not exist.""" self.data = add_atom_names(self.data, atom_names) - def replicate(self, ncopy): + def replicate(self, ncopy: list[int] | tuple[int, int, int]): """Replicate the each frame in the system in 3 dimensions. Each frame in the system will become a supercell. @@ -732,7 +777,7 @@ def replicate(self, ncopy): np.array(np.copy(data["atom_numbs"])) * np.prod(ncopy) ) tmp.data["atom_types"] = np.sort( - np.tile(np.copy(data["atom_types"]), np.prod(ncopy)), kind="stable" + np.tile(np.copy(data["atom_types"]), np.prod(ncopy).item()), kind="stable" ) tmp.data["cells"] = np.copy(data["cells"]) for ii in range(3): @@ -752,7 +797,7 @@ def replicate(self, ncopy): ) return tmp - def replace(self, initial_atom_type, end_atom_type, replace_num): + def replace(self, initial_atom_type: str, end_atom_type: str, replace_num: int): if type(self) is not dpdata.System: raise RuntimeError( "Must use method replace() of the instance of class dpdata.System" @@ -797,7 +842,11 @@ def replace(self, initial_atom_type, end_atom_type, replace_num): self.sort_atom_types() def perturb( - self, pert_num, cell_pert_fraction, atom_pert_distance, atom_pert_style="normal" + self, + pert_num: int, + cell_pert_fraction: float, + atom_pert_distance: float, + atom_pert_style: str = "normal", ): """Perturb each frame in the system randomly. The cell will be deformed randomly, and atoms will be displaced by a random distance in random direction. @@ -865,7 +914,7 @@ def nopbc(self): return False @nopbc.setter - def nopbc(self, value): + def nopbc(self, value: bool): self.data["nopbc"] = value def shuffle(self): @@ -874,7 +923,9 @@ def shuffle(self): self.data = self.sub_system(idx).data return idx - def predict(self, *args: Any, driver: str = "dp", **kwargs: Any) -> "LabeledSystem": + def predict( + self, *args: Any, driver: str | Driver = "dp", **kwargs: Any + ) -> LabeledSystem: """Predict energies and forces by a driver. Parameters @@ -903,8 +954,8 @@ def predict(self, *args: Any, driver: str = "dp", **kwargs: Any) -> "LabeledSyst return LabeledSystem(data=data) def minimize( - self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any - ) -> "LabeledSystem": + self, *args: Any, minimizer: str | Minimizer, **kwargs: Any + ) -> LabeledSystem: """Minimize the geometry. Parameters @@ -926,7 +977,11 @@ def minimize( data = minimizer.minimize(self.data.copy()) return LabeledSystem(data=data) - def pick_atom_idx(self, idx, nopbc=None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -942,15 +997,18 @@ def pick_atom_idx(self, idx, nopbc=None): new system """ new_sys = self.copy() - if isinstance(idx, (int, np.int64)): + if isinstance(idx, numbers.Integral): idx = np.array([idx]) + assert not isinstance(idx, int) for tt in self.DTYPES: if tt.name not in self.data: # skip optional data continue if tt.shape is not None and Axis.NATOMS in tt.shape: axis_natoms = tt.shape.index(Axis.NATOMS) - new_shape = [slice(None) for _ in self.data[tt.name].shape] + new_shape: list[slice | np.ndarray | list[int]] = [ + slice(None) for _ in self.data[tt.name].shape + ] new_shape[axis_natoms] = idx new_sys.data[tt.name] = self.data[tt.name][tuple(new_shape)] # recalculate atom_numbs according to atom_types @@ -962,7 +1020,7 @@ def pick_atom_idx(self, idx, nopbc=None): new_sys.nopbc = nopbc return new_sys - def remove_atom_names(self, atom_names): + def remove_atom_names(self, atom_names: str | list[str]): """Remove atom names and all such atoms. For example, you may not remove EP atoms in TIP4P/Ew water, which is not a real atom. @@ -988,7 +1046,13 @@ def remove_atom_names(self, atom_names): new_sys.data["atom_numbs"] = new_sys.data["atom_numbs"][: len(new_atom_names)] return new_sys - def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): + def pick_by_amber_mask( + self, + param: str | parmed.Structure, + maskstr: str, + pass_coords: bool = False, + nopbc: bool | None = None, + ): """Pick atoms by amber mask. Parameters @@ -1018,7 +1082,7 @@ def pick_by_amber_mask(self, param, maskstr, pass_coords=False, nopbc=None): return self.pick_atom_idx(idx, nopbc=nopbc) @classmethod - def register_data_type(cls, *data_type: Tuple[DataType]): + def register_data_type(cls, *data_type: DataType): """Register data type. Parameters @@ -1038,7 +1102,7 @@ def register_data_type(cls, *data_type: Tuple[DataType]): cls.DTYPES = tuple(dtypes_dict.values()) -def get_cell_perturb_matrix(cell_pert_fraction): +def get_cell_perturb_matrix(cell_pert_fraction: float): if cell_pert_fraction < 0: raise RuntimeError("cell_pert_fraction can not be negative") e0 = np.random.rand(6) @@ -1053,7 +1117,10 @@ def get_cell_perturb_matrix(cell_pert_fraction): return cell_pert_matrix -def get_atom_perturb_vector(atom_pert_distance, atom_pert_style="normal"): +def get_atom_perturb_vector( + atom_pert_distance: float, + atom_pert_style: str = "normal", +): random_vector = None if atom_pert_distance < 0: raise RuntimeError("atom_pert_distance can not be negative") @@ -1123,7 +1190,7 @@ class LabeledSystem(System): The number of skipped frames when loading MD trajectory. """ - DTYPES = System.DTYPES + ( + DTYPES: tuple[DataType, ...] = System.DTYPES + ( DataType("energies", np.ndarray, (Axis.NFRAMES,)), DataType("forces", np.ndarray, (Axis.NFRAMES, Axis.NATOMS, 3)), DataType("virials", np.ndarray, (Axis.NFRAMES, 3, 3), required=False), @@ -1142,7 +1209,7 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs): self.data = {**self.data, **data} self.check_data() if hasattr(fmtobj.from_labeled_system, "post_func"): - for post_f in fmtobj.from_labeled_system.post_func: + for post_f in fmtobj.from_labeled_system.post_func: # type: ignore self.post_funcs.get_plugin(post_f)(self) return self @@ -1178,11 +1245,11 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") return self.__class__.from_dict({"data": self_copy.data}) - def has_virial(self): + def has_virial(self) -> bool: # return ('virials' in self.data) and (len(self.data['virials']) > 0) return "virials" in self.data - def affine_map_fv(self, trans, f_idx): + def affine_map_fv(self, trans, f_idx: int | numbers.Integral): assert np.linalg.det(trans) != 0 self.data["forces"][f_idx] = np.matmul(self.data["forces"][f_idx], trans) if self.has_virial(): @@ -1190,12 +1257,12 @@ def affine_map_fv(self, trans, f_idx): trans.T, np.matmul(self.data["virials"][f_idx], trans) ) - def rot_frame_lower_triangular(self, f_idx=0): + def rot_frame_lower_triangular(self, f_idx: int | numbers.Integral = 0): trans = System.rot_frame_lower_triangular(self, f_idx=f_idx) self.affine_map_fv(trans, f_idx=f_idx) return trans - def correction(self, hl_sys): + def correction(self, hl_sys: LabeledSystem) -> LabeledSystem: """Get energy and force correction between self and a high-level LabeledSystem. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1224,7 +1291,7 @@ def correction(self, hl_sys): ) return corrected_sys - def remove_outlier(self, threshold: float = 8.0) -> "LabeledSystem": + def remove_outlier(self, threshold: float = 8.0) -> LabeledSystem: r"""Remove outlier frames from the system. Remove the frames whose energies satisfy the condition @@ -1275,14 +1342,16 @@ def __init__(self, *systems, type_map=None): type_map : list of str Maps atom type to name """ - self.systems = {} + self.systems: dict[str, System] = {} if type_map is not None: - self.atom_names = type_map + self.atom_names: list[str] = type_map else: - self.atom_names = [] + self.atom_names: list[str] = [] self.append(*systems) - def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): + def from_fmt_obj( + self, fmtobj: Format, directory, labeled: bool = True, **kwargs: Any + ): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for dd in fmtobj.from_multi_systems(directory, **kwargs): if labeled: @@ -1306,7 +1375,7 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs): self.append(*system_list) return self - def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): + def to_fmt_obj(self, fmtobj: Format, directory, *args: Any, **kwargs: Any): if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat): for fn, ss in zip( fmtobj.to_multi_systems( @@ -1325,7 +1394,7 @@ def to_fmt_obj(self, fmtobj, directory, *args, **kwargs): ) return self - def to(self, fmt: str, *args, **kwargs) -> "MultiSystems": + def to(self, fmt: str, *args: Any, **kwargs: Any) -> MultiSystems: """Dump systems to the specific format. Parameters @@ -1369,13 +1438,19 @@ def __add__(self, others): raise RuntimeError("Unspported data structure") @classmethod - def from_file(cls, file_name, fmt, **kwargs): + def from_file(cls, file_name, fmt: str, **kwargs: Any): multi_systems = cls() multi_systems.load_systems_from_file(file_name=file_name, fmt=fmt, **kwargs) return multi_systems @classmethod - def from_dir(cls, dir_name, file_name, fmt="auto", type_map=None): + def from_dir( + cls, + dir_name: str, + file_name: str, + fmt: str = "auto", + type_map: list[str] | None = None, + ): multi_systems = cls() target_file_list = sorted( glob.glob(f"./{dir_name}/**/{file_name}", recursive=True) @@ -1386,15 +1461,16 @@ def from_dir(cls, dir_name, file_name, fmt="auto", type_map=None): ) return multi_systems - def load_systems_from_file(self, file_name=None, fmt=None, **kwargs): + def load_systems_from_file(self, file_name=None, fmt: str | None = None, **kwargs): + assert fmt is not None fmt = fmt.lower() return self.from_fmt_obj(load_format(fmt), file_name, **kwargs) - def get_nframes(self): + def get_nframes(self) -> int: """Returns number of frames in all systems.""" return sum(len(system) for system in self.systems.values()) - def append(self, *systems): + def append(self, *systems: System | MultiSystems): """Append systems or MultiSystems to systems. Parameters @@ -1411,7 +1487,7 @@ def append(self, *systems): else: raise RuntimeError("Object must be System or MultiSystems!") - def __append(self, system): + def __append(self, system: System): if not system.formula: return # prevent changing the original system @@ -1423,7 +1499,7 @@ def __append(self, system): else: self.systems[formula] = system.copy() - def check_atom_names(self, system): + def check_atom_names(self, system: System): """Make atom_names in all systems equal, prevent inconsistent atom_types.""" # new_in_system = set(system["atom_names"]) - set(self.atom_names) # new_in_self = set(self.atom_names) - set(system["atom_names"]) @@ -1444,7 +1520,9 @@ def check_atom_names(self, system): system.add_atom_names(new_in_self) system.sort_atom_names(type_map=self.atom_names) - def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems": + def predict( + self, *args: Any, driver: str | Driver = "dp", **kwargs: Any + ) -> MultiSystems: """Predict energies and forces by a driver. Parameters @@ -1469,8 +1547,8 @@ def predict(self, *args: Any, driver="dp", **kwargs: Any) -> "MultiSystems": return new_multisystems def minimize( - self, *args: Any, minimizer: Union[str, Minimizer], **kwargs: Any - ) -> "MultiSystems": + self, *args: Any, minimizer: str | Minimizer, **kwargs: Any + ) -> MultiSystems: """Minimize geometry by a minimizer. Parameters @@ -1503,7 +1581,11 @@ def minimize( new_multisystems.append(ss.minimize(*args, minimizer=minimizer, **kwargs)) return new_multisystems - def pick_atom_idx(self, idx, nopbc=None): + def pick_atom_idx( + self, + idx: int | numbers.Integral | list[int] | slice | np.ndarray, + nopbc: bool | None = None, + ): """Pick atom index. Parameters @@ -1523,7 +1605,7 @@ def pick_atom_idx(self, idx, nopbc=None): new_sys.append(ss.pick_atom_idx(idx, nopbc=nopbc)) return new_sys - def correction(self, hl_sys: "MultiSystems"): + def correction(self, hl_sys: MultiSystems) -> MultiSystems: """Get energy and force correction between self (assumed low-level) and a high-level MultiSystems. The self's coordinates will be kept, but energy and forces will be replaced by the correction between these two systems. @@ -1558,12 +1640,14 @@ def correction(self, hl_sys: "MultiSystems"): for nn in self.systems.keys(): ll_ss = self[nn] hl_ss = hl_sys[nn] + assert isinstance(ll_ss, LabeledSystem) + assert isinstance(hl_ss, LabeledSystem) corrected_sys.append(ll_ss.correction(hl_ss)) return corrected_sys def train_test_split( - self, test_size: Union[float, int], seed: Optional[int] = None - ) -> Tuple["MultiSystems", "MultiSystems", Dict[str, np.ndarray]]: + self, test_size: float | int, seed: int | None = None + ) -> tuple[MultiSystems, MultiSystems, dict[str, np.ndarray]]: """Split systems into random train and test subsets. Parameters @@ -1619,7 +1703,7 @@ def train_test_split( return train_systems, test_systems, test_system_idx -def get_cls_name(cls: object) -> str: +def get_cls_name(cls: type[Any]) -> str: """Returns the fully qualified name of a class, such as `np.ndarray`. Parameters @@ -1654,7 +1738,7 @@ def add_format_methods(): for method, formatcls in Format.get_from_methods().items(): - def get_func(ff): + def get_func_from(ff): # ff is not initized when defining from_format so cannot be polluted def from_format(self, file_name, **kwargs): return self.from_fmt_obj(ff(), file_name, **kwargs) @@ -1662,22 +1746,22 @@ def from_format(self, file_name, **kwargs): from_format.__doc__ = f"Read data from :class:`{get_cls_name(ff)}` format." return from_format - setattr(System, method, get_func(formatcls)) - setattr(LabeledSystem, method, get_func(formatcls)) - setattr(MultiSystems, method, get_func(formatcls)) + setattr(System, method, get_func_from(formatcls)) + setattr(LabeledSystem, method, get_func_from(formatcls)) + setattr(MultiSystems, method, get_func_from(formatcls)) for method, formatcls in Format.get_to_methods().items(): - def get_func(ff): + def get_func_to(ff): def to_format(self, *args, **kwargs): return self.to_fmt_obj(ff(), *args, **kwargs) to_format.__doc__ = f"Dump data to :class:`{get_cls_name(ff)}` format." return to_format - setattr(System, method, get_func(formatcls)) - setattr(LabeledSystem, method, get_func(formatcls)) - setattr(MultiSystems, method, get_func(formatcls)) + setattr(System, method, get_func_to(formatcls)) + setattr(LabeledSystem, method, get_func_to(formatcls)) + setattr(MultiSystems, method, get_func_to(formatcls)) # at this point, System.DTYPES and LabeledSystem.DTYPES has been initialized System.register_data_type(*get_data_types(labeled=False)) diff --git a/dpdata/unit.py b/dpdata/unit.py index 5fc8fe1e..09981b96 100644 --- a/dpdata/unit.py +++ b/dpdata/unit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC from scipy import constants # noqa: TID253 diff --git a/dpdata/utils.py b/dpdata/utils.py index cf4a109e..e008120e 100644 --- a/dpdata/utils.py +++ b/dpdata/utils.py @@ -1,9 +1,34 @@ +from __future__ import annotations + +import sys +from typing import overload + +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal import numpy as np from dpdata.periodic_table import Element -def elements_index_map(elements, standard=False, inverse=False): +@overload +def elements_index_map( + elements: list[str], standard: bool, inverse: Literal[True] +) -> dict[int, str]: ... +@overload +def elements_index_map( + elements: list[str], standard: bool, inverse: Literal[False] = ... +) -> dict[str, int]: ... +@overload +def elements_index_map( + elements: list[str], standard: bool, inverse: bool = False +) -> dict[str, int] | dict[int, str]: ... + + +def elements_index_map( + elements: list[str], standard: bool = False, inverse: bool = False +) -> dict: if standard: elements.sort(key=lambda x: Element(x).Z) if inverse: diff --git a/dpdata/vasp/outcar.py b/dpdata/vasp/outcar.py index 0eddac91..0fa4cb68 100644 --- a/dpdata/vasp/outcar.py +++ b/dpdata/vasp/outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re import warnings diff --git a/dpdata/vasp/poscar.py b/dpdata/vasp/poscar.py index fde0f8fb..102e7904 100644 --- a/dpdata/vasp/poscar.py +++ b/dpdata/vasp/poscar.py @@ -1,4 +1,5 @@ #!/usr/bin/python3 +from __future__ import annotations import numpy as np diff --git a/dpdata/vasp/xml.py b/dpdata/vasp/xml.py index a534fd0c..352b107e 100755 --- a/dpdata/vasp/xml.py +++ b/dpdata/vasp/xml.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from __future__ import annotations import xml.etree.ElementTree as ET diff --git a/dpdata/xyz/quip_gap_xyz.py b/dpdata/xyz/quip_gap_xyz.py index 068bec1f..b23b27e0 100644 --- a/dpdata/xyz/quip_gap_xyz.py +++ b/dpdata/xyz/quip_gap_xyz.py @@ -1,5 +1,7 @@ #!/usr/bin/env python3 # %% +from __future__ import annotations + import re from collections import OrderedDict diff --git a/dpdata/xyz/xyz.py b/dpdata/xyz/xyz.py index 745a97b1..0c36ac32 100644 --- a/dpdata/xyz/xyz.py +++ b/dpdata/xyz/xyz.py @@ -1,4 +1,4 @@ -from typing import Tuple +from __future__ import annotations import numpy as np @@ -31,7 +31,7 @@ def coord_to_xyz(coord: np.ndarray, types: list) -> str: return "\n".join(buff) -def xyz_to_coord(xyz: str) -> Tuple[np.ndarray, list]: +def xyz_to_coord(xyz: str) -> tuple[np.ndarray, list]: """Convert xyz format to coordinates and types. Parameters diff --git a/plugin_example/dpdata_random/__init__.py b/plugin_example/dpdata_random/__init__.py index 22820e0f..cc14faca 100644 --- a/plugin_example/dpdata_random/__init__.py +++ b/plugin_example/dpdata_random/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.format import Format diff --git a/pyproject.toml b/pyproject.toml index 1be79442..6efe0818 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ 'h5py', 'wcmatch', 'importlib_metadata>=1.4; python_version < "3.8"', + 'typing_extensions; python_version < "3.8"', ] requires-python = ">=3.7" readme = "README.md" @@ -83,6 +84,7 @@ select = [ "UP", # pyupgrade "I", # isort "TID253", # banned-module-level-imports + "TCH", # flake8-type-checking ] ignore = [ "E501", # line too long @@ -122,3 +124,11 @@ banned-module-level-imports = [ "monty", "scipy", ] + +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + +[tool.pyright] +include = [ + "dpdata/*.py", +] diff --git a/tests/comp_sys.py b/tests/comp_sys.py index f4663780..99879af6 100644 --- a/tests/comp_sys.py +++ b/tests/comp_sys.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/context.py b/tests/context.py index 77a7557d..3214e28e 100644 --- a/tests/context.py +++ b/tests/context.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import sys diff --git a/tests/plugin/dpdata_plugin_test/__init__.py b/tests/plugin/dpdata_plugin_test/__init__.py index b3821cb3..ef26e7c1 100644 --- a/tests/plugin/dpdata_plugin_test/__init__.py +++ b/tests/plugin/dpdata_plugin_test/__init__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np from dpdata.data_type import Axis, DataType, register_data_type diff --git a/tests/poscars/poscar_ref_oh.py b/tests/poscars/poscar_ref_oh.py index f120183e..2d29aeeb 100644 --- a/tests/poscars/poscar_ref_oh.py +++ b/tests/poscars/poscar_ref_oh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/poscars/test_lammps_dump_s_su.py b/tests/poscars/test_lammps_dump_s_su.py index 28673dfc..967c767a 100644 --- a/tests/poscars/test_lammps_dump_s_su.py +++ b/tests/poscars/test_lammps_dump_s_su.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/pwmat/config_ref_ch4.py b/tests/pwmat/config_ref_ch4.py index 71aef7fe..72499398 100644 --- a/tests/pwmat/config_ref_ch4.py +++ b/tests/pwmat/config_ref_ch4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/pwmat/config_ref_oh.py b/tests/pwmat/config_ref_oh.py index 6f3e0561..ad546019 100644 --- a/tests/pwmat/config_ref_oh.py +++ b/tests/pwmat/config_ref_oh.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import numpy as np diff --git a/tests/test_abacus_md.py b/tests/test_abacus_md.py index 782ed521..ddcb7734 100644 --- a/tests/test_abacus_md.py +++ b/tests/test_abacus_md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_abacus_pw_scf.py b/tests/test_abacus_pw_scf.py index eb712fbe..8d13dddc 100644 --- a/tests/test_abacus_pw_scf.py +++ b/tests/test_abacus_pw_scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_abacus_relax.py b/tests/test_abacus_relax.py index 65d73e53..b752a426 100644 --- a/tests/test_abacus_relax.py +++ b/tests/test_abacus_relax.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_abacus_stru_dump.py b/tests/test_abacus_stru_dump.py index 46cb5de6..356aa57f 100644 --- a/tests/test_abacus_stru_dump.py +++ b/tests/test_abacus_stru_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_amber_md.py b/tests/test_amber_md.py index 3995371e..b0a06058 100644 --- a/tests/test_amber_md.py +++ b/tests/test_amber_md.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_amber_sqm.py b/tests/test_amber_sqm.py index 7f14ff84..b7f09110 100644 --- a/tests/test_amber_sqm.py +++ b/tests/test_amber_sqm.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_ase_traj.py b/tests/test_ase_traj.py index b6eab27e..8e4a6e12 100644 --- a/tests/test_ase_traj.py +++ b/tests/test_ase_traj.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, CompSys, IsPBC diff --git a/tests/test_bond_order_system.py b/tests/test_bond_order_system.py index 41a167fb..104e18f1 100644 --- a/tests/test_bond_order_system.py +++ b/tests/test_bond_order_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import glob import os import unittest diff --git a/tests/test_cell_to_low_triangle.py b/tests/test_cell_to_low_triangle.py index c080c8e5..34d0a90a 100644 --- a/tests/test_cell_to_low_triangle.py +++ b/tests/test_cell_to_low_triangle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_cli.py b/tests/test_cli.py index 200a1c1e..9d70db5f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess as sp import sys import unittest diff --git a/tests/test_corr.py b/tests/test_corr.py index 463c99af..a7c6f7c4 100644 --- a/tests/test_corr.py +++ b/tests/test_corr.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_cp2k_aimd_output.py b/tests/test_cp2k_aimd_output.py index bce24250..46f292b1 100644 --- a/tests/test_cp2k_aimd_output.py +++ b/tests/test_cp2k_aimd_output.py @@ -1,4 +1,6 @@ # %% +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys diff --git a/tests/test_cp2k_output.py b/tests/test_cp2k_output.py index 0e4b153d..da58e87c 100644 --- a/tests/test_cp2k_output.py +++ b/tests/test_cp2k_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys diff --git a/tests/test_custom_data_type.py b/tests/test_custom_data_type.py index 7e3278ea..e94ba5e0 100644 --- a/tests/test_custom_data_type.py +++ b/tests/test_custom_data_type.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import h5py # noqa: TID253 diff --git a/tests/test_deepmd_comp.py b/tests/test_deepmd_comp.py index 46f8e741..28428786 100644 --- a/tests/test_deepmd_comp.py +++ b/tests/test_deepmd_comp.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_deepmd_hdf5.py b/tests/test_deepmd_hdf5.py index 20d16c37..b4a22f3c 100644 --- a/tests/test_deepmd_hdf5.py +++ b/tests/test_deepmd_hdf5.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_deepmd_mixed.py b/tests/test_deepmd_mixed.py index 7e522e06..02044932 100644 --- a/tests/test_deepmd_mixed.py +++ b/tests/test_deepmd_mixed.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_deepmd_raw.py b/tests/test_deepmd_raw.py index 1b056726..af875fde 100644 --- a/tests/test_deepmd_raw.py +++ b/tests/test_deepmd_raw.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import unittest diff --git a/tests/test_dftbplus.py b/tests/test_dftbplus.py index 2a2913a5..29cdaa92 100644 --- a/tests/test_dftbplus.py +++ b/tests/test_dftbplus.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_elements_index.py b/tests/test_elements_index.py index 45408b4d..186d7b80 100644 --- a/tests/test_elements_index.py +++ b/tests/test_elements_index.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from dpdata.system import elements_index_map diff --git a/tests/test_empty.py b/tests/test_empty.py index 8787f954..12913bab 100644 --- a/tests/test_empty.py +++ b/tests/test_empty.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_md_multi_elem_output.py b/tests/test_fhi_md_multi_elem_output.py index a20c45bd..b11a52f5 100644 --- a/tests/test_fhi_md_multi_elem_output.py +++ b/tests/test_fhi_md_multi_elem_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_md_output.py b/tests/test_fhi_md_output.py index d205e391..391cc319 100644 --- a/tests/test_fhi_md_output.py +++ b/tests/test_fhi_md_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_fhi_output.py b/tests/test_fhi_output.py index 067e5f69..bd3582f3 100644 --- a/tests/test_fhi_output.py +++ b/tests/test_fhi_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_from_pymatgen.py b/tests/test_from_pymatgen.py index d3ddbe3e..7689a9d5 100644 --- a/tests/test_from_pymatgen.py +++ b/tests/test_from_pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_gaussian_driver.py b/tests/test_gaussian_driver.py index 07150bc7..ff163848 100644 --- a/tests/test_gaussian_driver.py +++ b/tests/test_gaussian_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import importlib import os import shutil diff --git a/tests/test_gaussian_gjf.py b/tests/test_gaussian_gjf.py index 2e5f4ea8..b3819946 100644 --- a/tests/test_gaussian_gjf.py +++ b/tests/test_gaussian_gjf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_gaussian_log.py b/tests/test_gaussian_log.py index 6622e684..784fd594 100644 --- a/tests/test_gaussian_log.py +++ b/tests/test_gaussian_log.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_gromacs_gro.py b/tests/test_gromacs_gro.py index 2971755f..674c6510 100644 --- a/tests/test_gromacs_gro.py +++ b/tests/test_gromacs_gro.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_json.py b/tests/test_json.py index 545e5db8..0b6f1b9d 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_lammps_dump_idx.py b/tests/test_lammps_dump_idx.py index 272cc222..39379158 100644 --- a/tests/test_lammps_dump_idx.py +++ b/tests/test_lammps_dump_idx.py @@ -1,4 +1,5 @@ # The index should map to that in the dump file +from __future__ import annotations import os import unittest diff --git a/tests/test_lammps_dump_shift_origin.py b/tests/test_lammps_dump_shift_origin.py index 4ecd6f87..a7444234 100644 --- a/tests/test_lammps_dump_shift_origin.py +++ b/tests/test_lammps_dump_shift_origin.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsPBC diff --git a/tests/test_lammps_dump_skipload.py b/tests/test_lammps_dump_skipload.py index 224ec6d1..299e1db4 100644 --- a/tests/test_lammps_dump_skipload.py +++ b/tests/test_lammps_dump_skipload.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_dump_to_system.py b/tests/test_lammps_dump_to_system.py index af9748a5..4d634037 100644 --- a/tests/test_lammps_dump_to_system.py +++ b/tests/test_lammps_dump_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_dump_unfold.py b/tests/test_lammps_dump_unfold.py index 1e78d975..587602c8 100644 --- a/tests/test_lammps_dump_unfold.py +++ b/tests/test_lammps_dump_unfold.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_lmp_dump.py b/tests/test_lammps_lmp_dump.py index 8e9cfb32..25525f76 100644 --- a/tests/test_lammps_lmp_dump.py +++ b/tests/test_lammps_lmp_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_lmp_to_system.py b/tests/test_lammps_lmp_to_system.py index 19e13312..444b1dd4 100644 --- a/tests/test_lammps_lmp_to_system.py +++ b/tests/test_lammps_lmp_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_lammps_read_from_trajs.py b/tests/test_lammps_read_from_trajs.py index f1e5afdd..578ae471 100644 --- a/tests/test_lammps_read_from_trajs.py +++ b/tests/test_lammps_read_from_trajs.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_msd.py b/tests/test_msd.py index 52b1ce93..7148b0b5 100644 --- a/tests/test_msd.py +++ b/tests/test_msd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_multisystems.py b/tests/test_multisystems.py index 2bda13a9..88d4593a 100644 --- a/tests/test_multisystems.py +++ b/tests/test_multisystems.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import tempfile import unittest diff --git a/tests/test_n2p2.py b/tests/test_n2p2.py index 855a2752..32ac6447 100644 --- a/tests/test_n2p2.py +++ b/tests/test_n2p2.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_openmx.py b/tests/test_openmx.py index 0705ed0a..2122e8f4 100644 --- a/tests/test_openmx.py +++ b/tests/test_openmx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_openmx_check_convergence.py b/tests/test_openmx_check_convergence.py index 362c89c5..b19ad6e8 100644 --- a/tests/test_openmx_check_convergence.py +++ b/tests/test_openmx_check_convergence.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_orca_spout.py b/tests/test_orca_spout.py index ecb1a5ca..d034fbb0 100644 --- a/tests/test_orca_spout.py +++ b/tests/test_orca_spout.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_periodic_table.py b/tests/test_periodic_table.py index 6b856e91..3cf36b99 100644 --- a/tests/test_periodic_table.py +++ b/tests/test_periodic_table.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_perturb.py b/tests/test_perturb.py index b89a8c7f..eea71116 100644 --- a/tests/test_perturb.py +++ b/tests/test_perturb.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from unittest.mock import patch diff --git a/tests/test_pick_atom_idx.py b/tests/test_pick_atom_idx.py index 0dc06991..ef3368f3 100644 --- a/tests/test_pick_atom_idx.py +++ b/tests/test_pick_atom_idx.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsNoPBC diff --git a/tests/test_predict.py b/tests/test_predict.py index f08125ab..6ab00be3 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_psi4.py b/tests/test_psi4.py index b9c2124e..93bfc408 100644 --- a/tests/test_psi4.py +++ b/tests/test_psi4.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import tempfile import textwrap import unittest diff --git a/tests/test_pwmat_config_dump.py b/tests/test_pwmat_config_dump.py index 9389c7a9..e4d5a5a8 100644 --- a/tests/test_pwmat_config_dump.py +++ b/tests/test_pwmat_config_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_config_to_system.py b/tests/test_pwmat_config_to_system.py index 0956f956..59fd7339 100644 --- a/tests/test_pwmat_config_to_system.py +++ b/tests/test_pwmat_config_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_mlmd.py b/tests/test_pwmat_mlmd.py index 4a920c15..8dcdb1ef 100644 --- a/tests/test_pwmat_mlmd.py +++ b/tests/test_pwmat_mlmd.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_pwmat_movement.py b/tests/test_pwmat_movement.py index 68a9e681..14e976b2 100644 --- a/tests/test_pwmat_movement.py +++ b/tests/test_pwmat_movement.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_pymatgen_molecule.py b/tests/test_pymatgen_molecule.py index 231bd97f..e6a1b5ee 100644 --- a/tests/test_pymatgen_molecule.py +++ b/tests/test_pymatgen_molecule.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_qe_cp_traj.py b/tests/test_qe_cp_traj.py index 6a963106..9e062986 100644 --- a/tests/test_qe_cp_traj.py +++ b/tests/test_qe_cp_traj.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_cp_traj_skipload.py b/tests/test_qe_cp_traj_skipload.py index 2964e716..43cbe88d 100644 --- a/tests/test_qe_cp_traj_skipload.py +++ b/tests/test_qe_cp_traj_skipload.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_qe_pw_scf.py b/tests/test_qe_pw_scf.py index 57a739fb..8703e7c2 100644 --- a/tests/test_qe_pw_scf.py +++ b/tests/test_qe_pw_scf.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_pw_scf_crystal_atomic_positions.py b/tests/test_qe_pw_scf_crystal_atomic_positions.py index 01c4df21..383ea6cd 100644 --- a/tests/test_qe_pw_scf_crystal_atomic_positions.py +++ b/tests/test_qe_pw_scf_crystal_atomic_positions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_qe_pw_scf_energy_bug.py b/tests/test_qe_pw_scf_energy_bug.py index 8360a7a9..b66ce924 100644 --- a/tests/test_qe_pw_scf_energy_bug.py +++ b/tests/test_qe_pw_scf_energy_bug.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_quip_gap_xyz.py b/tests/test_quip_gap_xyz.py index b383bd2f..a265544c 100644 --- a/tests/test_quip_gap_xyz.py +++ b/tests/test_quip_gap_xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_remove_atom_names.py b/tests/test_remove_atom_names.py index d2d4abc7..9fbd8faf 100644 --- a/tests/test_remove_atom_names.py +++ b/tests/test_remove_atom_names.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsNoPBC diff --git a/tests/test_remove_outlier.py b/tests/test_remove_outlier.py index b2cb52fc..c08de0bf 100644 --- a/tests/test_remove_outlier.py +++ b/tests/test_remove_outlier.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_remove_pbc.py b/tests/test_remove_pbc.py index d5befd77..d70a2f02 100644 --- a/tests/test_remove_pbc.py +++ b/tests/test_remove_pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_replace.py b/tests/test_replace.py index b16c388b..b9194137 100644 --- a/tests/test_replace.py +++ b/tests/test_replace.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from unittest.mock import patch diff --git a/tests/test_replicate.py b/tests/test_replicate.py index 99104c3c..3add2dc0 100644 --- a/tests/test_replicate.py +++ b/tests/test_replicate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompSys, IsPBC diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index 9c462214..3ac33c2f 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_siesta_aiMD_output.py b/tests/test_siesta_aiMD_output.py index a1ba31b6..4dcb0445 100644 --- a/tests/test_siesta_aiMD_output.py +++ b/tests/test_siesta_aiMD_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_siesta_output.py b/tests/test_siesta_output.py index 9ff0167a..c649f7d0 100644 --- a/tests/test_siesta_output.py +++ b/tests/test_siesta_output.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_split_dataset.py b/tests/test_split_dataset.py index a5419b7b..ac0960cf 100644 --- a/tests/test_split_dataset.py +++ b/tests/test_split_dataset.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_sqm_driver.py b/tests/test_sqm_driver.py index 3dbc6df4..d7c0da73 100644 --- a/tests/test_sqm_driver.py +++ b/tests/test_sqm_driver.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import shutil import unittest diff --git a/tests/test_stat.py b/tests/test_stat.py index 9ae8a175..863cea6c 100644 --- a/tests/test_stat.py +++ b/tests/test_stat.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from context import dpdata diff --git a/tests/test_system_append.py b/tests/test_system_append.py index a2c30b23..7c325113 100644 --- a/tests/test_system_append.py +++ b/tests/test_system_append.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_system_apply_pbc.py b/tests/test_system_apply_pbc.py index 9cf44ae0..2114cf6a 100644 --- a/tests/test_system_apply_pbc.py +++ b/tests/test_system_apply_pbc.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_system_set_type.py b/tests/test_system_set_type.py index 4bb14b62..d8362ec7 100644 --- a/tests/test_system_set_type.py +++ b/tests/test_system_set_type.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_to_ase.py b/tests/test_to_ase.py index 60dc931d..09b830ba 100644 --- a/tests/test_to_ase.py +++ b/tests/test_to_ase.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_to_list.py b/tests/test_to_list.py index d559ffce..998f1265 100644 --- a/tests/test_to_list.py +++ b/tests/test_to_list.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from comp_sys import CompLabeledSys, IsPBC diff --git a/tests/test_to_pymatgen.py b/tests/test_to_pymatgen.py index b55443d4..72d1b27a 100644 --- a/tests/test_to_pymatgen.py +++ b/tests/test_to_pymatgen.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_to_pymatgen_entry.py b/tests/test_to_pymatgen_entry.py index 7111dcdc..dfdeb468 100644 --- a/tests/test_to_pymatgen_entry.py +++ b/tests/test_to_pymatgen_entry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_type_map.py b/tests/test_type_map.py index 2cc50865..92d25ada 100644 --- a/tests/test_type_map.py +++ b/tests/test_type_map.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest from itertools import permutations diff --git a/tests/test_vasp_outcar.py b/tests/test_vasp_outcar.py index fb2ec1c9..832b0a91 100644 --- a/tests/test_vasp_outcar.py +++ b/tests/test_vasp_outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_vasp_poscar_dump.py b/tests/test_vasp_poscar_dump.py index a81cbe94..62f21598 100644 --- a/tests/test_vasp_poscar_dump.py +++ b/tests/test_vasp_poscar_dump.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_poscar_to_system.py b/tests/test_vasp_poscar_to_system.py index dcb83bfd..7457d33d 100644 --- a/tests/test_vasp_poscar_to_system.py +++ b/tests/test_vasp_poscar_to_system.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_unconverged_outcar.py b/tests/test_vasp_unconverged_outcar.py index 7e1b3535..1f3b3d2d 100644 --- a/tests/test_vasp_unconverged_outcar.py +++ b/tests/test_vasp_unconverged_outcar.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_vasp_xml.py b/tests/test_vasp_xml.py index cc0bbb41..0b917754 100644 --- a/tests/test_vasp_xml.py +++ b/tests/test_vasp_xml.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import unittest import numpy as np diff --git a/tests/test_water_ions.py b/tests/test_water_ions.py index 788030f3..40c1c143 100644 --- a/tests/test_water_ions.py +++ b/tests/test_water_ions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import unittest diff --git a/tests/test_xyz.py b/tests/test_xyz.py index a84ad28b..d9bcf70e 100644 --- a/tests/test_xyz.py +++ b/tests/test_xyz.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import tempfile import unittest