Skip to content

Commit

Permalink
Fix rendering for numpy scalars
Browse files Browse the repository at this point in the history
In numpy 2.0, `repr(np.float64(0.5))` is `"np.float64(0.5)"`. Before, it was `"0.5"` (as for a Python `float`).
  • Loading branch information
mstimberg committed Apr 17, 2024
1 parent 193983d commit 491b9b7
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
6 changes: 5 additions & 1 deletion brian2/parsing/rendering.py
@@ -1,6 +1,7 @@
import ast
import numbers

import numpy as np
import sympy

from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS
Expand Down Expand Up @@ -87,6 +88,9 @@ def render_Name(self, node):
return node.id

def render_Constant(self, node):
if isinstance(node.value, np.number):
# repr prints the dtype in numpy 2.0
return repr(node.value.item())
return repr(node.value)

def render_Call(self, node):
Expand Down Expand Up @@ -344,7 +348,7 @@ def render_Constant(self, node):
elif node.value is False:
return "false"
else:
return repr(node.value)
return super().render_Constant(node)

def render_Name(self, node):
if node.id == "inf":
Expand Down
21 changes: 21 additions & 0 deletions brian2/tests/test_codegen.py
Expand Up @@ -10,6 +10,7 @@
from brian2 import _cache_dirs_and_extensions, clear_cache, prefs
from brian2.codegen.codeobject import CodeObject
from brian2.codegen.cpp_prefs import compiler_supports_c99, get_compiler_and_args
from brian2.codegen.generators.cython_generator import CythonNodeRenderer
from brian2.codegen.optimisation import optimise_statements
from brian2.codegen.runtime.cython_rt import CythonCodeObject
from brian2.codegen.statements import Statement
Expand All @@ -22,6 +23,7 @@
from brian2.core.functions import DEFAULT_CONSTANTS, DEFAULT_FUNCTIONS, Function
from brian2.core.variables import ArrayVariable, Constant, Subexpression, Variable
from brian2.devices.device import auto_target, device
from brian2.parsing.rendering import CPPNodeRenderer, NodeRenderer, NumpyNodeRenderer
from brian2.parsing.sympytools import str_to_sympy, sympy_to_str
from brian2.units import ms, second
from brian2.units.fundamentalunits import Unit
Expand Down Expand Up @@ -618,6 +620,25 @@ def test_msvc_flags():
assert len(previously_stored_flags[hostname])


@pytest.mark.codegen_independent
@pytest.mark.parametrize(
"renderer",
[
NodeRenderer(),
NumpyNodeRenderer(),
CythonNodeRenderer(),
CPPNodeRenderer(),
],
)
def test_number_rendering(renderer):
import ast

for number in [0.5, np.float32(0.5), np.float64(0.5)]:
# In numpy 2.0, repr(np.float64(0.5)) is 'np.float64(0.5)'
node = ast.Constant(value=number)
assert renderer.render_node(node) == "0.5"


if __name__ == "__main__":
test_auto_target()
test_analyse_identifiers()
Expand Down

0 comments on commit 491b9b7

Please sign in to comment.