Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automated code and docstring formatting #3

Merged
merged 8 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 36 additions & 9 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: [3.11]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
echo "NUMBA_DISABLE_JIT=false" >> $GITHUB_ENV # disable numba.jit as it causes issues with JAX
python -m pip install --upgrade pip
python -m pip install --upgrade pytest
python -m pip install flake8 pytest pytest-cov coveralls
python -m pip install flake8 pytest pytest-cov
python -m pip install codecov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f test_requirements.txt ]; then pip install -r test_requirements.txt; fi
python -m pip install .
Expand All @@ -40,10 +41,36 @@ jobs:
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest --cov=jaxtronomy
pytest --cov=./ --cov-report=xml
codecov

- name: Coveralls
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
coveralls --service=github
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
files: ./coverage.xml

# - name: Coveralls
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# run: |
# coveralls --service=github

# from this source for coveralls: https://github.com/marketplace/actions/coveralls-github-action
#- uses: actions/checkout@v1

#- name: Use Node.js 10.x
# uses: actions/setup-node@v1
# with:
# node-version: 10.x

#- name: npm install, make test-coverage
# run: |
# npm install
# make test-coverage

#- name: Coveralls
# uses: coverallsapp/github-action@master
# with:
# github-token: ${{ secrets.GITHUB_TOKEN }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ dmypy.json
# Pyre type checker
.pyre/

# macOS cache files
.idea/*
**/.DS_Store
33 changes: 33 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
ci:
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: []
submodules: false

repos:
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black-jupyter
language_version: python3
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: [-r, --black, --in-place]
12 changes: 12 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ JAXtronomy
.. image:: https://img.shields.io/badge/License-BSD_3--Clause-blue.svg
:target: https://github.com/lenstronomy/lenstronomy/blob/main/LICENSE

.. image:: https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62
:target: https://codecov.io/gh/lenstronomy/JAXtronomy

.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black

.. image:: https://img.shields.io/badge/%20formatter-docformatter-fedcba.svg
:target: https://github.com/PyCQA/docformatter

.. image:: https://img.shields.io/badge/%20style-sphinx-0a507a.svg
:target: https://www.sphinx-doc.org/en/master/usage/index.html



**JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.**
Expand Down
10 changes: 10 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
comment: # this is a top-level key
layout: " diff, flags, files"
behavior: default
require_changes: false # if true: only post the comment if coverage changes
require_base: false # [true :: must have a base report to post]
require_head: true # [true :: must have a head report to post]
ignore:
- "setup.py"
- "test/*"
- "test/**/*.py"
1 change: 0 additions & 1 deletion jaxtronomy/LensModel/Profiles/p_jaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,3 @@ def _sort_ra_rs(Ra, Rs):
Ra = jnp.where(Rs < Ra, Rs, Ra)
Rs = jnp.where(Rs < Ra, Ra, Rs)
return Ra, Rs

4 changes: 2 additions & 2 deletions jaxtronomy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__author__ = 'lenstronomy developers'
__version__ = '0.0.1rc1'
__author__ = "lenstronomy developers"
__version__ = "0.0.1rc1"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
license="BSD-3",
install_requires=requires,
tests_require=tests_require,
keywords='lenstronomy',
keywords="lenstronomy",
classifiers=[
"Development Status :: 1 - Alpha",
"Intended Audience :: Science/Research",
Expand All @@ -29,4 +29,4 @@
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)
)
32 changes: 23 additions & 9 deletions test/test_LensModel/test_Profiles/test_p_jaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy.testing as npt
import pytest
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down Expand Up @@ -71,7 +72,9 @@ def test_hessian(self):
sigma0 = 1.0
Ra, Rs = 0.5, 0.8
f_xx, f_xy, f_yx, f_yy = self.profile.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(
x, y, sigma0, Ra, Rs
)
npt.assert_almost_equal(f_xx, f_xx_ref, decimal=8)
npt.assert_almost_equal(f_xy, f_xy_ref, decimal=8)
npt.assert_almost_equal(f_yy, f_yy_ref, decimal=8)
Expand All @@ -80,7 +83,9 @@ def test_hessian(self):
x = np.array([1, 3, 4])
y = np.array([2, 1, 1])
f_xx, f_xy, f_yx, f_yy = self.profile.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(
x, y, sigma0, Ra, Rs
)
npt.assert_array_almost_equal(f_xx, f_xx_ref, decimal=8)
npt.assert_array_almost_equal(f_xy, f_xy_ref, decimal=8)
npt.assert_array_almost_equal(f_yy, f_yy_ref, decimal=8)
Expand Down Expand Up @@ -143,16 +148,25 @@ def test_jax_jit(self):
sigma0 = 1.0
Ra, Rs = 0.5, 0.8
jitted = jax.jit(self.profile.function)
npt.assert_almost_equal(self.profile.function(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_almost_equal(
self.profile.function(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)

jitted = jax.jit(self.profile.derivatives)
npt.assert_array_almost_equal(self.profile.derivatives(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_array_almost_equal(
self.profile.derivatives(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)

jitted = jax.jit(self.profile.hessian)
npt.assert_array_almost_equal(self.profile.hessian(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_array_almost_equal(
self.profile.hessian(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)


if __name__ == "__main__":
Expand Down
32 changes: 23 additions & 9 deletions test/test_LensModel/test_Profiles/test_p_jaffe_ellipse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@


from jaxtronomy.LensModel.Profiles.p_jaffe_ellipse import PJaffe_Ellipse
from lenstronomy.LensModel.Profiles.p_jaffe_ellipse import PJaffe_Ellipse as PJaffe_Ellipse_ref
from lenstronomy.LensModel.Profiles.p_jaffe_ellipse import (
PJaffe_Ellipse as PJaffe_Ellipse_ref,
)
import jaxtronomy.Util.param_util as param_util

import numpy as np
import numpy.testing as npt
import pytest
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down Expand Up @@ -126,7 +129,9 @@ def test_hessian(self):

def test_mass_3d_lens(self):
mass = self.profile.mass_3d_lens(r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0)
mass_ref = self.profile_ref.mass_3d_lens(r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0)
mass_ref = self.profile_ref.mass_3d_lens(
r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0
)
npt.assert_almost_equal(mass, mass_ref, decimal=8)

def test_jax_jit(self):
Expand All @@ -137,16 +142,25 @@ def test_jax_jit(self):
q, phi_G = 0.8, 0.1
e1, e2 = param_util.phi_q2_ellipticity(phi_G, q)
jitted = jax.jit(self.profile.function)
npt.assert_almost_equal(self.profile.function(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_almost_equal(
self.profile.function(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)

jitted = jax.jit(self.profile.derivatives)
npt.assert_array_almost_equal(self.profile.derivatives(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_array_almost_equal(
self.profile.derivatives(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)

jitted = jax.jit(self.profile.hessian)
npt.assert_array_almost_equal(self.profile.hessian(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_array_almost_equal(
self.profile.hessian(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion test/test_Util/test_param_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import jaxtronomy.Util.param_util as param_util
import lenstronomy.Util.param_util as param_util_ref
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down
5 changes: 5 additions & 0 deletions test/test_lenstronomy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test_lenstronomy_version():
"""Tests the import of lenstronomy."""
import lenstronomy

assert lenstronomy.__version__ == "1.11.7"
Loading