Skip to content

Commit

Permalink
Extract dtype.type in wrapper
Browse files Browse the repository at this point in the history
If a key to the template metaclass is a numpy dtype object, use
the dtype.type result as the template class lookup key.
  • Loading branch information
natelust committed Jun 14, 2018
1 parent 860253a commit 115743d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
16 changes: 14 additions & 2 deletions python/lsst/utils/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import sys
import types

import numpy as np

__all__ = ("continueClass", "inClass", "TemplateMeta")


Expand Down Expand Up @@ -265,8 +267,18 @@ def __call__(self, *args, **kwds):
# the abstract base class.
# If the ABC defines a "TEMPLATE_PARAMS" attribute, we use those strings
# as the kwargs we should intercept to find the right type.
key = tuple(kwds.pop(p, d) for p, d in zip(self.TEMPLATE_PARAMS,
self.TEMPLATE_DEFAULTS))

# Generate a type mapping key from input keywords. If the type returned
# from the keyword lookup is a numpy dtype object, fetch the underlying
# type of the dtype
key = []
for p, d in zip(self.TEMPLATE_PARAMS, self.TEMPLATE_DEFAULTS):
tempKey = kwds.pop(p, d)
if isinstance(tempKey, np.dtype):
tempKey = tempKey.type
key.append(tempKey)
key = tuple(key)

# indices are only tuples if there are multiple elements
cls = self._registry.get(key[0] if len(key) == 1 else key, None)
if cls is None:
Expand Down
12 changes: 8 additions & 4 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,14 @@ def testInheritanceHooks(self):

def testConstruction(self):
self.register()
f = self.Example(dtype=np.float32)
self.assertIsInstance(f, self.Example)
self.assertIsInstance(f, self.ExampleF)
self.assertNotIsInstance(f, self.ExampleD)
f1 = self.Example(dtype=np.float32)
# Test that numpy dtype objects resolve to their underlying type
f2 = self.Example(dtype=np.dtype(np.float32))
for f in (f1, f2):
self.assertIsInstance(f, self.Example)
self.assertIsInstance(f, self.ExampleF)
self.assertNotIsInstance(f, self.ExampleD)

with self.assertRaises(TypeError):
self.Example()
with self.assertRaises(TypeError):
Expand Down

0 comments on commit 115743d

Please sign in to comment.