New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
the missing parameter in jax.nn.hard_tanh actiavtion #17632
Comments
I think the intent here is that if you want to scale the x = jnp.linspace(-5, 5)
cutoff
y1 = cutoff * jax.nn.hard_tanh(x / cutoff)
y2 = hard_tanh(x, min_val=-cutoff, max_val=cutoff)
jnp.allclose(y1, y2)
# True This is similar, for example, to how you would scale What do you think? |
Hi! @jakevdp sorry for the late reply. yeah you are right we can use the input scaling method to change the liner region range in hardtanh activation function similar to the other activation function ex. sigmoid and step threshold. what if we want to different min_val and maxval. like we want a actiavtion function that is relu but it has h upper bound max_val and lower bound minvalue =0. then i think the hardtanh with max_val and min_val is a better option. x = jnp.linspace(-5, 5)
cutoff =2
y1 = cutoff * jax.nn.hard_tanh(x / cutoff)
y2 = hard_tanh(x, min_val=0, max_val=cutoff)
jnp.allclose(y1, y2)
# True and i think the by jsut adding the max_val and min_val value in the hardtanh function is computational chip then scaling the input. for this all i have resied this issue. Thanks |
Thanks – if you want to change the central point of the activation function, you could do I know adding |
Hi! @jakevdp I have just implemented the So if in jax we have these parameter then it will be very helpful to change the backend in from jax to torch and same for other backend. Adv. of parameter >
you can keep this same but this is reason why i have raised this issue. I have also learned a lot with this issue as you guided. hope this contribution is valuable for you also. 😄 Thanks |
A more general version of |
Thanks for the links to pytorch equivalents. I must say find it strange that Hardtanh has these parameters, when similar activation functions like Tanh and Hardsigmoid do not. It suggests to me that there is not much need for adding such knobs to activation functions in general. Do you know what it is that makes |
Hi! @jakedp yeah we can implement the hardtanh with without having this parameter. but these parameter have some benefits while creating the on frontend function for with 2 or more different backend support like jax, torch and etc. in these case we have to pass some variable with the input tensor or array like you suggest we can do with And also the hardtanh is made bcz it is computationally cheaper. It does however saturate for magnitudes of x greater than 1. and by scaling the tanh to mare then +1 and -1 not make any sense. bcz we are making it more computationlly inefficient. and same as the hardsigmoid. here are the some more references for hardtanh using these parameter. if you are intrested to add this parameter in jax also. then i am ready to make the PR. |
Thanks – I'm still curious about two things:
Regarding the computational efficiency argument: I'm not totally convinced that's relevant. The XLA compiler can generally fuse operations like this to the point where different ways of expressing things lower to the same efficient operation at runtime. |
Why not use |
To be clear, I'm not against adding |
Wouldn't this new version of |
@ jakevdp the resone why i have raised this issue. bcz i am not able to call the jax.nn.hardtanh api while implementing this PR. bcz i have to pass atleast some parameter to change the range. so it will be benefit in my case bcz any other backend don't have this problem. |
Why not call |
I think you can implement exactly what you want (including the extra parameters you desire) just by calling |
I think we can close this issue – thanks for the discussion! |
Description
In the implementation of hard_tanh the min_val and max_val should varible. so the the linear region should not be constent in between -1 to +1. this should be taken input when we call this method to set the region of liner in hard_tanh function. if in hard_tanh function we can change the liner region range that will much helpful rather then just taking the constant range -1 to 1.
as it mention in this paper
This is modified code to explain the issue .
What jax/jaxlib version are you using?
latest
Which accelerator(s) are you using?
cpu
Additional system info
python 3.9
NVIDIA GPU info
No response
The text was updated successfully, but these errors were encountered: