New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enforcing positivity (or other transformations) of TrainVars #165
Comments
It can be done but by applying a function to the variable when it is used; basically by using standard Python. Example: import objax
import jax.numpy as jn
class MyModule(objax.Module):
def __init__(self):
self.v = objax.TrainVar(jn.ones([]))
def __call__(self, x):
return jn.abs(v.value) * x In this example I used You can go as fancy as you want in your use of standard python: import objax
import jax.numpy as jn
class MyModule(objax.Module):
def __init__(self):
self.v = objax.TrainVar(jn.ones([]))
@property
def positive_v(self):
return jn.abs(self.v.value)
def __call__(self, x):
return self.positive_v * x Would that work for you? |
Hi, Thanks for the quick response. The solution you suggest absolutely works, the only thing lacking is that it would be really nice if calling Tensorflow probability have a similar feature called TransformedVariable: https://www.tensorflow.org/probability/api_docs/python/tfp/util/TransformedVariable?hl=en However, I can see that implementing such a feature could be more effort than it's worth, depending on whether others would use it too. |
Yes, I see your point, it's a balancing act really between having very explicit behaviors (a variable is just a container for a matrix/tensor, any other properties of that variable come from modifications applied to it when using it) and implicit behavior (a variable is a matrix combined with more functionalities that can continuously grow over time). We're trying really hard to avoid bloat and retain explicit behaviors to make the framework be really easy to learn due to its APIs being focused/minimal. That being said, I'm interested in hearing more opinions. @AlexeyKurakin What do you think? |
I don't think it's a good idea to add transformation functionality to existing Objax variables for the reasons David stated above. However as @wil-j-wil mentioned we can have a wrapper which can perform arbitrary transformations of underlying variables. class TransformedVar(objax.BaseVar):
def __init__(self, wrapped_var: objax.BaseVar, transform: Callable[[JaxArray], JaxArray]):
self._v = wrapped_var
self._transform = transform
@property
def value(self) -> JaxArray:
"""The value is read only as a safety measure to avoid accidentally making TrainVar non-differentiable.
You can write a value to a TrainVar by using assign."""
return self._transform(self._v.value)
@value.setter
def value(self, tensor: JaxArray):
raise ValueError('Assignment of transformed variable is not allowed.')
v = objax.TrainVar(...)
positive_v = TransformedVar(v, lambda x: jn.abs(x)) At this point I would be reluctant to add such wrapper to the core Objax code, to keep codebase simple and concise. However you can easily add it to your own code and I think it should satisfy your needs. Moreover we can add a tutorial which will implement |
Actually you could make it trainable: class TrainVarTransform(objax.TrainVar):
def __init__(self, tensor: JaxArray, reduce: Optional[Callable[[JaxArray], JaxArray]] = lambda x: x.mean(0),
transform: Optional[Callable[[JaxArray], JaxArray]] = None):
super().__init__(tensor, reduce)
self.transform = transform
@property
def value(self) -> JaxArray:
if self.transform:
return self.transform(self._value)
return self._value You can inherit from |
This sounds totally reasonable, and I see your point and agree there is no need to add to the core library. @david-berthelot's idea is very elegant, and would be ideal, however I don't think it works quite as expected. The gradients seem to be off, maybe overriding the
I will take a closer look later, and if we find a nice solution I'd be happy to contribute a short tutorial. |
Feel free to share a code snippet if you need help. |
Here's the most basic working example I could come up with, adapting from the tutorials. I have defined The class SimpleModule(objax.Module):
def __init__(self, length):
self.v1 = objax.TrainVar(jn.log(jn.ones((length,))))
@property
def v1_positive(self):
return jn.exp(self.v1.value)
def __call__(self, x_):
return jn.dot(x_, self.v1_positive)
class SimpleModuleTransform(objax.Module):
def __init__(self, length):
self.v1 = TrainVarTransform(jn.log(jn.ones((length,))), transform=jn.exp)
def __call__(self, x_):
return jn.dot(x_, self.v1.value)
m = SimpleModule(3)
mt = SimpleModuleTransform(3)
x = jn.ones((3,))
y = jn.array((-1.0, 1.0))
def loss_fn(x_, y_):
return ((m(x_) - y_) ** 2).sum()
def loss_fnt(x_, y_):
return ((mt(x_) - y_) ** 2).sum()
print('loss_fn(x, y) = ', loss_fn(x, y))
module_vars = m.vars()
module_varst = mt.vars()
# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)
gvt = objax.GradValues(loss_fnt, module_varst)
# gv returns both gradients and values of original function
grads, value = gv(x, y)
gradst, valuet = gvt(x, y)
print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)
print('Gradients:')
for g, var_name in zip(gradst, module_varst.keys()):
print(g, ' w.r.t. ', var_name)
print()
print('Value: ', valuet) This runs, however the value and gradients for the second module don't match:
and if I were to JIT |
Okay I see what the problem is, the Instead, one could define a from typing import Optional, Callable
import jax.numpy as jn
import objax
from objax.typing import JaxArray
class TrainVarTransform(objax.TrainVar):
def __init__(self, tensor: JaxArray, reduce: Optional[Callable[[JaxArray], JaxArray]] = lambda x: x.mean(0),
transform: Optional[Callable[[JaxArray], JaxArray]] = None):
super().__init__(tensor, reduce)
self.transform = transform
@property
def tvalue(self) -> JaxArray:
if self.transform:
return self.transform(self._value)
return self._value
class SimpleModule(objax.Module):
def __init__(self, length):
self.v1 = objax.TrainVar(jn.log(jn.ones((length,))))
@property
def v1_positive(self):
return jn.exp(self.v1.value)
def __call__(self, x_):
return jn.dot(x_, self.v1_positive)
class SimpleModuleTransform(objax.Module):
def __init__(self, length):
self.v1 = TrainVarTransform(jn.log(jn.ones((length,))), transform=jn.exp)
def __call__(self, x_):
return jn.dot(x_, self.v1.tvalue)
m = SimpleModule(3)
mt = SimpleModuleTransform(3)
x = jn.ones((3,))
y = jn.array((-1.0, 1.0))
def loss_fn(x_, y_):
return ((m(x_) - y_) ** 2).sum()
def loss_fnt(x_, y_):
return ((mt(x_) - y_) ** 2).sum()
print('loss_fn(x, y) = ', loss_fn(x, y))
module_vars = m.vars()
module_varst = mt.vars()
# Construct a module which computes gradients
gv = objax.GradValues(loss_fn, module_vars)
gvt = objax.GradValues(loss_fnt, module_varst)
# gv returns both gradients and values of original function
grads, value = gv(x, y)
gradst, valuet = gvt(x, y)
print('Gradients:')
for g, var_name in zip(grads, module_vars.keys()):
print(g, ' w.r.t. ', var_name)
print()
print('Value: ', value)
print('Gradients:')
for g, var_name in zip(gradst, module_varst.keys()):
print(g, ' w.r.t. ', var_name)
print()
print('Value: ', valuet) That works but that's not very satisfying since now one has to use |
Seems there are some solution for transformed var, which do not even require modifications of Objax code. |
Hi,
Is it possible to declare constraints on trainable variables, e.g. forcing them to be positive via an exponential or softplus transformation?
In an ideal world, we would be able to write something like:
self.variance = objax.TrainVar(np.array(1.0), transform=positive)
Thanks,
Will
p.s. thanks for the great work on objax so far, it's a pleasure to use.
The text was updated successfully, but these errors were encountered: