Skip to content

Commit

Permalink
chore: improve type annotations (#659)
Browse files Browse the repository at this point in the history
I give up fix type annotations in all files, but it's meaningful to fix
the top-module files `dpdata/*.py`.

<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->

## Summary by CodeRabbit

- **New Features**
  - Added `name` property to `DPDataCalculator` class.
  - Introduced assertion for `rdkit_mol` in bond order system.

- **Improvements**
  - Enhanced type hinting across various functions and methods.
- Added decorators and type assertion checks for better code validation.

- **Maintenance**
- Introduced a GitHub Actions workflow for running Pyright type checker.
  - Updated dependencies and configurations in `pyproject.toml`.

<!-- end of auto-generated comment: release notes by coderabbit.ai -->

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
  • Loading branch information
njzjz committed May 17, 2024
1 parent a7bf93d commit 626e692
Show file tree
Hide file tree
Showing 180 changed files with 653 additions and 162 deletions.
19 changes: 19 additions & 0 deletions .github/workflows/pyright.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
on:
- push
- pull_request

name: Type checker
jobs:
pyright:
name: pyright
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- run: pip install uv
- run: uv pip install --system -e .[amber,ase,pymatgen] rdkit openbabel-wheel
- uses: jakebailey/pyright-action@v2
with:
version: 1.1.363
2 changes: 2 additions & 0 deletions benchmark/test_import.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import subprocess
import sys

Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
from __future__ import annotations

import os
import subprocess as sp
import sys
Expand Down
9 changes: 8 additions & 1 deletion docs/make_format.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from __future__ import annotations

import csv
import os
import sys
from collections import defaultdict
from inspect import Parameter, Signature, cleandoc, signature
from typing import Literal

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

from numpydoc.docscrape import Parameter as numpydoc_Parameter
from numpydoc.docscrape_sphinx import SphinxDocString
Expand Down
2 changes: 2 additions & 0 deletions docs/nb/try_dpdata.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"from __future__ import annotations\n",
"\n",
"import dpdata"
]
},
Expand Down
2 changes: 2 additions & 0 deletions dpdata/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from __future__ import annotations

__version__ = "unknown"
2 changes: 2 additions & 0 deletions dpdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from . import lammps, md, vasp
from .bond_order_system import BondOrderSystem
from .system import LabeledSystem, MultiSystems, System
Expand Down
2 changes: 2 additions & 0 deletions dpdata/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from dpdata.cli import dpdata_cli

if __name__ == "__main__":
Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/md.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import warnings

Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/relax.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os

import numpy as np
Expand Down
2 changes: 2 additions & 0 deletions dpdata/abacus/scf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import re
import warnings
Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/mask.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Amber mask."""

from __future__ import annotations

try:
import parmed
except ImportError:
Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/md.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import re

Expand Down
2 changes: 2 additions & 0 deletions dpdata/amber/sqm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import numpy as np

from dpdata.periodic_table import ELEMENTS
Expand Down
21 changes: 13 additions & 8 deletions dpdata/ase_calculator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from __future__ import annotations

from typing import TYPE_CHECKING

from ase.calculators.calculator import ( # noqa: TID253
Calculator,
Expand All @@ -23,7 +25,10 @@ class DPDataCalculator(Calculator):
dpdata driver
"""

name = "dpdata"
@property
def name(self) -> str:
return "dpdata"

implemented_properties = ["energy", "free_energy", "forces", "virial", "stress"]

def __init__(self, driver: Driver, **kwargs) -> None:
Expand All @@ -32,9 +37,9 @@ def __init__(self, driver: Driver, **kwargs) -> None:

def calculate(
self,
atoms: Optional["Atoms"] = None,
properties: List[str] = ["energy", "forces"],
system_changes: List[str] = all_changes,
atoms: Atoms | None = None,
properties: list[str] = ["energy", "forces"],
system_changes: list[str] = all_changes,
):
"""Run calculation with a driver.
Expand All @@ -48,10 +53,10 @@ def calculate(
system_changes : List[str], optional
unused, only for function signature compatibility, by default all_changes
"""
if atoms is not None:
self.atoms = atoms.copy()
assert atoms is not None
atoms = atoms.copy()

system = dpdata.System(self.atoms, fmt="ase/structure")
system = dpdata.System(atoms, fmt="ase/structure")
data = system.predict(driver=self.driver).data

self.results["energy"] = data["energies"][0]
Expand Down
9 changes: 6 additions & 3 deletions dpdata/bond_order_system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# %%
# Bond Order System
from __future__ import annotations

from copy import deepcopy

import numpy as np
Expand Down Expand Up @@ -96,13 +98,14 @@ def from_fmt_obj(self, fmtobj, file_name, **kwargs):
mol = fmtobj.from_bond_order_system(file_name, **kwargs)
self.from_rdkit_mol(mol)
if hasattr(fmtobj.from_bond_order_system, "post_func"):
for post_f in fmtobj.from_bond_order_system.post_func:
for post_f in fmtobj.from_bond_order_system.post_func: # type: ignore
self.post_funcs.get_plugin(post_f)(self)
return self

def to_fmt_obj(self, fmtobj, *args, **kwargs):
from rdkit.Chem import Conformer

assert self.rdkit_mol is not None
self.rdkit_mol.RemoveAllConformers()
for ii in range(self.get_nframes()):
conf = Conformer()
Expand Down Expand Up @@ -145,9 +148,9 @@ def get_formal_charges(self):
"""Return the formal charges on each atom."""
return self.data["formal_charges"]

def copy(self):
def copy(self): # type: ignore
new_mol = deepcopy(self.rdkit_mol)
self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol)
return self.__class__(data=deepcopy(self.data), rdkit_mol=new_mol)

def __add__(self, other):
raise NotImplementedError(
Expand Down
9 changes: 5 additions & 4 deletions dpdata/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Command line interface for dpdata."""

from __future__ import annotations

import argparse
from typing import Optional

from . import __version__
from .system import LabeledSystem, MultiSystems, System
Expand Down Expand Up @@ -59,11 +60,11 @@ def convert(
*,
from_file: str,
from_format: str = "auto",
to_file: Optional[str] = None,
to_format: Optional[str] = None,
to_file: str | None = None,
to_format: str | None = None,
no_labeled: bool = False,
multi: bool = False,
type_map: Optional[list] = None,
type_map: list | None = None,
**kwargs,
):
"""Convert files from one format to another one.
Expand Down
1 change: 1 addition & 0 deletions dpdata/cp2k/cell.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# %%
from __future__ import annotations

import numpy as np

Expand Down
2 changes: 2 additions & 0 deletions dpdata/cp2k/output.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# %%
from __future__ import annotations

import math
import re
from collections import OrderedDict
Expand Down
13 changes: 8 additions & 5 deletions dpdata/data_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from enum import Enum, unique
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING

import numpy as np

Expand Down Expand Up @@ -50,16 +52,17 @@ def __init__(
self,
name: str,
dtype: type,
shape: Tuple[int, Axis] = None,
shape: tuple[int | Axis, ...] | None = None,
required: bool = True,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required

def real_shape(self, system: "System") -> Tuple[int]:
def real_shape(self, system: System) -> tuple[int]:
"""Returns expected real shape of a system."""
assert self.shape is not None
shape = []
for ii in self.shape:
if ii is Axis.NFRAMES:
Expand All @@ -70,7 +73,7 @@ def real_shape(self, system: "System") -> Tuple[int]:
shape.append(system.get_natoms())
elif ii is Axis.NBONDS:
# BondOrderSystem
shape.append(system.get_nbonds())
shape.append(system.get_nbonds()) # type: ignore
elif ii == -1:
shape.append(AnyInt(-1))
elif isinstance(ii, int):
Expand All @@ -79,7 +82,7 @@ def real_shape(self, system: "System") -> Tuple[int]:
raise RuntimeError("Shape is not an int!")
return tuple(shape)

def check(self, system: "System"):
def check(self, system: System):
"""Check if a system has correct data of this type.
Parameters
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/comp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import glob
import os
import shutil
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/mixed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import glob
import os
import shutil
Expand Down
2 changes: 2 additions & 0 deletions dpdata/deepmd/raw.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import os
import warnings

Expand Down
4 changes: 2 additions & 2 deletions dpdata/dftbplus/output.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple
from __future__ import annotations

import numpy as np


def read_dftb_plus(fn_1: str, fn_2: str) -> Tuple[str, np.ndarray, float, np.ndarray]:
def read_dftb_plus(fn_1: str, fn_2: str) -> tuple[str, np.ndarray, float, np.ndarray]:
"""Read from DFTB+ input and output.
Parameters
Expand Down
15 changes: 9 additions & 6 deletions dpdata/driver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Driver plugin system."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Callable, List, Union
from typing import TYPE_CHECKING, Callable

from .plugin import Plugin

if TYPE_CHECKING:
import ase
import ase.calculators.calculator


class Driver(ABC):
Expand Down Expand Up @@ -43,7 +45,7 @@ def register(key: str) -> Callable:
return Driver.__DriverPlugin.register(key)

@staticmethod
def get_driver(key: str) -> "Driver":
def get_driver(key: str) -> type[Driver]:
"""Get a driver plugin.
Parameters
Expand Down Expand Up @@ -97,7 +99,7 @@ def label(self, data: dict) -> dict:
return NotImplemented

@property
def ase_calculator(self) -> "ase.calculators.calculator.Calculator":
def ase_calculator(self) -> ase.calculators.calculator.Calculator:
"""Returns an ase calculator based on this driver."""
from .ase_calculator import DPDataCalculator

Expand Down Expand Up @@ -130,7 +132,7 @@ class HybridDriver(Driver):
This driver is the hybrid of SQM and DP.
"""

def __init__(self, drivers: List[Union[dict, Driver]]) -> None:
def __init__(self, drivers: list[dict | Driver]) -> None:
self.drivers = []
for driver in drivers:
if isinstance(driver, Driver):
Expand All @@ -157,6 +159,7 @@ def label(self, data: dict) -> dict:
dict
labeled data with energies and forces
"""
labeled_data = {}
for ii, driver in enumerate(self.drivers):
lb_data = driver.label(data.copy())
if ii == 0:
Expand Down Expand Up @@ -199,7 +202,7 @@ def register(key: str) -> Callable:
return Minimizer.__MinimizerPlugin.register(key)

@staticmethod
def get_minimizer(key: str) -> "Minimizer":
def get_minimizer(key: str) -> type[Minimizer]:
"""Get a minimizer plugin.
Parameters
Expand Down
2 changes: 2 additions & 0 deletions dpdata/fhi_aims/output.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import re
import warnings

Expand Down
4 changes: 3 additions & 1 deletion dpdata/format.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Implement the format plugin system."""

from __future__ import annotations

import os
from abc import ABC

Expand Down Expand Up @@ -163,7 +165,7 @@ def decorator(object):
if not isinstance(func_name, (list, tuple, set)):
object.post_func = (func_name,)
else:
object.post_func = func_name
object.post_func = tuple(func_name)
return object

return decorator
Expand Down
Loading

0 comments on commit 626e692

Please sign in to comment.