Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

the missing parameter in jax.nn.hard_tanh actiavtion #17632

Closed
samthakur587 opened this issue Sep 16, 2023 · 16 comments
Closed

the missing parameter in jax.nn.hard_tanh actiavtion #17632

samthakur587 opened this issue Sep 16, 2023 · 16 comments
Assignees
Labels
bug Something isn't working

Comments

@samthakur587
Copy link

Description

@jax.jit
def hard_tanh(x: Array) -> Array:
  r"""Hard :math:`\mathrm{tanh}` activation function.

  Computes the element-wise function:

  .. math::
    \mathrm{hard\_tanh}(x) = \begin{cases}
      -1, & x < -1\\
      x, & -1 \le x \le 1\\
      1, & 1 < x
    \end{cases}

  Args:
    x : input array

  Returns:
    An array.
  """
  return jnp.where(x > 1, 1, jnp.where(x < -1, -1, x))

In the implementation of hard_tanh the min_val and max_val should varible. so the the linear region should not be constent in between -1 to +1. this should be taken input when we call this method to set the region of liner in hard_tanh function. if in hard_tanh function we can change the liner region range that will much helpful rather then just taking the constant range -1 to 1.
as it mention in this paper

This is modified code to explain the issue .

@jax.jit
def hard_tanh(x: Array , min_val: Array = 1.0, max_val: Array = -1.0 ) -> Array:
  r"""Hard :math:`\mathrm{tanh}` activation function.

  Computes the element-wise function:

  .. math::
    \mathrm{hard\_tanh}(x) = \begin{cases}
      -1, & x < -1\\
      x, & -1 \le x \le 1\\
      1, & 1 < x
    \end{cases}

  Args:
    x : input array
   min_val = array or scalar (default: -1.0)
   max_val = array or scalar (default:  1.0)
  
Returns:
    An array.
  """
  return jnp.where(x > max_val, max_val, jnp.where(x < min_val, min_val, x))

What jax/jaxlib version are you using?

latest

Which accelerator(s) are you using?

cpu

Additional system info

python 3.9

NVIDIA GPU info

No response

@samthakur587 samthakur587 added the bug Something isn't working label Sep 16, 2023
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 20, 2023

I think the intent here is that if you want to scale the hard_tanh cutoffs, you can do so by scaling the input. For example:

x = jnp.linspace(-5, 5)
cutoff
y1 = cutoff * jax.nn.hard_tanh(x / cutoff)
y2 = hard_tanh(x, min_val=-cutoff, max_val=cutoff)
jnp.allclose(y1, y2)
# True

This is similar, for example, to how you would scale sigmoid and other activation functions.

What do you think?

@jakevdp jakevdp self-assigned this Sep 20, 2023
@samthakur587
Copy link
Author

Hi! @jakevdp sorry for the late reply. yeah you are right we can use the input scaling method to change the liner region range in hardtanh activation function similar to the other activation function ex. sigmoid and step threshold.

what if we want to different min_val and maxval. like we want a actiavtion function that is relu but it has h upper bound max_val and lower bound minvalue =0. then i think the hardtanh with max_val and min_val is a better option.

x = jnp.linspace(-5, 5)
cutoff =2
y1 = cutoff * jax.nn.hard_tanh(x / cutoff)
y2 = hard_tanh(x, min_val=0, max_val=cutoff)
jnp.allclose(y1, y2)
# True

and i think the by jsut adding the max_val and min_val value in the hardtanh function is computational chip then scaling the input. for this all i have resied this issue.

Thanks

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 25, 2023

Thanks – if you want to change the central point of the activation function, you could do nn.hard_tanh(x - offset). I know it's a different parameterization than expressing it via min_val and max_val, but it should work without needing to change the function's API.

I know adding min_val and max_val would be an easy change to make, but I'm not totally clear on why one would need them when other activation functions in jax.nn don't have similar parameters. Do you have examples of a reference that uses a hard tanh activation while separately tuning these parameters?

@samthakur587
Copy link
Author

Hi! @jakevdp I have just implemented the hardtanh activation in ivy Unified AI where i have implemented this for all the backend numpy , tensorflow, torch and jax. i found that the torch implementation is required the min_val & max_val here the docs.

So if in jax we have these parameter then it will be very helpful to change the backend in from jax to torch and same for other backend.

Adv. of parameter >

  1. required the less no. of operation then scaling the input and add cutoff -
    in the scaling we are first scaling the input then multiply the hardtanh to cutoff. so it has no. off operation grater then the
    parameter implementation.
  2. less runtime -
    so as the no. off operation is less in the parameter implementation so the time of computation less then the scaling implementation. as the 'hardtanh' is mostly used in NLP task where we have deep RNN , LSTM, and Transformer which required large number of operation. the computation cost can effect the training. hers i have calculated the time difference.

Screenshot 2023-09-27 010349

  1. framework dependency:
    it will be helpful to keep the parameter same for all framework like the torch and jax has the same parameter then the use of this function make easy for someone who know the torch and don't familier with jax.

you can keep this same but this is reason why i have raised this issue. I have also learned a lot with this issue as you guided. hope this contribution is valuable for you also. 😄

Thanks
SamThakur

@jewillco
Copy link

A more general version of hard_tanh is jax.numpy.clip; it allows arbitrary upper and lower bounds. You could implement the more general hard_tanh using that as hard_tanh(x, limit) = jax.numpy.clip(x, -limit, limit).

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 11, 2023

Thanks for the links to pytorch equivalents. I must say find it strange that Hardtanh has these parameters, when similar activation functions like Tanh and Hardsigmoid do not. It suggests to me that there is not much need for adding such knobs to activation functions in general.

Do you know what it is that makes hard_tanh special in this respect? Is there some use-case you have where you need to tune the limits of hard_tanh? Or are you mainly interested in adding this parameter because it exists in the equivalent pytorch function?

@samthakur587
Copy link
Author

Hi! @jakedp yeah we can implement the hardtanh with without having this parameter. but these parameter have some benefits while creating the on frontend function for with 2 or more different backend support like jax, torch and etc. in these case we have to pass some variable with the input tensor or array like you suggest we can do with cutoff =2 to change the min and max value of the hardtanh. so we have to make the function like this hardtanh(input , cutoff=1) here we also have to handle the case like when we have different min and max value. then why not we just add these to parameter min and max and save the operation and case to handle. like this I have just implemented the hardtanh for the ivy experimental api here i am also not able to use your 'jax.nn.hardtanh()' api because here we need these two parameter to scale the linear region range.

And also the hardtanh is made bcz it is computationally cheaper. It does however saturate for magnitudes of x greater than 1. and by scaling the tanh to mare then +1 and -1 not make any sense. bcz we are making it more computationlly inefficient. and same as the hardsigmoid.

here are the some more references for hardtanh using these parameter.

  1. paddleapddle : docs
  2. mlpack : PR/issue
  3. pytorch : docs

if you are intrested to add this parameter in jax also. then i am ready to make the PR.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 12, 2023

Thanks – I'm still curious about two things:

  1. why is hard_tanh the only activation function for which this kind of flexibility is implemented? Why is this necessary when no similar activation functions have such parameterizations?
  2. are you interested in this because of any particular use-case, or is it only because pytorch and others have this parameterization? Is there any application for which this is important?

Regarding the computational efficiency argument: I'm not totally convinced that's relevant. The XLA compiler can generally fuse operations like this to the point where different ways of expressing things lower to the same efficient operation at runtime.

@jewillco
Copy link

Why not use jax.numpy.clip instead? That will already do everything you want.

@jakevdp
Copy link
Collaborator

jakevdp commented Oct 12, 2023

To be clear, I'm not against adding min/max parameterizations to jax.nn.hard_tanh. But it increases JAX's API surface & testing surface (if only marginally), and I'd like to understand which users would benefit. This function has existed for four years (since #1262) and to my knowledge nobody has asked for this sort of parameterization before.

@jewillco
Copy link

Wouldn't this new version of hard_tanh just be an alias of clip with defaults for the a_min and a_max parameters? That would not be hard to do if people wanted to do it, though.

@samthakur587
Copy link
Author

@ jakevdp the resone why i have raised this issue. bcz i am not able to call the jax.nn.hardtanh api while implementing this PR. bcz i have to pass atleast some parameter to change the range. so it will be benefit in my case bcz any other backend don't have this problem.

@jewillco
Copy link

Why not call jax.numpy.clip instead?

@samthakur587
Copy link
Author

Hi! @jakevdp and @jewillco thanks for your guide i understand without these parameter this will work fine and no need to change and i also don't want to increases JAX's API surface & testing surface (if only marginally).

Thanks can i close this issue ?

@jewillco
Copy link

I think you can implement exactly what you want (including the extra parameters you desire) just by calling clip rather than hard_tanh; that way, you get all of the functionality you want for your code and no changes are needed in JAX.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 8, 2023

I think we can close this issue – thanks for the discussion!

@jakevdp jakevdp closed this as completed Nov 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants