Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Programmatically list the "Numpy" definitions not implemented by "jax.numpy". #3689

Closed
KelSolaar opened this issue Jul 8, 2020 · 5 comments
Assignees
Labels
enhancement New feature or request

Comments

@KelSolaar
Copy link

KelSolaar commented Jul 8, 2020

Hi,

I'm trying to use Jax as a computation backend for Colour and I have no easy (and fast) way to programmatically find which Numpy definitions are supported or not by jax.numpy. The problem is that I can only discover that when they are called.

What I was looking at currently is a mechanism that routes the definitions depending on whether the selected backend and failsafe to Numpy if they do not exist or are not implemented.

This is the relevant content of test colour.ndarray.backend module:

_NDIMENSIONAL_ARRAY_BACKEND = os.environ.get(
    'COLOUR_SCIENCE__NDIMENSIONAL_ARRAY_BACKEND', 'Numpy').lower()

class NDimensionalArrayBackend(object):
    def __init__(self):
        import numpy

        self._failsafe = self._numpy = numpy

        try:
            import jax.numpy

            self._jax = jax.numpy
        except ImportError:
            self._jax = None

    def __getattr__(self, attribute):
        failsafe = getattr(self._failsafe, attribute)

        if _NDIMENSIONAL_ARRAY_BACKEND == 'numpy':
            return getattr(self._numpy, attribute)
        elif _NDIMENSIONAL_ARRAY_BACKEND == 'jax' and self._jax is not None:
            try:
                return getattr(self._jax, attribute)
            except AttributeError:
                return failsafe
        else:
            return failsafe

Then the colour.ndarray.__init__ module is implemented as follows:

from __future__ import absolute_import

import sys

from .backend import NDimensionalArrayBackend


class ndarray(NDimensionalArrayBackend):
    def __getattr__(self, attribute):
        return super(ndarray, self).__getattr__(attribute)


sys.modules['colour.ndarray'] = ndarray()

del NDimensionalArrayBackend, sys

Thus now instead of import numpy as np I can import colour.ndarray as np and this route the code accordingly to the _NDIMENSIONAL_ARRAY_BACKEND global.

The problem is that if some of my code uses a Jax definition that is not implemented, e.g. np.copy it raises an exception.

The list of not implemented definitions would be trivial to set somewhere when looking at the jax.numpy.__init__ module here:

globals()[func.__name__] = lax_numpy._not_implemented(func)

Hope that makes sense!

Cheers

Thomas

KelSolaar added a commit to colour-science/colour that referenced this issue Jul 8, 2020
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 8, 2020

We could add some mechanism for that. Note that not-implemented functions are also defined elsewhere, for example:

jax/jax/numpy/lax_numpy.py

Lines 4325 to 4328 in fdd7f0c

# These methods are mentioned explicitly by nondiff_methods, so we create
# _not_implemented implementations of them here rather than in __init__.py.
# TODO(phawkins): implement these.
argpartition = _not_implemented(np.argpartition)

Would a global list of not implemented functions do what you want? For example, a list of strings stored in jax.numpy._NOT_IMPLEMENTED or similar?

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 8, 2020

One way you can get the list in the current release is like this, though it's a bit of a hack:

  In [1] import inspect 
    ...: import jax.numpy 
    ...:  
    ...: def unimplemented(): 
    ...:     for name in dir(jax.numpy): 
    ...:         f = getattr(jax.numpy, name) 
    ...:         try: 
    ...:             source = inspect.getsource(f) 
    ...:         except (TypeError, OSError): 
    ...:             continue 
    ...:         if "Numpy function {} not yet implemented" in source: 
    ...:             yield name 
    ...:                                                                                                            

In [2]: list(unimplemented())
['<lambda>',
 '_add_newdoc_ufunc',
 '_fastCopyAndTranspose',
 'add_docstring',
 'add_newdoc',
 'alen',
 'apply_along_axis',
 'apply_over_axes',
 'argpartition',
 'array2string',
 'array_equiv',
 'array_split',
 'asanyarray',
 'asarray_chkfinite',
 'ascontiguousarray',
 'asfarray',
 'asfortranarray',
 'asmatrix',
 'asscalar',
 'base_repr',
 'binary_repr',
 'bmat',
 'busday_count',
 'busday_offset',
 'byte_bounds',
 'choose',
 'common_type',
 'compare_chararrays',
 'copy',
 'copyto',
 'datetime_as_string',
 'datetime_data',
 'delete',
 'deprecate',
 'diag_indices_from',
 'disp',
 'fill_diagonal',
 'find_common_type',
 'format_float_positional',
 'format_float_scientific',
 'frombuffer',
 'fromfile',
 'fromfunction',
 'fromiter',
 'frompyfunc',
 'fromregex',
 'fromstring',
 'fv',
 'genfromtxt',
 'get_array_wrap',
 'get_include',
 'get_printoptions',
 'getbufsize',
 'geterr',
 'geterrcall',
 'geterrobj',
 'histogram2d',
 'histogramdd',
 'i0',
 'info',
 'insert',
 'int_asbuffer',
 'interp',
 'intersect1d',
 'invert',
 'ipmt',
 'irr',
 'is_busday',
 'isfortran',
 'isnat',
 'issctype',
 'issubclass_',
 'lax_numpy',
 'lexsort',
 'loads',
 'loadtxt',
 'lookfor',
 'mafromtxt',
 'maximum_sctype',
 'may_share_memory',
 'min_scalar_type',
 'mintypecode',
 'mirr',
 'modf',
 'nanmedian',
 'nanpercentile',
 'nanquantile',
 'ndfromtxt',
 'nested_iters',
 'nper',
 'npv',
 'obj2sctype',
 'partition',
 'piecewise',
 'place',
 'pmt',
 'poly',
 'polyder',
 'polydiv',
 'polyfit',
 'polyint',
 'ppmt',
 'printoptions',
 'put',
 'put_along_axis',
 'putmask',
 'pv',
 'rate',
 'ravel_multi_index',
 'real_if_close',
 'recfromcsv',
 'recfromtxt',
 'require',
 'resize',
 'round_',
 'safe_eval',
 'savetxt',
 'savez_compressed',
 'sctype2char',
 'set_numeric_ops',
 'set_string_function',
 'setbufsize',
 'setdiff1d',
 'seterr',
 'seterrcall',
 'seterrobj',
 'setxor1d',
 'shares_memory',
 'show',
 'sort_complex',
 'source',
 'spacing',
 'tril_indices_from',
 'trim_zeros',
 'triu_indices_from',
 'typename',
 'union1d',
 'unwrap',
 'who']

@KelSolaar
Copy link
Author

KelSolaar commented Jul 8, 2020

Hi @jakevdp,

Thanks, a simple attribute with a list of strings would be fantastic and super clean.

Excellent idea about introspecting the code, I did not think about it but this will do for our tests!

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 9, 2020

With #3697, you can now do this:

import jax.numpy as jnp
print(jnp._NOT_IMPLEMENTED)

The variable contains a list of names of unimplemented functions.

@KelSolaar
Copy link
Author

Awesome @jakevdp and thanks again!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants