Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hack some way to dispatch on number of dimensions of numpy/jax arrays #10

Closed
PhilipVinc opened this issue May 21, 2021 · 22 comments · Fixed by #73
Closed

Hack some way to dispatch on number of dimensions of numpy/jax arrays #10

PhilipVinc opened this issue May 21, 2021 · 22 comments · Fixed by #73

Comments

@PhilipVinc
Copy link
Collaborator

It would be nice (even though admittedly hacky) if we could dispatch on the number of dimensions of a jax/numpy array, which can be probed with .ndim on an instance.

This is admittedly not part of the type information in python, but maybe we can hack it in?

I was thinking of creating a custom parametric type for signatures, but then the problem is resolving the call signature and bringing this information from the value domain to the type domain.

This happens in parametric.py if I am not mistaken, and would require changing this function to allow hacking in some custom types...
Do you have any idea on what would be the best way to implement this?

def type_of(obj):
    """Get the Plum type of an object.

    Args:
        obj (object): Object to get type of.

    Returns
        ptype: Plum type of `obj`.
    """
    if isinstance(obj, list):
        return List(_types_of_iterable(obj))

    if isinstance(obj, tuple):
        return Tuple(*(type_of(x) for x in obj))

    return ptype(type(obj))
@wesselb
Copy link
Member

wesselb commented Jun 1, 2021

Hey @PhilipVinc,

Sorry for not getting back to this earlier. I'm a bit busy at the moment.

I fully agree that being able to dispatch on the dimensionality of an array would be an amazing feature to have. Your proposed solution seems like something that could be made to work.

Here's one way of doing this. First, we make type_of extendable by using dispatch.

@_dispatch
def type_of(obj):
    """Get the Plum type of an object.

    Args:
        obj (object): Object to get type of.

    Returns
        ptype: Plum type of `obj`.
    """
    if isinstance(obj, list):
        return List(_types_of_iterable(obj))

    if isinstance(obj, tuple):
        return Tuple(*(type_of(x) for x in obj))

    return ptype(type(obj))

Then this is possible:

import numpy as np
from plum import dispatch, parametric, type_of


@parametric
class NPArray(np.ndarray):
    pass


@type_of.dispatch
def type_of(x: np.ndarray):
    return NPArray[x.ndim]


@dispatch
def f(x: NPArray[1]):
    print("Vector!")


@dispatch
def f(x: NPArray[2]):
    print("Matrix!")


# Plum currently avoids unnecessary `type_of` calls. We would enable the use
# of `type_of` by default. Here's a hack that works for now:
f._parametric = True

This would be the outcome:

>>> f(np.random.randn(5))
Vector!

>>> f(np.random.randn(5, 5))
Matrix!

>>> f(np.random.randn(5, 5, 5))
NotFoundLookupError: For function "f", signature Signature(plum.parametric.NPArray[3]) could not be resolved.

These extensions might be best off in LAB. For example, you would then import lab.jax as B and def f(x: B.JAXArray[2]) or def f(x: B.Array[2]).

How does this sound?

@wesselb
Copy link
Member

wesselb commented Jun 11, 2021

I've made the above change in the most recent release, 1.1.1, and added an entry to the README. I'll leave this issue open and close it once good implementations of backend-specific array types (like JAXArray[1]) are available.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 9, 2022

Somewhat related to the above; do you consider it feasible/possible/desirable to be able to dispatch on the basis of (a computable subset of) the value of arguments? Im trying to do something akin (though not literally identical) to dispatch based on array shape. In practice, a finite amount of concrete shapes would be in play; but I cant enumerate them upfront, since it depends on the use of the library. That is, id like to match my method dispatch on a computable function of such a hashable type (such as a shape).

@dispatch(property=lambda x: x.shape, condition=lambda s: s < (3, 4, 5))
def f(x: np.ndarray):
    print("This array has shape smaller than (3, 4, 5)")

Computing the condition might be expensive; but it should only happen once for every shape that occurs; and then it should just be a dict lookup based on a hashable type; just like any other dynamic dispatched call.

@PhilipVinc
Copy link
Collaborator Author

You can already do it.
You need to define a parametric type with a custom type_of and then use it to define a dispatch rule.

See this test, where I do it for the number of dimensions (but you can generalise it to work on shapes) https://github.com/wesselb/plum/blob/fb29f9723f4b006457c74beb4cb09c811c1c983a/tests/test_parametric.py#L680

@parametric(runtime_type_of=True)
class NPArray(np.ndarray):
    pass

@type_of.dispatch
def type_of_extension(x: np.ndarray):
    return NPArray[x.ndim]

@dispatch
def f(x: NPArray[1]):
    return "vector"

@dispatch
def f(x: NPArray[2]):
    return "matrix"

assert f(np.random.randn(10)) == "vector"
assert f(np.random.randn(10, 10)) == "matrix"
with pytest.raises(NotFoundLookupError):
    f(np.random.randn(10, 10, 10))

@PhilipVinc
Copy link
Collaborator Author

Ah sorry, now I see what you are asking is a bit different..
I think that implementing this is going to be a bit though...

@EelcoHoogendoorn
Copy link

Yeah... im not seeing any other libraries that dispatch in any way on object value; just purely on type (this example of dispatching based on array dimension is the first example im seeing to the contrary).

So either this is a great idea; or one thats going to blow up in your face once you think about it more / work with it more?

For the use case I have in mind, the hashable attribute is defacto a dynamic extension of the typing system though; and while my use case isnt precisely identical to dispatching based on array shape; I can infact think of plenty of legitimate uses for that as well (say, specializing a 2d/3d cross product or small matrix determinants).

Certainly in many array based languages, the shape of the array is effectively regarded as part of the type description. But some syntax to more flexibly match ranges of values (or general computable functions thereof), rather than just single types, or collections thereof, would be a requirement for me; and I suppose something that isnt currently covered by plum yet from what I can tell.

@wesselb
Copy link
Member

wesselb commented May 11, 2022

Hey @PhilipVinc and @EelcoHoogendoorn,

Ah @PhilipVinc says, I think this might be challenging, but it could conceivably be done.

I've hacked something very quick together. Would something like the below be what you're after?

import numpy as np

from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta


class DynamicTypeParameter:
    """A type parameter which contains a function `check` which can check whether
    another type parameter is a type subparameter."""

    def __init__(self, parameter, check):
        self.parameter = parameter
        self.check = check

    def __str__(self):
        return str(self.parameter)

    def __repr__(self):
        return repr(self.parameter)


_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter


def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
    # Resolve dynamic type parameter for the parent class.
    if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
        par_check = par_cls[0].check
        par_parameter = (par_cls[0].parameter,)
    else:
        par_check = None

    # Resolve dynamic type parameter for the subclass.
    if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
        par_subclass = (par_subclass[0].parameter,)

    if par_check:
        # Use the dynamic check.
        return par_check(*par_subclass)
    else:
        return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)


CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter


@parametric(runtime_type_of=True)
class _NPArray(np.ndarray):
    """A type for NumPy arrays where the type parameter specifies the number of
    dimensions."""


class _NPArrayClass:
    """Just some sugar to make indexing with square bracket work."""

    def __getitem__(self, shape):
        def check_subshape(other_shape):
            try:
                other_shape = tuple(other_shape)
            except TypeError:
                return False
            return len(other_shape) == len(shape) and all(
                s1 <= s2 for s1, s2 in zip(other_shape, shape)
            )

        return _NPArray[DynamicTypeParameter(shape, check_subshape)]


NPArray = _NPArrayClass()


@type_of.dispatch
def type_of(x: np.ndarray):
    # Hook into Plum's type inference system to produce an appropriate instance of
    # `NPArray` for NumPy arrays.
    return NPArray[x.shape]


@dispatch
def f(x: NPArray[(10,)]):
    return "vector of (10,)"


@dispatch
def f(x: NPArray[(10, 20)]):
    return "matrix of (10, 20)"


print(f(np.random.randn(8)))      
# vector of (10,)

print(f(np.random.randn(11)))     
# For function "f", signature Signature(__main__._NPArray[(11,)]) could not be resolved.

print(f(np.random.randn(8, 18)))
# matrix of (10, 10)

print(f(np.random.randn(8, 22)))
# NotFoundLookupError: For function "f", signature Signature(__main__._NPArray[(8, 22)]) could not be resolved.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 11, 2022

Hey @wesselb ; thats looking pretty neat. The ability to specify some kind of computable matching function is quite critical for my actual application though. Though I suppose that would be a significant departure from the current syntax, as parsed from the type annotations of the signature itself.

Still; one could imagine parsing the signature as the default; which would correspond to a more flexible syntax as illustrated by the below (where attribute maps from the input to a hashable attribute of the input, and condition from said attribute to a bool):

@dispatch(attribute=lambda x: type(x), condition=lambda a: a == int)
def f(x):
    print("This is an int")
# the above could be seen as a verbose form of the below
@dispatch
def f(x: int):
    print("This is an int")

Note: in the mockup code I cobbles together, the attribute selection function is fed into the dispatcher via its constructor rather than the annotation call; since it tends to be the same across a whole family of functions, so no need to repeat that all the time.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 11, 2022

I was wondering if this is considered bad style, to dispatch based on the object value; but according to wikipedia: https://en.wikipedia.org/wiki/Multiple_dispatch

Multiple dispatch or multimethods is a feature of some programming languages in which a function or method can be dynamically dispatched based on the run-time (dynamic) type or, in the more general case, some other attribute of more than one of its [arguments](https://en.wikipedia.org/wiki/Parameter_(computer_programming)).

So its not unheard of... even if it havnt seen it in the python ecosystem before.

@wesselb
Copy link
Member

wesselb commented May 14, 2022

The ability to specify some kind of computable matching function is quite critical for my actual application though.

Would something like how the example works suffice, or would you need something more flexible? To clarify, in the example, the second argument to DynamicTypeParameter is a function which checks if a given type parameter is a type subparameter. This is the relevant bit of the example:

...
        def check_subshape(other_shape):
            try:
                other_shape = tuple(other_shape)
            except TypeError:
                return False
            return len(other_shape) == len(shape) and all(
                s1 <= s2 for s1, s2 in zip(other_shape, shape)
            )

        return _NPArray[DynamicTypeParameter(shape, check_subshape)]
...

It is true that all information would have to be contained in the type parameter of the object, which is more limited than the general case which you outline where the function could take in the object itself.

I was wondering if this is considered bad style, to dispatch based on the object value; but according to wikipedia: (...)

This is really interesting. I think one could conceivably hack this together, but it might be challenging to make it really work well. For example, at the moment, types are hashable, which enables caching, and caching is crucial to obtain reasonable performance. If you use an arbitrary function to check whether an object belongs to the type family, then caching in this way wouldn't be possible anymore, because the object might not be hashable. If the object were hashable, then you could make it part of the type parameter and use the approach from the example; perhaps such a halfway house might be reasonable solution?

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 15, 2022

Would something like how the example works suffice, or would you need something more flexible?

Yeah; id need something more flexible.

types are hashable

Yeah; in my example I also meant for the 'attribute' lambda to return some hashable, 'type-like' information attribute of the object. A shape or dimension of an array would qualify, as being immutable type-like descriptors of the underlying raw array. The dtype of the array itself would qualify as well, obviously, even if not part of the 'type' on the level of the python type system.

But there are other use-cases for imbuing array elements with various type-like annotations beyond just their dtype; perhaps physical units, or in the case I have in mind, blades/grades of a geometric-algebra. Those hashable blade-descriptor objects would be quite involved objects themselves, with a variety of computable properties, on the basis of which we would want to specialize our function dispatch. That is, it would invite a richer syntax than just testing if the hashable attribute is inside some set, as per the typical type-based dispatch.

@wesselb
Copy link
Member

wesselb commented May 15, 2022

Hmm, perhaps I'm misunderstanding, but I think what you're after is possible in the setup of the example. E.g., what about something like the following?

from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta


class DynamicTypeParameter:
    """A type parameter which contains a function `check` which can check whether
    another type parameter is a type subparameter."""

    def __init__(self, check, parameter=None):
        self.check = check
        self.parameter = parameter

    def __str__(self):
        return str(self.parameter)

    def __repr__(self):
        return repr(self.parameter)


_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter


def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
    # Resolve dynamic type parameter for the parent class.
    if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
        par_check = par_cls[0].check
        par_parameter = (par_cls[0].parameter,)
    else:
        par_check = None

    # Resolve dynamic type parameter for the subclass.
    if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
        par_subclass = (par_subclass[0].parameter,)

    if par_check:
        # Use the dynamic check.
        return par_check(*par_subclass)
    else:
        return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)


CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter


@parametric
class DynamicallyTypedObject:
    def __init__(self, obj):
        self.obj = obj

    @classmethod
    def __infer_type_parameter__(cls, obj):
        # Use the object itself as the type parameter.
        return obj


@dispatch
def f(x: DynamicallyTypedObject[
        DynamicTypeParameter(check=lambda p: hasattr(type(p), "__len__") and len(p) <= 10)
    ]):
    print("Method for `len(x) <= 10`!")


import torch  # Use PyTorch, because NumPy arrays cannot be hashed.

f(DynamicallyTypedObject(torch.ones(5)))   
# Method for `len(x) <= 10`!

f(DynamicallyTypedObject(torch.ones(15)))  
# NotFoundLookupError

We can add some sugar on top of this to make sure that it behaves like your proposal of @dispatch:

import numpy as np
from functools import wraps

from plum import dispatch, parametric, type_of
from plum.parametric import CovariantMeta


class DynamicTypeParameter:
    """A type parameter which contains a function `check` which can check whether
    another type parameter is a type subparameter."""

    def __init__(self, check, parameter=None):
        self.check = check
        self.parameter = parameter

    def __str__(self):
        return str(self.parameter)

    def __repr__(self):
        return repr(self.parameter)


_original_is_sub_type_parameter = CovariantMeta._is_sub_type_parameter


def _new_is_sub_type_parameter(cls, par_cls, subclass, par_subclass):
    # Resolve dynamic type parameter for the parent class.
    if len(par_cls) == 1 and isinstance(par_cls[0], DynamicTypeParameter):
        par_check = par_cls[0].check
        par_parameter = (par_cls[0].parameter,)
    else:
        par_check = None

    # Resolve dynamic type parameter for the subclass.
    if len(par_subclass) == 1 and isinstance(par_subclass[0], DynamicTypeParameter):
        par_subclass = (par_subclass[0].parameter,)

    if par_check:
        # Use the dynamic check.
        return par_check(*par_subclass)
    else:
        return _original_is_sub_type_parameter(cls, par_cls, subclass, par_subclass)


CovariantMeta._is_sub_type_parameter = _new_is_sub_type_parameter


def attribute_dispatch(attribute, check):
    def decorator(f):
        @parametric
        class _DynamicallyTypedObject:
            def __init__(self, obj):
                self.obj = obj

            @classmethod
            def __infer_type_parameter__(cls, obj):
                return attribute(obj)

        def wrapped_f(x: _DynamicallyTypedObject[DynamicTypeParameter(check)]):
            return f(x)

        wrapped_f.__name__ = f.__name__
        wrapped_f = dispatch(wrapped_f)

        @wraps(f)
        def second_wrapped_f(x):
            return wrapped_f(_DynamicallyTypedObject(x))

        return second_wrapped_f

    return decorator


def _safe_le(x, y):
    try:
        return x <= y
    except Exception:
        return False


@attribute_dispatch(attribute=lambda x: len(x), check=lambda x: _safe_le(x, 10))
def f(x):
    print("Method for `len(x) <= 10`!")



f(np.ones(5))   
# Method for `len(x) <= 10`!

f(np.ones(15))  
# NotFoundLookupError

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 15, 2022

I dont fully appreciate the internal mechanisms (i suppose it requires some extra complexity to shoehorn it into the existing mechanisms of plum); but yeah the external API indeed looks the part!

With the minor detail that I think usually itd be nice to seperately bind the attribute to create a len/whatever_dispatcher since that part tends to be constant over a bunch of annotations usually I suppose.

@wesselb
Copy link
Member

wesselb commented May 19, 2022

With the minor detail that I think usually itd be nice to seperately bind the attribute to create a len/whatever_dispatcher since that part tends to be constant over a bunch of annotations usually I suppose.

That should certainly be possible!

I think I really like this idea and that it would be a valuable addition. However, you're indeed right that some precarious manoeuvring is required to fit it into the current internal mechanisms. There are still a few things unclear to me.

Firstly, to make this really work, the current mechanisms would require the ability to ask whether one check function is more specific than another. E.g., lambda x: len(x) <= 5 is more specific than lambda x: len(x) <= 10. One way to make this work is to not use lambdas, but to use objects which can be compared. For example, make a Shape object which is hashable such that any two shapes can be compared. This comes at the cost of boilerplate, but perhaps it's not too bad.

Secondly, in the prototype, the information about the shape is preserved by wrapping the object in another object with a type parameter. I think I still like the fundamental principle that all information necessary for dispatch should be retained when you take type(object).

What I would propose is an interface like the below. (Note that this is pseudo-code and doesn't actually run.)

from numpy import ndarray

from plum import dispatch, parametric
from plum.util import Comparable


class Shape(Comparable):
    def __init__(self, *dims)
        self.dims = dims

    @dispatch
    def __le__(self, other: "Shape"):
        return ...


@parametric
class ShapedArray(ndarray):
    def __infer_type_parameter__(cls, *args, **kw_args):
        return Shape(*self.shape)


@dispatch
def f(x: ShapedArray[Shape(10, 10)]):
    # Do something for objects with `shape <= Shape(10, 10)`.


@dispatch
def f(x: np.ndarray):
    # Ensure that the shape information is always included.
    return f(ShapedArray(x))


f(np.random.randn(15, 15))

What would you think about this proposal? Would that suffice for your use case?

@EelcoHoogendoorn
Copy link

I suppose it would suffice; but considering the number of different kinds of matching functions that I need, and the fact that they dont really have a ton of reuse; many are one-off, the need to wrap every condition into a type isnt very attractive.

While I like using nice libraries over reinventing the wheel, getting exactly the syntax I want takes me 40 lines of code added to my project.

@wesselb
Copy link
Member

wesselb commented May 21, 2022

That's totally fair enough.

getting exactly the syntax I want takes me 40 lines of code added to my project.

How would you deal with the below ambiguity?

@dispatch(condition=lambda x: len(x) <= 10)
def f(x):
    print("This is an int of length at most 10")


@dispatch(condition=lambda x: len(x) <= 20)
def f(x):
    print("This is an int of length at most 20")

If you call f with an x such that len(x) = 5, both conditions would evaluate to True. The key idea of multiple dispatch is that the most specific method should be chosen. However, unless you inspect these lambdas, there is no way to know which of the conditions is most specific. EDIT: The wrapping into types, which admittedly isn't terribly attractive syntactically, forces you to chose an order which is then used to determine which is most specific.

@EelcoHoogendoorn
Copy link

What I currently do is to let the order in which the annotations are registered be defining for the order in which they are matched. With an optional kwarg to overload the insertion order; which might be useful when extending base library functionality.

Instead of maintaining an ordered list and letting that be defining... it might be better to define an explicit priority number. That way itd be easier to manage the situation where there could be multiple extension modules, and you dont want to end up with the risk of an import-order dependent result. You could just insert with fractional priority; and maybe raise an error in case of ambiguity?

@wesselb
Copy link
Member

wesselb commented May 22, 2022

I see! If that approach or assigning a priority number suffices, then perhaps it is not necessary to go through a lot of technical hoops for a fully general implementation of the idea. My sense is that determining which condition function is more specific is the key difficulty (and also what lies at the heart of multiple dispatch), and a general implementation will necessarily involve at least a bit of extra boilerplate which specifies how any two condition functions can be compared.

@EelcoHoogendoorn
Copy link

EelcoHoogendoorn commented May 23, 2022

yeah... I dont know that there is a general answer to the question of how to implement such comparisons... the other option would be for the matching function to return a number, rather than a bool... that might work in some situations where you want this to be some complex estimate of the input (like maybe an estimate of runtime for a given implementation), but otherwise its just messy.

On a related note: if one is to use generic object attributes to decide multiple dispatch, its nice to encourage people to implement a flyweight pattern for those objects; to guarantee trivial equality comparisons between them:

class FlyweightMixin:
	def __hash__(self):
		return id(self)
	def __eq__(self, other):
		return self is other


class FlyweightFactory:
	def __init__(self):
		self.flyweight_pool = {}
	def construct(self, key, value) -> FlyweightMixin:
		"""All constructor calls of flyweights supposed to be made via this factory method."""
		if key in self.flyweight_pool:
			return self.flyweight_pool[key]
		else:
			value = value()
			self.flyweight_pool[key] = value
			return value

I guess its not a hard requirement that the attributes you dispatch on have a certain cost of value equality comparison... but if you care about performance it sure is nice. In my case these attributes contain potentially humongous numpy arrays that do not natively compare cheaply; and I do care about performance. If you end up going through with implementing something like this, might be useful to include something like this in the example. Id already hacked in similar functionality; but its nicer to factor out such a known pattern with an existing name; and it took me a while to realize it could be quite simple indeed.

This is the same thing python does under the hood to trivialize string comparisons and some other primitive types. Note that this a pretty simple/rough example... one nice feature to add would be for the flyweight mixin to sabotage the normal constructor, so your code cant accidentially create one outside the pool; cause thatd be a bug that would break the value equality implementation.

@wesselb
Copy link
Member

wesselb commented May 25, 2022

yeah... I dont know that there is a general answer to the question of how to implement such comparisons... the other option would be for the matching function to return a number, rather than a bool... that might work in some situations where you want this to be some complex estimate of the input (like maybe an estimate of runtime for a given implementation), but otherwise its just messy.

Hmm, yeah. It's a hard problem!

On a related note: if one is to use generic object attributes to decide multiple dispatch, its nice to encourage people to implement a flyweight pattern for those objects; to guarantee trivial equality comparisons between them:

This is a nice suggestion! Certainly a cheap performance boost in certain scenarios. If we were to go down the path of implementing dispatch based on object attributes, it would be worth thinking about how to make this as efficient as possible and also how to reduce boilerplate as much as possible.

@rtbs-dev
Copy link
Contributor

@wesselb Am I right in thinking that support for this would be automatically included in the case that beartype dispatch is supported re: #53 ? See e.g. relevant beartype features. I'm pretty sure any of these types of situations could be checked with the new door api and the beartype validators.

@wesselb wesselb mentioned this issue Oct 23, 2022
@wesselb
Copy link
Member

wesselb commented Oct 23, 2022

@tbsexton, combined with your proposal to perform dispatch solely using isinstance checks, I think you might be right! This is exciting!

@wesselb wesselb mentioned this issue Feb 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants