Skip to content

Commit

Permalink
don’t run pylint (#3)
Browse files Browse the repository at this point in the history
* don’t run pylint
* use jax config
* minimum python version
* fix py versions
* ignore pypy for now

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Nov 7, 2023
1 parent 3956150 commit 4bfacc1
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 46 deletions.
18 changes: 10 additions & 8 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ jobs:
- uses: pre-commit/action@v3.0.0
with:
extra_args: --hook-stage manual --all-files
- name: Run PyLint
run: |
echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json"
pipx run nox -s pylint
# - name: Run PyLint
# run: |
# echo "::add-matcher::$GITHUB_WORKSPACE/.github/matchers/pylint.json"
# pipx run nox -s pylint

checks:
name: Check Python ${{ matrix.python-version }} on ${{ matrix.runs-on }}
Expand All @@ -40,12 +40,14 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.12"]
python-version: ["3.11", "3.12"]
runs-on: [ubuntu-latest, macos-latest, windows-latest]

include:
- python-version: pypy-3.10
runs-on: ubuntu-latest
# TODO: figure out OpenBLAS install
# ``ERROR: Dependency "OpenBLAS" not found, tried pkgconfig and cmake``
# include:
# - python-version: pypy-3.11
# runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v4
Expand Down
10 changes: 3 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
description = "Galactic Dynamix in Jax"
readme = "README.md"
requires-python = ">=3.8"
requires-python = ">=3.11"
classifiers = [
"Development Status :: 1 - Planning",
"Intended Audience :: Science/Research",
Expand All @@ -20,9 +20,6 @@ classifiers = [
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
Expand All @@ -33,7 +30,6 @@ dependencies = [
"astropy >= 5.3",
"diffrax",
"equinox",
"gala >= 1.7",
"jax",
"jax_cosmo",
"typing_extensions",
Expand Down Expand Up @@ -110,7 +106,7 @@ disallow_untyped_defs = true
disallow_incomplete_defs = true

[[tool.mypy.overrides]]
module = ["astropy.*", "diffrax.*", "equinox.*", "gala.*", "jax.*", "jax_cosmo.*"]
module = ["astropy.*", "diffrax.*", "equinox.*", "jax.*", "jax_cosmo.*"]
ignore_missing_imports = true


Expand Down Expand Up @@ -158,7 +154,7 @@ isort.required-imports = ["from __future__ import annotations"]


[tool.pylint]
py-version = "3.8"
py-version = "3.11"
ignore-paths = [".*/_version.py"]
reports.output-format = "colorized"
similarities.ignore-imports = "yes"
Expand Down
2 changes: 1 addition & 1 deletion src/galdynamix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

__all__ = ["__version__"]

from jax.config import config
from jax import config

from ._version import version as __version__

Expand Down
8 changes: 6 additions & 2 deletions src/galdynamix/integrate/_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

__all__ = ["Integrator"]
__all__ = ["AbstractIntegrator"]

import abc
from typing import Any, Protocol
Expand All @@ -14,8 +14,12 @@ def __call__(self, t: jt.Array, xv: jt.Array, args: Any) -> jt.Array:
...


class Integrator(eqx.Module): # type: ignore[misc]
class AbstractIntegrator(eqx.Module): # type: ignore[misc]
"""Integrator Class."""

F: FCallable
"""The function to integrate."""
# TODO: should this be moved to be the first argument of the run method?

@abc.abstractmethod
def run(
Expand Down
4 changes: 2 additions & 2 deletions src/galdynamix/integrate/_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
)
from diffrax import SaveAt as DiffraxSaveAt

from galdynamix.integrate._base import Integrator
from galdynamix.integrate._base import AbstractIntegrator


class DiffraxIntegrator(Integrator):
class DiffraxIntegrator(AbstractIntegrator):
"""Thin wrapper around ``diffrax.diffeqsolve``."""

_: KW_ONLY
Expand Down
22 changes: 14 additions & 8 deletions src/galdynamix/potential/_potential/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
import jax.numpy as xp
import jax.typing as jt
from astropy.constants import G as apy_G
from gala.units import UnitSystem, dimensionless

from galdynamix.integrate._base import AbstractIntegrator
from galdynamix.integrate._builtin import DiffraxIntegrator
from galdynamix.potential._potential.param.field import ParameterField
from galdynamix.units import UnitSystem, dimensionless
from galdynamix.utils import partial_jit


Expand Down Expand Up @@ -94,17 +96,21 @@ def acceleration(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# Convenience methods

@partial_jit()
def _vel_acc(self, t: jt.Array, qp: jt.Array, args: Any) -> jt.Array:
def _vel_acc(self, t: jt.Array, qp: jt.Array, args: tuple[Any, ...]) -> jt.Array:
return xp.hstack([qp[3:], self.acceleration(qp[:3], t)])

@partial_jit()
@partial_jit(static_argnames=("Integrator", "integrator_kw"))
def integrate_orbit(
self, w0: jt.Array, t0: jt.Array, t1: jt.Array, ts: jt.Array | None
self,
w0: jt.Array,
t0: jt.Array,
t1: jt.Array,
ts: jt.Array | None,
*,
Integrator: type[AbstractIntegrator] = DiffraxIntegrator,
integrator_kw: dict[str, Any] | None = None,
) -> jt.Array:
# TODO: allow passing in integrator options
from galdynamix.integrate._builtin import DiffraxIntegrator as Integrator

return Integrator(self._vel_acc).run(w0, t0, t1, ts)
return Integrator(self._vel_acc, **(integrator_kw or {})).run(w0, t0, t1, ts)


# ===========================================================================
Expand Down
8 changes: 2 additions & 6 deletions src/galdynamix/potential/_potential/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@

from dataclasses import KW_ONLY

import astropy.units as u
import equinox as eqx
import jax
import jax.numpy as xp
import jax.typing as jt
from gala.units import UnitSystem
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline

from galdynamix.potential._potential.base import AbstractPotential
from galdynamix.potential._potential.param import AbstractParameter, ParameterField
from galdynamix.units import galactic
from galdynamix.utils import partial_jit
from galdynamix.utils.dataclasses import field

Expand Down Expand Up @@ -135,9 +134,6 @@ def potential_energy(self, q: jt.Array, /, t: jt.Array) -> jt.Array:
# -------------------------------------------------------------------


usys = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian)


@jax.jit # type: ignore[misc]
def get_splines(x_eval: jt.Array, x: jt.Array, y: jt.Array) -> Any:
return InterpolatedUnivariateSpline(x, y, k=3)(x_eval)
Expand All @@ -152,7 +148,7 @@ def single_subhalo_potential(
TODO: custom unit specification/subhalo potential specficiation.
Currently supports units kpc, Myr, Msun, rad.
"""
pot_single = Isochrone(m=params["m"], a=params["a"], units=usys)
pot_single = Isochrone(m=params["m"], a=params["a"], units=galactic)
return pot_single.potential_energy(q, t)


Expand Down
8 changes: 5 additions & 3 deletions src/galdynamix/potential/_potential/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import equinox as eqx
import jax.numpy as xp
import jax.typing as jt
from gala.units import UnitSystem, dimensionless

from galdynamix.units import UnitSystem, dimensionless
from galdynamix.utils import ImmutableDict, partial_jit

from .base import AbstractPotentialBase
Expand All @@ -25,7 +25,9 @@ class CompositePotential(ImmutableDict[AbstractPotentialBase], AbstractPotential

_data: dict[str, AbstractPotentialBase]
_: KW_ONLY
units: UnitSystem = eqx.field(static=True)
units: UnitSystem = eqx.field(
static=True, converter=lambda x: dimensionless if x is None else UnitSystem(x)
)
_G: float = eqx.field(init=False, static=True)

def __init__(
Expand All @@ -37,7 +39,7 @@ def __init__(
**kwargs: AbstractPotentialBase,
) -> None:
super().__init__(potentials, **kwargs) # type: ignore[arg-type]
self.units = dimensionless if units is None else UnitSystem(units)
self.units = self.__dataclass_fields__["units"].metadata["converter"](units)
# TODO: check unit systems of contained potentials to make sure they match.

self._init_units()
Expand Down
129 changes: 129 additions & 0 deletions src/galdynamix/units.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Paired down UnitSystem class from gala.
See gala's license below.
```
The MIT License (MIT)
Copyright (c) 2012-2023 Adrian M. Price-Whelan
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```
"""

from __future__ import annotations

from typing import Any, ClassVar

__all__ = [
"UnitSystem",
"DimensionlessUnitSystem",
"galactic",
"dimensionless",
"solarsystem",
]


import astropy.units as u


class UnitSystem:
"""Represents a system of units."""

_core_units: list[u.UnitBase]
_registry: dict[u.PhysicalType, u.UnitBase]

_required_physical_types: ClassVar[list[u.PhysicalType]] = [
u.get_physical_type("length"),
u.get_physical_type("time"),
u.get_physical_type("mass"),
u.get_physical_type("angle"),
]

def __init__(self, units: UnitSystem | u.UnitBase, *args: u.UnitBase):
if isinstance(units, UnitSystem):
if len(args) > 0:
msg = "If passing in a UnitSystem instance, you cannot pass in additional units."
raise ValueError(msg)

self._registry = units._registry.copy()
self._core_units = units._core_units
return

units = (units, *args)

self._registry = {}
for unit in units:
unit_ = (
unit if isinstance(unit, u.UnitBase) else u.def_unit(f"{unit!s}", unit)
)
if unit_.physical_type in self._registry:
msg = f"Multiple units passed in with type {unit_.physical_type!r}"
raise ValueError(msg)
self._registry[unit_.physical_type] = unit_

self._core_units = []
for phys_type in self._required_physical_types:
if phys_type not in self._registry:
msg = f"You must specify a unit for the physical type {phys_type!r}"
raise ValueError(msg)
self._core_units.append(self._registry[phys_type])

def __getitem__(self, key: str | u.PhysicalType) -> u.UnitBase:
key = u.get_physical_type(key)
return self._registry[key]

def __len__(self) -> int:
return len(self._core_units)

def __iter__(self) -> u.UnitBase:
yield from self._core_units

def __repr__(self) -> str:
return f"UnitSystem({', '.join(str(uu) for uu in self._core_units)})"

def __eq__(self, other: Any) -> bool:
return bool(self._registry == other._registry)

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)


class DimensionlessUnitSystem(UnitSystem):
_required_physical_types: ClassVar[list[u.PhysicalType]] = []

def __init__(self) -> None:
self._core_units = [u.one]
self._registry = {"dimensionless": u.one}

def __getitem__(self, key: str) -> u.UnitBase:
return u.one

def __str__(self) -> str:
return "UnitSystem(dimensionless)"


# define galactic unit system
galactic = UnitSystem(u.kpc, u.Myr, u.Msun, u.radian, u.km / u.s)

# solar system units
solarsystem = UnitSystem(u.au, u.M_sun, u.yr, u.radian)

# dimensionless
dimensionless = DimensionlessUnitSystem()
8 changes: 2 additions & 6 deletions src/galdynamix/utils/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,10 @@

__all__ = ["ImmutableDict"]

from collections.abc import Mapping
from collections.abc import ItemsView, Iterator, KeysView, Mapping, ValuesView
from typing import (
ItemsView,
Iterator,
KeysView,
Self,
TypeVar,
ValuesView,
)

from jax.tree_util import register_pytree_node_class
Expand Down Expand Up @@ -78,4 +74,4 @@ def tree_unflatten(
a re-constructed object of the registered type, using the specified
children and auxiliary data.
"""
return cls(tuple(zip(aux_data, children))) # type: ignore[arg-type]
return cls(tuple(zip(aux_data, children, strict=True))) # type: ignore[arg-type]
3 changes: 1 addition & 2 deletions src/galdynamix/utils/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@

__all__ = ["partial_jit"]

from collections.abc import Iterable, Sequence
from collections.abc import Callable, Iterable, Sequence
from functools import partial
from typing import (
Any,
Callable,
NotRequired,
TypedDict,
TypeVar,
Expand Down
Loading

0 comments on commit 4bfacc1

Please sign in to comment.