Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforcing positivity (or other transformations) of TrainVars #165

Closed
wil-j-wil opened this issue Dec 2, 2020 · 10 comments
Closed

Enforcing positivity (or other transformations) of TrainVars #165

wil-j-wil opened this issue Dec 2, 2020 · 10 comments

Comments

@wil-j-wil
Copy link

Hi,

Is it possible to declare constraints on trainable variables, e.g. forcing them to be positive via an exponential or softplus transformation?

In an ideal world, we would be able to write something like:
self.variance = objax.TrainVar(np.array(1.0), transform=positive)

Thanks,

Will

p.s. thanks for the great work on objax so far, it's a pleasure to use.

@david-berthelot
Copy link
Contributor

It can be done but by applying a function to the variable when it is used; basically by using standard Python.

Example:

import objax
import jax.numpy as jn

class MyModule(objax.Module):
    def __init__(self):
        self.v = objax.TrainVar(jn.ones([]))

    def __call__(self, x):
        return jn.abs(v.value) * x

In this example I used jn.abs as a way to make the variable usage positive, you can use any math expression you'd like obviously.

You can go as fancy as you want in your use of standard python:

import objax
import jax.numpy as jn

class MyModule(objax.Module):
    def __init__(self):
        self.v = objax.TrainVar(jn.ones([]))

    @property
    def positive_v(self):
        return jn.abs(self.v.value)

    def __call__(self, x):
        return self.positive_v * x

Would that work for you?

@wil-j-wil
Copy link
Author

Hi,

Thanks for the quick response. The solution you suggest absolutely works, the only thing lacking is that it would be really nice if calling self.v.value returned the positive value itself. That way the code becomes cleaner, and when a user wants to inspect a variable, they don't have to know which positivity transform to apply. Also, initialising the variables would become slightly cleaner.

Tensorflow probability have a similar feature called TransformedVariable: https://www.tensorflow.org/probability/api_docs/python/tfp/util/TransformedVariable?hl=en

However, I can see that implementing such a feature could be more effort than it's worth, depending on whether others would use it too.

@david-berthelot
Copy link
Contributor

Yes, I see your point, it's a balancing act really between having very explicit behaviors (a variable is just a container for a matrix/tensor, any other properties of that variable come from modifications applied to it when using it) and implicit behavior (a variable is a matrix combined with more functionalities that can continuously grow over time).

We're trying really hard to avoid bloat and retain explicit behaviors to make the framework be really easy to learn due to its APIs being focused/minimal.

That being said, I'm interested in hearing more opinions. @AlexeyKurakin What do you think?

@AlexeyKurakin
Copy link
Member

AlexeyKurakin commented Dec 9, 2020

I don't think it's a good idea to add transformation functionality to existing Objax variables for the reasons David stated above.

However as @wil-j-wil mentioned we can have a wrapper which can perform arbitrary transformations of underlying variables.
Here is possible example:

class TransformedVar(objax.BaseVar):
    def __init__(self, wrapped_var: objax.BaseVar, transform: Callable[[JaxArray], JaxArray]):
        self._v = wrapped_var
        self._transform = transform

    @property
    def value(self) -> JaxArray:
        """The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable.
        You can write a value to a TrainVar by using assign."""
        return self._transform(self._v.value)

    @value.setter
    def value(self, tensor: JaxArray):
        raise ValueError('Assignment of transformed variable is not allowed.')


v = objax.TrainVar(...)
positive_v = TransformedVar(v, lambda x: jn.abs(x))

At this point I would be reluctant to add such wrapper to the core Objax code, to keep codebase simple and concise. However you can easily add it to your own code and I think it should satisfy your needs.

Moreover we can add a tutorial which will implement TransformedVar (inside the tutorial code) and will show how to use it to do transformations of variables.
@wil-j-wil will you be willing to contribute tutorial showing how to apply transformations to variables?

@david-berthelot
Copy link
Contributor

Actually you could make it trainable:

class TrainVarTransform(objax.TrainVar):
    def __init__(self, tensor: JaxArray, reduce: Optional[Callable[[JaxArray], JaxArray]] = lambda x: x.mean(0),
                 transform: Optional[Callable[[JaxArray], JaxArray]] = None):
        super().__init__(tensor, reduce)
        self.transform = transform

    @property
    def value(self) -> JaxArray:
        if self.transform:
            return self.transform(self._value)
        return self._value

You can inherit from TrainVar for trainable variables and make your own. You can also inherit from StateVar for non-trainable variables.

@wil-j-wil
Copy link
Author

This sounds totally reasonable, and I see your point and agree there is no need to add to the core library.

@david-berthelot's idea is very elegant, and would be ideal, however I don't think it works quite as expected. The gradients seem to be off, maybe overriding the value property is causing something strange to happen here? Also, when trying to JIT the training loop using this approach, I get the following error:

jax.core.UnexpectedTracerError: Encountered an unexpected tracer. Perhaps this tracer escaped through global state from a previously traced function.
The functions being transformed should not save traced values to global state.

I will take a closer look later, and if we find a nice solution I'd be happy to contribute a short tutorial.

@david-berthelot
Copy link
Contributor

Feel free to share a code snippet if you need help.

@wil-j-wil
Copy link
Author

Here's the most basic working example I could come up with, adapting from the tutorials. I have defined TrainVarTransform exactly as @david-berthelot did above.

The SimpleModule class defines a TrainVar and then uses an explicit transformation, and gives the correct gradients. SimpleModuleTransform instead uses TrainVarTransform, but gives incorrect gradients.

class SimpleModule(objax.Module):

    def __init__(self, length):
        self.v1 = objax.TrainVar(jn.log(jn.ones((length,))))

    @property
    def v1_positive(self):
        return jn.exp(self.v1.value)

    def __call__(self, x_):
        return jn.dot(x_, self.v1_positive)


class SimpleModuleTransform(objax.Module):

    def __init__(self, length):
        self.v1 = TrainVarTransform(jn.log(jn.ones((length,))), transform=jn.exp)

    def __call__(self, x_):
        return jn.dot(x_, self.v1.value)


m = SimpleModule(3)
mt = SimpleModuleTransform(3)


x = jn.ones((3,))
y = jn.array((-1.0, 1.0))


def loss_fn(x_, y_):
    return ((m(x_) - y_) ** 2).sum()


def loss_fnt(x_, y_):
    return ((mt(x_) - y_) ** 2).sum()


print('loss_fn(x, y) = ', loss_fn(x, y))

module_vars = m.vars()
module_varst = mt.vars()

# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)
gvt = objax.GradValues(loss_fnt, module_varst)

# gv returns both gradients and values of original function
grads, value = gv(x, y)
gradst, valuet = gvt(x, y)

print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
    print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)

print('Gradients:')
for g, var_name in zip(gradst, module_varst.keys()):
    print(g, ' w.r.t. ', var_name)
print()
print('Value: ', valuet)

This runs, however the value and gradients for the second module don't match:

Gradients:
[12. 12. 12.]  w.r.t.  (SimpleModule).v1

Value:  [DeviceArray(20., dtype=float32)]

Gradients:
[88.66867 88.66867 88.66867]  w.r.t.  (SimpleModuleTransform).v1

Value:  [DeviceArray(135.00299, dtype=float32)]

and if I were to JIT SimpleModuleTransform then I would get the error message I posted above.

@david-berthelot
Copy link
Contributor

Okay I see what the problem is, the .value is used as a variable to be updated (like for example .value += 0.5 which would not work if a transform is applied).

Instead, one could define a .tvalue property (for transformed, the name is unimportant):

from typing import Optional, Callable

import jax.numpy as jn
import objax
from objax.typing import JaxArray


class TrainVarTransform(objax.TrainVar):
    def __init__(self, tensor: JaxArray, reduce: Optional[Callable[[JaxArray], JaxArray]] = lambda x: x.mean(0),
                 transform: Optional[Callable[[JaxArray], JaxArray]] = None):
        super().__init__(tensor, reduce)
        self.transform = transform

    @property
    def tvalue(self) -> JaxArray:
        if self.transform:
            return self.transform(self._value)
        return self._value


class SimpleModule(objax.Module):

    def __init__(self, length):
        self.v1 = objax.TrainVar(jn.log(jn.ones((length,))))

    @property
    def v1_positive(self):
        return jn.exp(self.v1.value)

    def __call__(self, x_):
        return jn.dot(x_, self.v1_positive)


class SimpleModuleTransform(objax.Module):

    def __init__(self, length):
        self.v1 = TrainVarTransform(jn.log(jn.ones((length,))), transform=jn.exp)

    def __call__(self, x_):
        return jn.dot(x_, self.v1.tvalue)


m = SimpleModule(3)
mt = SimpleModuleTransform(3)


x = jn.ones((3,))
y = jn.array((-1.0, 1.0))


def loss_fn(x_, y_):
    return ((m(x_) - y_) ** 2).sum()


def loss_fnt(x_, y_):
    return ((mt(x_) - y_) ** 2).sum()


print('loss_fn(x, y) = ', loss_fn(x, y))

module_vars = m.vars()
module_varst = mt.vars()

# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)
gvt = objax.GradValues(loss_fnt, module_varst)

# gv returns both gradients and values of original function
grads, value = gv(x, y)
gradst, valuet = gvt(x, y)

print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
    print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)

print('Gradients:')
for g, var_name in zip(gradst, module_varst.keys()):
    print(g, ' w.r.t. ', var_name)
print()
print('Value: ', valuet)

That works but that's not very satisfying since now one has to use .tvalue instead of .value. So we cannot simply replace a TrainVar in objax.nn.Linear with a TransformTrainVar. But at least that's a semi-elegant way of handling your use case.

@AlexeyKurakin
Copy link
Member

Seems there are some solution for transformed var, which do not even require modifications of Objax code.
At this point I'm closing the issue.
@wil-j-wil if you have any more questions feel free to reopen the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants