-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hack some way to dispatch on number of dimensions of numpy/jax arrays #10
Comments
Hey @PhilipVinc, Sorry for not getting back to this earlier. I'm a bit busy at the moment. I fully agree that being able to dispatch on the dimensionality of an array would be an amazing feature to have. Your proposed solution seems like something that could be made to work. Here's one way of doing this. First, we make @_dispatch
def type_of(obj):
"""Get the Plum type of an object.
Args:
obj (object): Object to get type of.
Returns
ptype: Plum type of `obj`.
"""
if isinstance(obj, list):
return List(_types_of_iterable(obj))
if isinstance(obj, tuple):
return Tuple(*(type_of(x) for x in obj))
return ptype(type(obj)) Then this is possible: import numpy as np
from plum import dispatch, parametric, type_of
@parametric
class NPArray(np.ndarray):
pass
@type_of.dispatch
def type_of(x: np.ndarray):
return NPArray[x.ndim]
@dispatch
def f(x: NPArray[1]):
print("Vector!")
@dispatch
def f(x: NPArray[2]):
print("Matrix!")
# Plum currently avoids unnecessary `type_of` calls. We would enable the use
# of `type_of` by default. Here's a hack that works for now:
f._parametric = True This would be the outcome: >>> f(np.random.randn(5))
Vector!
>>> f(np.random.randn(5, 5))
Matrix!
>>> f(np.random.randn(5, 5, 5))
NotFoundLookupError: For function "f", signature Signature(plum.parametric.NPArray[3]) could not be resolved. These extensions might be best off in LAB. For example, you would then How does this sound? |
I've made the above change in the most recent release, |
Somewhat related to the above; do you consider it feasible/possible/desirable to be able to dispatch on the basis of (a computable subset of) the value of arguments? Im trying to do something akin (though not literally identical) to dispatch based on array shape. In practice, a finite amount of concrete shapes would be in play; but I cant enumerate them upfront, since it depends on the use of the library. That is, id like to match my method dispatch on a computable function of such a hashable type (such as a shape).
Computing the condition might be expensive; but it should only happen once for every shape that occurs; and then it should just be a dict lookup based on a hashable type; just like any other dynamic dispatched call. |
You can already do it. See this test, where I do it for the number of dimensions (but you can generalise it to work on shapes) https://github.com/wesselb/plum/blob/fb29f9723f4b006457c74beb4cb09c811c1c983a/tests/test_parametric.py#L680 @parametric(runtime_type_of=True)
class NPArray(np.ndarray):
pass
@type_of.dispatch
def type_of_extension(x: np.ndarray):
return NPArray[x.ndim]
@dispatch
def f(x: NPArray[1]):
return "vector"
@dispatch
def f(x: NPArray[2]):
return "matrix"
assert f(np.random.randn(10)) == "vector"
assert f(np.random.randn(10, 10)) == "matrix"
with pytest.raises(NotFoundLookupError):
f(np.random.randn(10, 10, 10)) |
Ah sorry, now I see what you are asking is a bit different.. |
Yeah... im not seeing any other libraries that dispatch in any way on object value; just purely on type (this example of dispatching based on array dimension is the first example im seeing to the contrary). So either this is a great idea; or one thats going to blow up in your face once you think about it more / work with it more? For the use case I have in mind, the hashable attribute is defacto a dynamic extension of the typing system though; and while my use case isnt precisely identical to dispatching based on array shape; I can infact think of plenty of legitimate uses for that as well (say, specializing a 2d/3d cross product or small matrix determinants). Certainly in many array based languages, the shape of the array is effectively regarded as part of the type description. But some syntax to more flexibly match ranges of values (or general computable functions thereof), rather than just single types, or collections thereof, would be a requirement for me; and I suppose something that isnt currently covered by plum yet from what I can tell. |
Hey @PhilipVinc and @EelcoHoogendoorn, Ah @PhilipVinc says, I think this might be challenging, but it could conceivably be done. I've hacked something very quick together. Would something like the below be what you're after? import numpy as np
from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, parameter, check):
self.parameter = parameter
self.check = check
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
@parametric(runtime_type_of=True)
class _NPArray(np.ndarray):
"""A type for NumPy arrays where the type parameter specifies the number of
dimensions."""
class _NPArrayClass:
"""Just some sugar to make indexing with square bracket work."""
def __getitem__(self, shape):
def check_subshape(other_shape):
try:
other_shape = tuple(other_shape)
except TypeError:
return False
return len(other_shape) == len(shape) and all(
s1 <= s2 for s1, s2 in zip(other_shape, shape)
)
return _NPArray[DynamicTypeParameter(shape, check_subshape)]
NPArray = _NPArrayClass()
@type_of.dispatch
def type_of(x: np.ndarray):
# Hook into Plum's type inference system to produce an appropriate instance of
# `NPArray` for NumPy arrays.
return NPArray[x.shape]
@dispatch
def f(x: NPArray[(10,)]):
return "vector of (10,)"
@dispatch
def f(x: NPArray[(10, 20)]):
return "matrix of (10, 20)"
print(f(np.random.randn(8)))
# vector of (10,)
print(f(np.random.randn(11)))
# For function "f", signature Signature(__main__._NPArray[(11,)]) could not be resolved.
print(f(np.random.randn(8, 18)))
# matrix of (10, 10)
print(f(np.random.randn(8, 22)))
# NotFoundLookupError: For function "f", signature Signature(__main__._NPArray[(8, 22)]) could not be resolved. |
Hey @wesselb ; thats looking pretty neat. The ability to specify some kind of computable matching function is quite critical for my actual application though. Though I suppose that would be a significant departure from the current syntax, as parsed from the type annotations of the signature itself. Still; one could imagine parsing the signature as the default; which would correspond to a more flexible syntax as illustrated by the below (where attribute maps from the input to a hashable attribute of the input, and condition from said attribute to a bool):
Note: in the mockup code I cobbles together, the attribute selection function is fed into the dispatcher via its constructor rather than the annotation call; since it tends to be the same across a whole family of functions, so no need to repeat that all the time. |
I was wondering if this is considered bad style, to dispatch based on the object value; but according to wikipedia: https://en.wikipedia.org/wiki/Multiple_dispatch
So its not unheard of... even if it havnt seen it in the python ecosystem before. |
Would something like how the example works suffice, or would you need something more flexible? To clarify, in the example, the second argument to ...
def check_subshape(other_shape):
try:
other_shape = tuple(other_shape)
except TypeError:
return False
return len(other_shape) == len(shape) and all(
s1 <= s2 for s1, s2 in zip(other_shape, shape)
)
return _NPArray[DynamicTypeParameter(shape, check_subshape)]
... It is true that all information would have to be contained in the type parameter of the object, which is more limited than the general case which you outline where the function could take in the object itself.
This is really interesting. I think one could conceivably hack this together, but it might be challenging to make it really work well. For example, at the moment, types are hashable, which enables caching, and caching is crucial to obtain reasonable performance. If you use an arbitrary function to check whether an object belongs to the type family, then caching in this way wouldn't be possible anymore, because the object might not be hashable. If the object were hashable, then you could make it part of the type parameter and use the approach from the example; perhaps such a halfway house might be reasonable solution? |
Yeah; id need something more flexible.
Yeah; in my example I also meant for the 'attribute' lambda to return some hashable, 'type-like' information attribute of the object. A shape or dimension of an array would qualify, as being immutable type-like descriptors of the underlying raw array. The dtype of the array itself would qualify as well, obviously, even if not part of the 'type' on the level of the python type system. But there are other use-cases for imbuing array elements with various type-like annotations beyond just their dtype; perhaps physical units, or in the case I have in mind, blades/grades of a geometric-algebra. Those hashable blade-descriptor objects would be quite involved objects themselves, with a variety of computable properties, on the basis of which we would want to specialize our function dispatch. That is, it would invite a richer syntax than just testing if the hashable attribute is inside some set, as per the typical type-based dispatch. |
Hmm, perhaps I'm misunderstanding, but I think what you're after is possible in the setup of the example. E.g., what about something like the following? from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, check, parameter=None):
self.check = check
self.parameter = parameter
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
@parametric
class DynamicallyTypedObject:
def __init__(self, obj):
self.obj = obj
@classmethod
def __infer_type_parameter__(cls, obj):
# Use the object itself as the type parameter.
return obj
@dispatch
def f(x: DynamicallyTypedObject[
DynamicTypeParameter(check=lambda p: hasattr(type(p), "__len__") and len(p) <= 10)
]):
print("Method for `len(x) <= 10`!")
import torch # Use PyTorch, because NumPy arrays cannot be hashed.
f(DynamicallyTypedObject(torch.ones(5)))
# Method for `len(x) <= 10`!
f(DynamicallyTypedObject(torch.ones(15)))
# NotFoundLookupError We can add some sugar on top of this to make sure that it behaves like your proposal of import numpy as np
from functools import wraps
from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta
class DynamicTypeParameter:
"""A type parameter which contains a function `check` which can check whether
another type parameter is a type subparameter."""
def __init__(self, check, parameter=None):
self.check = check
self.parameter = parameter
def __str__(self):
return str(self.parameter)
def __repr__(self):
return repr(self.parameter)
_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter
def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
# Resolve dynamic type parameter for the parent class.
if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
par_check = par_cls[0].check
par_parameter = (par_cls[0].parameter,)
else:
par_check = None
# Resolve dynamic type parameter for the subclass.
if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
par_subclass = (par_subclass[0].parameter,)
if par_check:
# Use the dynamic check.
return par_check(*par_subclass)
else:
return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)
CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter
def attribute_dispatch(attribute, check):
def decorator(f):
@parametric
class _DynamicallyTypedObject:
def __init__(self, obj):
self.obj = obj
@classmethod
def __infer_type_parameter__(cls, obj):
return attribute(obj)
def wrapped_f(x: _DynamicallyTypedObject[DynamicTypeParameter(check)]):
return f(x)
wrapped_f.__name__ = f.__name__
wrapped_f = dispatch(wrapped_f)
@wraps(f)
def second_wrapped_f(x):
return wrapped_f(_DynamicallyTypedObject(x))
return second_wrapped_f
return decorator
def _safe_le(x, y):
try:
return x <= y
except Exception:
return False
@attribute_dispatch(attribute=lambda x: len(x), check=lambda x: _safe_le(x, 10))
def f(x):
print("Method for `len(x) <= 10`!")
f(np.ones(5))
# Method for `len(x) <= 10`!
f(np.ones(15))
# NotFoundLookupError |
I dont fully appreciate the internal mechanisms (i suppose it requires some extra complexity to shoehorn it into the existing mechanisms of plum); but yeah the external API indeed looks the part! With the minor detail that I think usually itd be nice to seperately bind the attribute to create a len/whatever_dispatcher since that part tends to be constant over a bunch of annotations usually I suppose. |
That should certainly be possible! I think I really like this idea and that it would be a valuable addition. However, you're indeed right that some precarious manoeuvring is required to fit it into the current internal mechanisms. There are still a few things unclear to me. Firstly, to make this really work, the current mechanisms would require the ability to ask whether one Secondly, in the prototype, the information about the shape is preserved by wrapping the object in another object with a type parameter. I think I still like the fundamental principle that all information necessary for dispatch should be retained when you take What I would propose is an interface like the below. (Note that this is pseudo-code and doesn't actually run.) from numpy import ndarray
from plum import dispatch, parametric
from plum.util import Comparable
class Shape(Comparable):
def __init__(self, *dims)
self.dims = dims
@dispatch
def __le__(self, other: "Shape"):
return ...
@parametric
class ShapedArray(ndarray):
def __infer_type_parameter__(cls, *args, **kw_args):
return Shape(*self.shape)
@dispatch
def f(x: ShapedArray[Shape(10, 10)]):
# Do something for objects with `shape <= Shape(10, 10)`.
@dispatch
def f(x: np.ndarray):
# Ensure that the shape information is always included.
return f(ShapedArray(x))
f(np.random.randn(15, 15)) What would you think about this proposal? Would that suffice for your use case? |
I suppose it would suffice; but considering the number of different kinds of matching functions that I need, and the fact that they dont really have a ton of reuse; many are one-off, the need to wrap every condition into a type isnt very attractive. While I like using nice libraries over reinventing the wheel, getting exactly the syntax I want takes me 40 lines of code added to my project. |
That's totally fair enough.
How would you deal with the below ambiguity? @dispatch(condition=lambda x: len(x) <= 10)
def f(x):
print("This is an int of length at most 10")
@dispatch(condition=lambda x: len(x) <= 20)
def f(x):
print("This is an int of length at most 20") If you call |
What I currently do is to let the order in which the annotations are registered be defining for the order in which they are matched. With an optional kwarg to overload the insertion order; which might be useful when extending base library functionality. Instead of maintaining an ordered list and letting that be defining... it might be better to define an explicit priority number. That way itd be easier to manage the situation where there could be multiple extension modules, and you dont want to end up with the risk of an import-order dependent result. You could just insert with fractional priority; and maybe raise an error in case of ambiguity? |
I see! If that approach or assigning a priority number suffices, then perhaps it is not necessary to go through a lot of technical hoops for a fully general implementation of the idea. My sense is that determining which |
yeah... I dont know that there is a general answer to the question of how to implement such comparisons... the other option would be for the matching function to return a number, rather than a bool... that might work in some situations where you want this to be some complex estimate of the input (like maybe an estimate of runtime for a given implementation), but otherwise its just messy. On a related note: if one is to use generic object attributes to decide multiple dispatch, its nice to encourage people to implement a flyweight pattern for those objects; to guarantee trivial equality comparisons between them:
I guess its not a hard requirement that the attributes you dispatch on have a certain cost of value equality comparison... but if you care about performance it sure is nice. In my case these attributes contain potentially humongous numpy arrays that do not natively compare cheaply; and I do care about performance. If you end up going through with implementing something like this, might be useful to include something like this in the example. Id already hacked in similar functionality; but its nicer to factor out such a known pattern with an existing name; and it took me a while to realize it could be quite simple indeed. This is the same thing python does under the hood to trivialize string comparisons and some other primitive types. Note that this a pretty simple/rough example... one nice feature to add would be for the flyweight mixin to sabotage the normal constructor, so your code cant accidentially create one outside the pool; cause thatd be a bug that would break the value equality implementation. |
Hmm, yeah. It's a hard problem!
This is a nice suggestion! Certainly a cheap performance boost in certain scenarios. If we were to go down the path of implementing dispatch based on object attributes, it would be worth thinking about how to make this as efficient as possible and also how to reduce boilerplate as much as possible. |
@wesselb Am I right in thinking that support for this would be automatically included in the case that beartype dispatch is supported re: #53 ? See e.g. relevant beartype features. I'm pretty sure any of these types of situations could be checked with the new |
@tbsexton, combined with your proposal to perform dispatch solely using |
It would be nice (even though admittedly hacky) if we could dispatch on the number of dimensions of a jax/numpy array, which can be probed with
.ndim
on an instance.This is admittedly not part of the type information in python, but maybe we can hack it in?
I was thinking of creating a custom parametric type for signatures, but then the problem is resolving the call signature and bringing this information from the value domain to the type domain.
This happens in
parametric.py
if I am not mistaken, and would require changing this function to allow hacking in some custom types...Do you have any idea on what would be the best way to implement this?
The text was updated successfully, but these errors were encountered: