Skip to content
This repository
Browse code

Support numba.addressof() (issue #141)

  • Loading branch information...
commit 59400c3634b2c15fc8e5c934faeb2315ca458904 1 parent 45d5a5b
Mark Florisson markflorisson88 authored
3  numba/pipeline.py
@@ -545,6 +545,9 @@ def transform(self, ast, env):
545 545 func_env.numba_wrapper_func = numbawrapper
546 546 func_env.llvm_wrapper_func = lfuncwrapper
547 547
  548 + # Set pointer to function for external code and numba.addressof()
  549 + numbawrapper.lfunc_pointer = func_env.lfunc_pointer
  550 +
548 551 return ast
549 552
550 553 class ErrorReporting(PipelineStage):
49 numba/special.py
... ... @@ -1,16 +1,56 @@
1 1 # -*- coding: utf-8 -*-
  2 +
2 3 """
3 4 Special compiler-recognized numba functions and attributes.
4 5 """
  6 +
5 7 from __future__ import print_function, division, absolute_import
6 8
7   -__all__ = ['NULL', 'typeof', 'python', 'nopython']
  9 +__all__ = ['NULL', 'typeof', 'python', 'nopython', 'addressof']
  10 +
  11 +import ctypes
  12 +
  13 +from numba import error
  14 +
  15 +#------------------------------------------------------------------------
  16 +# Pointers
  17 +#------------------------------------------------------------------------
8 18
9 19 class NumbaDotNULL(object):
10 20 "NULL pointer"
11 21
12 22 NULL = NumbaDotNULL()
13 23
  24 +
  25 +def addressof(obj, propagate=True):
  26 + """
  27 + Take the address of a compiled jit function.
  28 +
  29 + :param obj: the jit function
  30 + :param write_unraisable: whether to write uncaught exceptions to stderr
  31 + :param propagate: whether to always propagate exceptions
  32 +
  33 + :return: ctypes function pointer
  34 + """
  35 + from numba import numbawrapper
  36 +
  37 + if not propagate:
  38 + raise ValueError("Writing unraisable exception is not yet supported")
  39 +
  40 + if not isinstance(obj, numbawrapper.numbafunction_type):
  41 + raise TypeError("Object is not a jit function")
  42 +
  43 + if obj.lfunc_pointer is None:
  44 + raise ValueError(
  45 + "Jit function does not have pointer")
  46 +
  47 + ctypes_sig = obj.signature.to_ctypes()
  48 + return ctypes.cast(obj.lfunc_pointer, ctypes_sig)
  49 +
  50 +#------------------------------------------------------------------------
  51 +# Types
  52 +#------------------------------------------------------------------------
  53 +
14 54 def typeof(variable):
15 55 """
16 56 Get the type of a variable.
@@ -21,6 +61,10 @@ def typeof(variable):
21 61 context = NumbaEnvironment.get_environment().context
22 62 return context.typemapper.from_python(variable)
23 63
  64 +#------------------------------------------------------------------------
  65 +# python/nopython context managers
  66 +#------------------------------------------------------------------------
  67 +
24 68 class NoopContext(object):
25 69
26 70 def __init__(self, name):
@@ -36,4 +80,5 @@ def __repr__(self):
36 80 return self.name
37 81
38 82 python = NoopContext("python")
39   -nopython = NoopContext("nopython")
  83 +nopython = NoopContext("nopython")
  84 +
52 numba/tests/test_addressof.py
... ... @@ -0,0 +1,52 @@
  1 +# -*- coding: utf-8 -*-
  2 +
  3 +"""
  4 +Test numba.addressof().
  5 +"""
  6 +
  7 +from __future__ import print_function, division, absolute_import
  8 +
  9 +import ctypes
  10 +
  11 +import numba
  12 +from numba import *
  13 +
  14 +
  15 +@jit(int32(int32, int32))
  16 +def func(a, b):
  17 + return a * b
  18 +
  19 +@autojit
  20 +def error_func():
  21 + pass
  22 +
  23 +#------------------------------------------------------------------------
  24 +# Tests
  25 +#------------------------------------------------------------------------
  26 +
  27 +def test_addressof(arg):
  28 + """
  29 + >>> func = test_addressof(func)
  30 + >>> assert func.restype == ctypes.c_int32
  31 + >>> assert func.argtypes == (ctypes.c_int32, ctypes.c_int32)
  32 + >>> func(5, 2)
  33 + 10
  34 + """
  35 + return numba.addressof(arg)
  36 +
  37 +def test_addressof_error(arg, **kwds):
  38 + """
  39 + >>> test_addressof_error(error_func)
  40 + Traceback (most recent call last):
  41 + ...
  42 + TypeError: Object is not a jit function
  43 +
  44 + >>> test_addressof_error(func, propagate=False)
  45 + Traceback (most recent call last):
  46 + ...
  47 + ValueError: Writing unraisable exception is not yet supported
  48 + """
  49 + return numba.addressof(arg, **kwds)
  50 +
  51 +
  52 +numba.testmod()

0 comments on commit 59400c3

Please sign in to comment.
Something went wrong with that request. Please try again.