Skip to content

Commit

Permalink
Added Support for Numpy BitGenerators
Browse files Browse the repository at this point in the history
  • Loading branch information
kc611 committed May 6, 2022
1 parent 0f5953d commit 3356b88
Show file tree
Hide file tree
Showing 10 changed files with 2,796 additions and 0 deletions.
96 changes: 96 additions & 0 deletions numba/core/boxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1105,3 +1105,99 @@ def unbox_typeref(typ, val, c):
@box(types.LiteralStrKeyDict)
def box_LiteralStrKeyDict(typ, val, c):
return box_unsupported(typ, val, c)


@unbox(types.NumPyRandomBitGeneratorType)
def unbox_numpy_random_bitgenerator(typ, obj, c):
# The bit_generator instance has a `.ctypes` attr which is a namedtuple
# with the following members (types):
# * state_address (Python int)
# * state (ctypes.c_void_p)
# * next_uint64 (ctypes.CFunctionType instance)
# * next_uint32 (ctypes.CFunctionType instance)
# * next_double (ctypes.CFunctionType instance)
# * bit_generator (ctypes.c_void_p)

struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder)
struct_ptr.parent = obj
# c.pyapi.incref(obj) # ? need to hold ref to the underlying python obj

# Get the .ctypes attr
ctypes_binding = c.pyapi.object_getattr_string(obj, 'ctypes')

# Look up the "state_address" member and wire it into the struct
interface_state_address = c.pyapi.object_getattr_string(
ctypes_binding, 'state_address')
setattr(struct_ptr, 'state_address',
c.unbox(types.uintp, interface_state_address).value)

# Look up the "state" member and wire it into the struct
interface_state = c.pyapi.object_getattr_string(ctypes_binding, 'state')
interface_state_value = c.pyapi.object_getattr_string(
interface_state, 'value')
setattr(
struct_ptr,
'state',
c.unbox(
types.uintp,
interface_state_value).value)

# Want to store callable function pointers to these CFunctionTypes, so
# import ctypes and use it to cast the CFunctionTypes to c_void_p and
# store the results.
# First find ctypes.cast, and ctypes.c_void_p
ctypes_name = c.context.insert_const_string(c.builder.module, 'ctypes')
ctypes_module = c.pyapi.import_module_noblock(ctypes_name)
ct_cast = c.pyapi.object_getattr_string(ctypes_module, 'cast')
ct_voidptr_ty = c.pyapi.object_getattr_string(ctypes_module, 'c_void_p')

# This wires in the fnptrs refered to by name
def wire_in_fnptrs(name):
# Find the CFunctionType function
interface_next_fn = c.pyapi.object_getattr_string(
ctypes_binding, name)

# Want to do ctypes.cast(CFunctionType, ctypes.c_void_p), create an
# args tuple for that.
args = c.pyapi.tuple_pack([interface_next_fn, ct_voidptr_ty])

# Call ctypes.cast()
interface_next_fn_casted = c.pyapi.call(ct_cast, args)

# Fetch the .value attr on the resulting ctypes.c_void_p for storage
# in the function pointer slot.
interface_next_fn_casted_value = c.pyapi.object_getattr_string(
interface_next_fn_casted, 'value')

# Wire up
setattr(struct_ptr, f'fnptr_{name}',
c.unbox(types.uintp, interface_next_fn_casted_value).value)

wire_in_fnptrs('next_double')
wire_in_fnptrs('next_uint64')
wire_in_fnptrs('next_uint32')

# This is the same as the `state` member but its the bit_generator address,
# it's probably never needed.
interface_bit_generator = c.pyapi.object_getattr_string(
ctypes_binding, 'bit_generator')
interface_bit_generator_value = c.pyapi.object_getattr_string(
interface_bit_generator, 'value')
setattr(
struct_ptr,
'bit_generator',
c.unbox(
types.uintp,
interface_bit_generator_value).value)

return NativeValue(struct_ptr._getvalue())

_bit_gen_type = types.NumPyRandomBitGeneratorType('bit_generator')

@unbox(types.NumPyRandomGeneratorType)
def unbox_numpy_random_generator(typ, obj, c):
struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder)
bit_gen_inst = c.pyapi.object_getattr_string(obj, 'bit_generator')
unboxed = c.unbox(_bit_gen_type, bit_gen_inst).value
struct_ptr.bit_generator = unboxed
return NativeValue(struct_ptr._getvalue())
1 change: 1 addition & 0 deletions numba/core/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def load_additional_registries(self):
from numba.core import optional
from numba.misc import gdb_hook, literal
from numba.np import linalg, polynomial, arraymath, arrayobj
from numba.np.random import generator_core, generator_methods
from numba.typed import typeddict, dictimpl
from numba.typed import typedlist, listobject
from numba.experimental import jitclass, function_type
Expand Down
11 changes: 11 additions & 0 deletions numba/core/types/npytypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,14 @@ def strides(self):
@property
def key(self):
return self.dtype, self.shape

class NumPyRandomBitGeneratorType(Type):
def __init__(self, *args, **kwargs):
super(NumPyRandomBitGeneratorType, self).__init__(*args, **kwargs)
self.name = 'NumPyRandomBitGeneratorType'


class NumPyRandomGeneratorType(Type):
def __init__(self, *args, **kwargs):
super(NumPyRandomGeneratorType, self).__init__(*args, **kwargs)
self.name = 'NumPyRandomGeneratorType'
15 changes: 15 additions & 0 deletions numba/core/typing/typeof.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
# terminal color markup
_termcolor = errors.termcolor()

try:
from numpy.random._bit_generator import BitGenerator
except ImportError:
from numpy.random.bit_generator import BitGenerator


class Purpose(enum.Enum):
# Value being typed is used as an argument
Expand Down Expand Up @@ -265,3 +270,13 @@ def _typeof_nb_type(val, c):
return types.NumberClass(val)
else:
return types.TypeRef(val)


@typeof_impl.register(BitGenerator)
def typeof_numpy_random_bitgen(val, c):
return types.NumPyRandomBitGeneratorType(val)


@typeof_impl.register(np.random.Generator)
def typeof_random_generator(val, c):
return types.NumPyRandomGeneratorType(val)
Empty file added numba/np/random/__init__.py
Empty file.

0 comments on commit 3356b88

Please sign in to comment.