Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Objax variables could be used in place of native JAX arrays #207

Merged
merged 4 commits into from
Mar 19, 2021
Merged

Objax variables could be used in place of native JAX arrays #207

merged 4 commits into from
Mar 19, 2021

Conversation

AlexeyKurakin
Copy link
Member

@AlexeyKurakin AlexeyKurakin commented Mar 13, 2021

This is a prototype implementation of #111

This is reasonably well tested and complete implementation of JAX duck typing, which allows to use Objax variables in mathematical expressions without explicitly calling .value.
This change does not include documentation, updates of examples, etc... which will be done in future pull requests.

@AlexeyKurakin
Copy link
Member Author

One drawback of the proposed design is that it requires to define a bunch of __OP__ methods on BaseVar.
This is a python limitation, it does not have any elegant workaround and it's needed to support things like var1 + var2 where var1 and var2 are Objax variables.
It somewhat bloats the BaseVar class.

While there is no way to avoid it, it is possible to "move" definitions of __OP__ methods outside of the BaseVar class. For example, JAX does it this way: https://github.com/google/jax/blob/80966fe5bfb123a9b7b0200eb9aa53bcef5f1a2b/jax/_src/numpy/lax_numpy.py#L5390

In case of BaseVar it could be done in a following way:

# Definition of BaseVar
class BaseVar(abc.ABC):
    def __init__(self, reduce: Optional[Callable[[JaxArray], JaxArray]]):
        self._reduce = reduce

    @property
    @abc.abstractmethod
    def value(self) -> JaxArray:
        raise NotImplementedError('Pure method')

    @value.setter
    @abc.abstractmethod
    def value(self, tensor: JaxArray):
        raise NotImplementedError('Pure method')

    def assign(self, tensor: JaxArray, check=True):
        # ...

    def reduce(self, tensors: JaxArray):
        # ...

    def __repr__(self):
        # ...


# ... definition of other classses 


# In the end of the file
# Support for JAX duck typing
_UNARY_OPS = [
    '__neg__',
    #...
]
_BINARY_OPS = [
    '__add__',
    #...
]

for unary_op in _UNARY_OPS:
    setattr(BaseVar, unary_op, lambda self, other, op_name=unary_op: getattr(self.value, op_name)())

for binary_op in BINARY_OPS:
    setattr(BaseVar, binary_op, lambda self, other, op_name=binary_op: getattr(self.value, op_name)(get_jax_value(other)))

This would remove clutter from BaseVar code, however will add clutter in the end of file and potentially may create confusion about how BaseVar works.

So I would be interested to hear opinions about it

@AlexeyKurakin AlexeyKurakin changed the title [DRAFT] Objax variables could be used in place of native JAX arrays Objax variables could be used in place of native JAX arrays Mar 19, 2021
@AlexeyKurakin AlexeyKurakin merged commit 58bac37 into google:master Mar 19, 2021
@AlexeyKurakin AlexeyKurakin deleted the dotvalue branch March 19, 2021 16:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants