Skip to content

Commit

Permalink
Merge pull request #3 from ajshajib/feature/autoformat
Browse files Browse the repository at this point in the history
Automated code and docstring formatting
  • Loading branch information
sibirrer committed Mar 15, 2024
2 parents 54f253a + 8eaf2b6 commit 87a86af
Show file tree
Hide file tree
Showing 12 changed files with 149 additions and 34 deletions.
45 changes: 36 additions & 9 deletions .github/workflows/ci_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,21 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.8]
python-version: [3.11]

steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
echo "NUMBA_DISABLE_JIT=false" >> $GITHUB_ENV # disable numba.jit as it causes issues with JAX
python -m pip install --upgrade pip
python -m pip install --upgrade pytest
python -m pip install flake8 pytest pytest-cov coveralls
python -m pip install flake8 pytest pytest-cov
python -m pip install codecov
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f test_requirements.txt ]; then pip install -r test_requirements.txt; fi
python -m pip install .
Expand All @@ -40,10 +41,36 @@ jobs:
# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
- name: Test with pytest
run: |
pytest --cov=jaxtronomy
pytest --cov=./ --cov-report=xml
codecov
- name: Coveralls
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v3
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
coveralls --service=github
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
with:
files: ./coverage.xml

# - name: Coveralls
# env:
# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# run: |
# coveralls --service=github

# from this source for coveralls: https://github.com/marketplace/actions/coveralls-github-action
#- uses: actions/checkout@v1

#- name: Use Node.js 10.x
# uses: actions/setup-node@v1
# with:
# node-version: 10.x

#- name: npm install, make test-coverage
# run: |
# npm install
# make test-coverage

#- name: Coveralls
# uses: coverallsapp/github-action@master
# with:
# github-token: ${{ secrets.GITHUB_TOKEN }}
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,5 +128,5 @@ dmypy.json
# Pyre type checker
.pyre/

# macOS cache files
.idea/*
**/.DS_Store
33 changes: 33 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
ci:
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
skip: []
submodules: false

repos:
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black-jupyter
language_version: python3
- repo: https://github.com/PyCQA/docformatter
rev: v1.7.5
hooks:
- id: docformatter
additional_dependencies: [tomli]
args: [-r, --black, --in-place]
12 changes: 12 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ JAXtronomy
.. image:: https://img.shields.io/badge/License-BSD_3--Clause-blue.svg
:target: https://github.com/lenstronomy/lenstronomy/blob/main/LICENSE

.. image:: https://codecov.io/gh/lenstronomy/JAXtronomy/graph/badge.svg?token=6EJAX8CF62
:target: https://codecov.io/gh/lenstronomy/JAXtronomy

.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
:target: https://github.com/psf/black

.. image:: https://img.shields.io/badge/%20formatter-docformatter-fedcba.svg
:target: https://github.com/PyCQA/docformatter

.. image:: https://img.shields.io/badge/%20style-sphinx-0a507a.svg
:target: https://www.sphinx-doc.org/en/master/usage/index.html



**JAX port of lenstronomy, for parallelized, GPU accelerated, and differentiable gravitational lensing and image simulations.**
Expand Down
10 changes: 10 additions & 0 deletions codecov.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
comment: # this is a top-level key
layout: " diff, flags, files"
behavior: default
require_changes: false # if true: only post the comment if coverage changes
require_base: false # [true :: must have a base report to post]
require_head: true # [true :: must have a head report to post]
ignore:
- "setup.py"
- "test/*"
- "test/**/*.py"
1 change: 0 additions & 1 deletion jaxtronomy/LensModel/Profiles/p_jaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,4 +354,3 @@ def _sort_ra_rs(Ra, Rs):
Ra = jnp.where(Rs < Ra, Rs, Ra)
Rs = jnp.where(Rs < Ra, Ra, Rs)
return Ra, Rs

4 changes: 2 additions & 2 deletions jaxtronomy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__author__ = 'lenstronomy developers'
__version__ = '0.0.1rc1'
__author__ = "lenstronomy developers"
__version__ = "0.0.1rc1"
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
license="BSD-3",
install_requires=requires,
tests_require=tests_require,
keywords='lenstronomy',
keywords="lenstronomy",
classifiers=[
"Development Status :: 1 - Alpha",
"Intended Audience :: Science/Research",
Expand All @@ -29,4 +29,4 @@
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
],
)
)
32 changes: 23 additions & 9 deletions test/test_LensModel/test_Profiles/test_p_jaffe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import numpy.testing as npt
import pytest
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down Expand Up @@ -71,7 +72,9 @@ def test_hessian(self):
sigma0 = 1.0
Ra, Rs = 0.5, 0.8
f_xx, f_xy, f_yx, f_yy = self.profile.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(
x, y, sigma0, Ra, Rs
)
npt.assert_almost_equal(f_xx, f_xx_ref, decimal=8)
npt.assert_almost_equal(f_xy, f_xy_ref, decimal=8)
npt.assert_almost_equal(f_yy, f_yy_ref, decimal=8)
Expand All @@ -80,7 +83,9 @@ def test_hessian(self):
x = np.array([1, 3, 4])
y = np.array([2, 1, 1])
f_xx, f_xy, f_yx, f_yy = self.profile.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(x, y, sigma0, Ra, Rs)
f_xx_ref, f_xy_ref, f_yx_ref, f_yy_ref = self.profile_ref.hessian(
x, y, sigma0, Ra, Rs
)
npt.assert_array_almost_equal(f_xx, f_xx_ref, decimal=8)
npt.assert_array_almost_equal(f_xy, f_xy_ref, decimal=8)
npt.assert_array_almost_equal(f_yy, f_yy_ref, decimal=8)
Expand Down Expand Up @@ -143,16 +148,25 @@ def test_jax_jit(self):
sigma0 = 1.0
Ra, Rs = 0.5, 0.8
jitted = jax.jit(self.profile.function)
npt.assert_almost_equal(self.profile.function(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_almost_equal(
self.profile.function(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)

jitted = jax.jit(self.profile.derivatives)
npt.assert_array_almost_equal(self.profile.derivatives(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_array_almost_equal(
self.profile.derivatives(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)

jitted = jax.jit(self.profile.hessian)
npt.assert_array_almost_equal(self.profile.hessian(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs), decimal=8)
npt.assert_array_almost_equal(
self.profile.hessian(x, y, sigma0, Ra, Rs),
jitted(x, y, sigma0, Ra, Rs),
decimal=8,
)


if __name__ == "__main__":
Expand Down
32 changes: 23 additions & 9 deletions test/test_LensModel/test_Profiles/test_p_jaffe_ellipse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@


from jaxtronomy.LensModel.Profiles.p_jaffe_ellipse import PJaffe_Ellipse
from lenstronomy.LensModel.Profiles.p_jaffe_ellipse import PJaffe_Ellipse as PJaffe_Ellipse_ref
from lenstronomy.LensModel.Profiles.p_jaffe_ellipse import (
PJaffe_Ellipse as PJaffe_Ellipse_ref,
)
import jaxtronomy.Util.param_util as param_util

import numpy as np
import numpy.testing as npt
import pytest
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down Expand Up @@ -126,7 +129,9 @@ def test_hessian(self):

def test_mass_3d_lens(self):
mass = self.profile.mass_3d_lens(r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0)
mass_ref = self.profile_ref.mass_3d_lens(r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0)
mass_ref = self.profile_ref.mass_3d_lens(
r=1, sigma0=1, Ra=0.5, Rs=0.8, e1=0, e2=0
)
npt.assert_almost_equal(mass, mass_ref, decimal=8)

def test_jax_jit(self):
Expand All @@ -137,16 +142,25 @@ def test_jax_jit(self):
q, phi_G = 0.8, 0.1
e1, e2 = param_util.phi_q2_ellipticity(phi_G, q)
jitted = jax.jit(self.profile.function)
npt.assert_almost_equal(self.profile.function(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_almost_equal(
self.profile.function(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)

jitted = jax.jit(self.profile.derivatives)
npt.assert_array_almost_equal(self.profile.derivatives(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_array_almost_equal(
self.profile.derivatives(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)

jitted = jax.jit(self.profile.hessian)
npt.assert_array_almost_equal(self.profile.hessian(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2), decimal=8)
npt.assert_array_almost_equal(
self.profile.hessian(x, y, sigma0, Ra, Rs, e1, e2),
jitted(x, y, sigma0, Ra, Rs, e1, e2),
decimal=8,
)


if __name__ == "__main__":
Expand Down
3 changes: 2 additions & 1 deletion test/test_Util/test_param_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import jaxtronomy.Util.param_util as param_util
import lenstronomy.Util.param_util as param_util_ref
import jax
jax.config.update('jax_enable_x64', True) # 64-bit floats, consistent with numpy

jax.config.update("jax_enable_x64", True) # 64-bit floats, consistent with numpy
import jax.numpy as jnp


Expand Down
5 changes: 5 additions & 0 deletions test/test_lenstronomy_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
def test_lenstronomy_version():
"""Tests the import of lenstronomy."""
import lenstronomy

assert lenstronomy.__version__ == "1.11.7"

0 comments on commit 87a86af

Please sign in to comment.