Skip to content

Commit

Permalink
Compat fix for older numpy versions (#517)
Browse files Browse the repository at this point in the history
* Compat fix for older numpy versions

* Adding test environment for old numpy versions

* Fixing lint

---------

Co-authored-by: Dan Foreman-Mackey <dfm@dfm.io>
  • Loading branch information
MuellerSeb and dfm committed Apr 19, 2024
1 parent 9433c91 commit 9c5f59b
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,17 @@ jobs:
COVERALLS_PARALLEL: true
COVERALLS_FLAG_NAME: ${{ matrix.python-version }}-${{ matrix.os }}

leading_edge:
numpy_edge:
runs-on: ${{ matrix.os }}
strategy:
matrix:
python-version: ["3.12"]
numpy-version: ["numpy>=2.0.0rc1"]
os: ["ubuntu-latest"]
include:
- python-version: "3.9"
numpy-version: "numpy<1.25"
os: "ubuntu-latest"

steps:
- name: Checkout
Expand All @@ -68,7 +73,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install pip install pytest "numpy>=2.0.0rc1"
python -m pip install pip install pytest "${{ matrix.numpy-version }}"
python -m pip install -e.
- name: Run tests
run: pytest
Expand Down
9 changes: 8 additions & 1 deletion src/emcee/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
# for py2.7, will be an Exception in 3.8
from collections import Iterable

try:
# Try to import from numpy.exceptions (available in NumPy 1.25 and later)
from numpy.exceptions import VisibleDeprecationWarning
except ImportError:
# Fallback to the top-level numpy import (for older versions)
from numpy import VisibleDeprecationWarning


class EnsembleSampler(object):
"""An ensemble MCMC sampler
Expand Down Expand Up @@ -511,7 +518,7 @@ def compute_log_prob(self, coords):
try:
with warnings.catch_warnings(record=True):
warnings.simplefilter(
"error", np.exceptions.VisibleDeprecationWarning
"error", VisibleDeprecationWarning
)
try:
dt = np.atleast_1d(blob[0]).dtype
Expand Down

0 comments on commit 9c5f59b

Please sign in to comment.