Skip to content

Commit

Permalink
Require explicit enabling of numpy interop (#337)
Browse files Browse the repository at this point in the history
* feat: make numpy interop opt-in

This commit removes all automatic imports of numpy, and instead adds a new
function "enable_numpy_interop()" to comtypes.npsupport. This function is
the only one which actually imports numpy, any attempt to use numpy interop
related features without calling "enable_numpy_interop()" first will lead to
an ImportError being raised with a message explaining that the function needs
to be called before using numpy functionality.

Other parts of comtypes wishing to access numpy functions should call
comtypes.npsupport.get_numpy() which will return the module.

* make npsupport.isndarray raise an Exception if interop not enabled

Without numpy interop being enabled, we can't directly check if a variable is
an ndarray, but having the __array_interface__ attribute is a fairly good
estimator, so if an object with that attribute is passed to isndarray before
interop is enabled, a ValueError will be raised prompting the user to call
npsupport.enable_numpy_interop().

* make safearray_as_ndarray automatically enable np interop

Entering the safearray_as_ndarray context manager will internally call
npsupport.enable_numpy_interop(), as it is clear that the user wants to have
numpy support.

early return from repeated calls to npsupport.enable_numpy_interop()

Calling enable_numpy_interop() when interop is already enabled returns at the
top of the function.

* reorganise tests relating to np interop

I have gathered all the numpy related tests into a single test file,
test_npsupport.py, which has a couple of different TestCases internally,
reflecting the original organisation. test_npsupport will only be run if
importing numpy succeeds.

I also removed a lot of @Skip decorators relating to numpy dependence, all
tests currently pass (or are skipped) with or without numpy being installed
(on my system at least, Python 3.9.13 and numpy 1.23.1).

* fix syntax errors for older Python versions

Remove inline type annotations on a couple of functions in test_npsupport.py
and an f-string in npsupport.py

* refactor to use a singleton class instead of global variables

Modify npsupport to put all functionality inside a class Interop, which
exposes public interface methods "enable()", "isndarray()", "isdatetime64" and
the properties "numpy", "VARIANT_dtype", "typecodes", "datetime64" and
"com_null_date64". A singleton instance of the class (called interop) is
created in the npsupport namespace, to use numpy interop its "enable()" method
should be called. It is also still valid to use the "safearray_as_ndarray"
context manager to enable support as well.

Co-authored-by: Ben Rowland <ben.rowland@hallmarq.net>
  • Loading branch information
bennyrowland and Ben Rowland committed Aug 19, 2022
1 parent 57e4f24 commit cc9a013
Show file tree
Hide file tree
Showing 8 changed files with 492 additions and 455 deletions.
8 changes: 4 additions & 4 deletions comtypes/automation.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def _set_value(self, value):
com_days = delta.days + (delta.seconds + delta.microseconds * 1e-6) / 86400.
self.vt = VT_DATE
self._.VT_R8 = com_days
elif npsupport.isdatetime64(value):
elif npsupport.interop.isdatetime64(value):
com_days = value - npsupport.com_null_date64
com_days /= npsupport.numpy.timedelta64(1, 'D')
com_days /= npsupport.interop.numpy.timedelta64(1, 'D')
self.vt = VT_DATE
self._.VT_R8 = com_days
elif decimal is not None and isinstance(value, decimal.Decimal):
Expand All @@ -308,10 +308,10 @@ def _set_value(self, value):
obj = _midlSAFEARRAY(typ).create(value)
memmove(byref(self._), byref(obj), sizeof(obj))
self.vt = VT_ARRAY | obj._vartype_
elif npsupport.isndarray(value):
elif npsupport.interop.isndarray(value):
# Try to convert a simple array of basic types.
descr = value.dtype.descr[0][1]
typ = npsupport.typecodes.get(descr)
typ = npsupport.interop.typecodes.get(descr)
if typ is None:
# Try for variant
obj = _midlSAFEARRAY(VARIANT).create(value)
Expand Down
278 changes: 153 additions & 125 deletions comtypes/npsupport.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,153 @@
""" Consolidation of numpy support utilities. """
import sys

try:
import numpy
except ImportError:
numpy = None


HAVE_NUMPY = numpy is not None

is_64bits = sys.maxsize > 2**32


def _make_variant_dtype():
""" Create a dtype for VARIANT. This requires support for Unions, which is
available in numpy version 1.7 or greater.
This does not support the decimal type.
Returns None if the dtype cannot be created.
"""

# pointer typecode
ptr_typecode = '<u8' if is_64bits else '<u4'

_tagBRECORD_format = [
('pvRecord', ptr_typecode),
('pRecInfo', ptr_typecode),
]

# overlapping typecodes only allowed in numpy version 1.7 or greater
U_VARIANT_format = dict(
names=[
'VT_BOOL', 'VT_I1', 'VT_I2', 'VT_I4', 'VT_I8', 'VT_INT', 'VT_UI1',
'VT_UI2', 'VT_UI4', 'VT_UI8', 'VT_UINT', 'VT_R4', 'VT_R8', 'VT_CY',
'c_wchar_p', 'c_void_p', 'pparray', 'bstrVal', '_tagBRECORD',
],
formats=[
'<i2', '<i1', '<i2', '<i4', '<i8', '<i4', '<u1', '<u2', '<u4',
'<u8', '<u4', '<f4', '<f8', '<i8', ptr_typecode, ptr_typecode,
ptr_typecode, ptr_typecode, _tagBRECORD_format,
],
offsets=[0] * 19 # This is what makes it a union
)

tagVARIANT_format = [
("vt", '<u2'),
("wReserved1", '<u2'),
("wReserved2", '<u2'),
("wReserved3", '<u2'),
("_", U_VARIANT_format),
]

return numpy.dtype(tagVARIANT_format)


def isndarray(value):
""" Check if a value is an ndarray.
This cannot succeed if numpy is not available.
"""
if not HAVE_NUMPY:
return False
return isinstance(value, numpy.ndarray)


def isdatetime64(value):
""" Check if a value is a datetime64.
This cannot succeed if datetime64 is not available.
"""
if not HAVE_NUMPY:
return False
return isinstance(value, datetime64)


def _check_ctypeslib_typecodes():
import numpy as np
from numpy import ctypeslib
try:
from numpy.ctypeslib import _typecodes
except ImportError:
from numpy.ctypeslib import as_ctypes_type

dtypes_to_ctypes = {}

for tp in set(np.sctypeDict.values()):
try:
ctype_for = as_ctypes_type(tp)
dtypes_to_ctypes[np.dtype(tp).str] = ctype_for
except NotImplementedError:
continue
ctypeslib._typecodes = dtypes_to_ctypes
return ctypeslib._typecodes


com_null_date64 = None
datetime64 = None
VARIANT_dtype = None
typecodes = {}

if HAVE_NUMPY:
typecodes = _check_ctypeslib_typecodes()
# dtype for VARIANT. This allows for packing of variants into an array, and
# subsequent conversion to a multi-dimensional safearray.
try:
VARIANT_dtype = _make_variant_dtype()
except ValueError:
pass

# This simplifies dependent modules
try:
from numpy import datetime64
except ImportError:
pass
else:
try:
# This does not work on numpy 1.6
com_null_date64 = datetime64("1899-12-30T00:00:00", "ns")
except TypeError:
pass
""" Consolidation of numpy support utilities. """
import sys

is_64bits = sys.maxsize > 2**32


class Interop:
""" Class encapsulating all the functionality necessary to allow interop of
comtypes with numpy. Needs to be enabled with the "enable()" method.
"""
def __init__(self):
self.enabled = False
self.VARIANT_dtype = None
self.typecodes = {}
self.datetime64 = None
self.com_null_date64 = None

def _make_variant_dtype(self):
""" Create a dtype for VARIANT. This requires support for Unions, which
is available in numpy version 1.7 or greater.
This does not support the decimal type.
Returns None if the dtype cannot be created.
"""
if not self.enabled:
return None
# pointer typecode
ptr_typecode = '<u8' if is_64bits else '<u4'

_tagBRECORD_format = [
('pvRecord', ptr_typecode),
('pRecInfo', ptr_typecode),
]

# overlapping typecodes only allowed in numpy version 1.7 or greater
U_VARIANT_format = dict(
names=[
'VT_BOOL', 'VT_I1', 'VT_I2', 'VT_I4', 'VT_I8', 'VT_INT',
'VT_UI1', 'VT_UI2', 'VT_UI4', 'VT_UI8', 'VT_UINT', 'VT_R4',
'VT_R8', 'VT_CY', 'c_wchar_p', 'c_void_p', 'pparray',
'bstrVal', '_tagBRECORD',
],
formats=[
'<i2', '<i1', '<i2', '<i4', '<i8', '<i4', '<u1', '<u2', '<u4',
'<u8', '<u4', '<f4', '<f8', '<i8', ptr_typecode, ptr_typecode,
ptr_typecode, ptr_typecode, _tagBRECORD_format,
],
offsets=[0] * 19 # This is what makes it a union
)

tagVARIANT_format = [
("vt", '<u2'),
("wReserved1", '<u2'),
("wReserved2", '<u2'),
("wReserved3", '<u2'),
("_", U_VARIANT_format),
]

return self.numpy.dtype(tagVARIANT_format)

def _check_ctypeslib_typecodes(self):
if not self.enabled:
return {}
import numpy as np
from numpy import ctypeslib
try:
from numpy.ctypeslib import _typecodes
except ImportError:
from numpy.ctypeslib import as_ctypes_type

dtypes_to_ctypes = {}

for tp in set(np.sctypeDict.values()):
try:
ctype_for = as_ctypes_type(tp)
dtypes_to_ctypes[np.dtype(tp).str] = ctype_for
except NotImplementedError:
continue
ctypeslib._typecodes = dtypes_to_ctypes
return dtypes_to_ctypes

def isndarray(self, value):
""" Check if a value is an ndarray.
This cannot succeed if numpy is not available.
"""
if not self.enabled:
if hasattr(value, "__array_interface__"):
raise ValueError(
(
"Argument {0} appears to be a numpy.ndarray, but "
"comtypes numpy support has not been enabled. Please "
"try calling comtypes.npsupport.enable_numpy_interop()"
" before passing ndarrays as parameters."
).format(value)
)
return False

return isinstance(value, self.numpy.ndarray)

def isdatetime64(self, value):
""" Check if a value is a datetime64.
This cannot succeed if datetime64 is not available.
"""
if not self.enabled:
return False
return isinstance(value, self.datetime64)

@property
def numpy(self):
""" The numpy package.
"""
if self.enabled:
import numpy
return numpy
raise ImportError(
"In comtypes>=1.2.0 numpy interop must be explicitly enabled with "
"comtypes.npsupport.enable_numpy_interop before attempting to use "
"numpy features."
)

def enable(self):
""" Enables numpy/comtypes interop.
"""
# don't do this twice
if self.enabled:
return
# first we have to be able to import numpy
import numpy
# if that succeeded we can be enabled
self.enabled = True
self.VARIANT_dtype = self._make_variant_dtype()
self.typecodes = self._check_ctypeslib_typecodes()
try:
from numpy import datetime64
self.datetime64 = datetime64
except ImportError:
self.datetime64 = None
if self.datetime64:
try:
# This does not work on numpy 1.6
self.com_null_date64 = self.datetime64("1899-12-30T00:00:00", "ns")
except TypeError:
self.com_null_date64 = None


interop = Interop()

__all__ = ["interop"]

0 comments on commit cc9a013

Please sign in to comment.