Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[activations] pytorch-1.11+ Tanh Gelu Approximation #15397

Open
stas00 opened this issue Jan 28, 2022 · 14 comments
Open

[activations] pytorch-1.11+ Tanh Gelu Approximation #15397

stas00 opened this issue Jan 28, 2022 · 14 comments
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Jan 28, 2022

🚀 Feature request

As kindly flagged by @vadimkantorov pt-1.11 will have a fast Tanh Gelu Approximation as implemented here pytorch/pytorch#61439 so we could replace our manual implementation with the fast one when pt>=1.11 is detected.

for additional context please see this thread: pytorch/pytorch#39853

@jaketae
Copy link
Contributor

jaketae commented Jan 29, 2022

Hi @stas00, may I take a look into this? Below is a list of to-do's I can think of, but of course, there could be more.

  1. Compare output diffs: Using the torch version should not affect model output.
  2. Review affected models/code: The upstream PR seems to be a drop-in replacement for HF transformers' gelu_new (replaced with F.gelu(x, approximate="tanh")), but more investigation might be needed to see if other gelu occurrences in this repo can be replaced.
  3. Performance benchmarks: Evaluate the speedup gains from using the new implementation.

What do you think?

@stas00
Copy link
Contributor Author

stas00 commented Jan 29, 2022

By all means, @jaketae - thank you!

the most important thing is numerical backward compatibility. Since transformers are used for research - we must not change the outputs even by a little bit - and any new changes that do change the outcomes need to be explicitly enabled. At least that's my understanding of the "policy" - I don't know though how to quantify how much is ok of a change. When I proposed to add torch jit to our activation functions which makes them more correct since they are then accumulated in fp32, I was told to create new functions instead. Hope that give you enough forewarning of what will be accepted and what not.

@jaketae
Copy link
Contributor

jaketae commented Jan 29, 2022

I definitely agree numerical BC is key here. I think we can have extensive tests using (1) random tensors as input, and (2) within the context of model forward and backward. I assume we'll also have to check for device and dtypes.

Should this issue be tabled until PyTorch 1.11 is released? IIRC the current stable is 1.10.2. Alternatively, I could use the nightly build to get started.

@stas00
Copy link
Contributor Author

stas00 commented Jan 30, 2022

It's your call, you can wait till pt-1.11 is released - probably in a month or so, or you can start with nightly get everything ready and merge it once it's released.

@huggingface huggingface deleted a comment from github-actions bot Feb 28, 2022
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Feb 28, 2022
@jaketae
Copy link
Contributor

jaketae commented Mar 14, 2022

PyTorch 1.11 is out! I could maybe get started on this if you haven't already @stas00?

@stas00
Copy link
Contributor Author

stas00 commented Mar 14, 2022

go for it, Jaesung!

@jaketae
Copy link
Contributor

jaketae commented Mar 14, 2022

Upon more investigation, I realized 1.11 didn't ship with hyperbolic tangent GELU. I'll use the nightly instead.

@jaketae
Copy link
Contributor

jaketae commented Mar 22, 2022

Apologies for the delay @stas00. Here is a quick update of where I am at the moment.

1. Understanding PyTorch's GELU Approximation

The implementation can be found here:

    if (approximate == GeluType::Tanh) {
    AT_DISPATCH_FLOATING_TYPES_AND(
        ScalarType::BFloat16, it.dtype(), "GeluKernelImpl", [&]() {
      using Vec = vec::Vectorized<scalar_t>;
      const Vec kBetaVec(scalar_t(M_SQRT2 * M_2_SQRTPI * 0.5));
      const Vec kKappaVec(scalar_t(0.044715));
      const Vec kOneVec(scalar_t(1));
      const Vec kPointFiveVec(scalar_t(0.5));
      cpu_kernel_vec(
          it,
          [](scalar_t x) {
            const scalar_t kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
            const scalar_t kKappa = 0.044715;
            auto x_cube = x * x * x;
            auto inner = kBeta * (x + kKappa * x_cube);
            return scalar_t(0.5) * x * (scalar_t(1) + std::tanh(inner));
          },

As noted in the docs, this boils down to

\text{GELU}(x) = 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3)))

HF transformers has a number of GELU implementations, but the one which corresponds to this specific variant appears to be ACT2FN["gelu_new"], which is implemented as

return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

Hence, I investigated whether the output of gelu_new(x) equals that of F.gelu(x, approximate="tanh").

2. Preliminary Experiments

A simple first-level check might be to generate a random tensor and compare the output of the two functions. Here is the experiment, with link to the Colab notebook.

NUM_TRIALS = 1000

cpu_equal = []
cpu_allclose = []

for _ in range(NUM_TRIALS):
    x_cpu = torch.randn(3, 3)
    torch_cpu = F.gelu(x_cpu, approximate="tanh")
    hf_cpu = gelu_new(x_cpu)
    cpu_equal.append(torch.equal(torch_cpu, hf_cpu))
    cpu_allclose.append(torch.allclose(torch_cpu, hf_cpu, rtol=1e-6))

print(average(cpu_equal))
print(average(cpu_allclose))

The same experiment was conducted with GPU by replacing x_cpu = torch.randn(3, 3) with x_gpu = torch.randn(3, 3, device="cuda"). There is no particular reason behind the choice of tensor size (3, 3); it seemed reasonably small enough to manually check tensor values.

Given an rtol of 1e-6 and 1000 iterations per device, here are the results:

cpu equal: 0.213
cpu allclose: 0.991
gpu equal: 0.952
gpu allclose: 0.997

Computations seem to be more robust on the GPU. In particular, torch.equal passes only in about 1 out of 5 runs on CPU.

3. Next Steps

Here is a non-exhaustive list of tasks.

  • Conduct more experiments with, different tensor sizes, rtol parameters, and dtype
  • Run basic experiments (e.g. training and inference) with some models with replaced activation functions to see if results are reproducible

Generally, my intuition is that replacing gelu_fast with the new PyTorch GELU is a risky move given that torch.equal does not pass in all cases. As you noted previously,

we must not change the outputs even by a little bit

Unless there is a performance overhead we are concerned with, I do not see a compelling reason to make what could be a dangerous transition.

@stas00
Copy link
Contributor Author

stas00 commented Mar 22, 2022

That's a great report, Jaesung! Thank you!

Well, it can't be equal since it's an approximation, so it's really about deciding on the atol/rtol values and whether the results are acceptable.

and of course you want to experiment with much larger tensors than 3x3 to come up with conclusions.

If I try with more realistic sizes I get 0 matches with close or equal:
https://colab.research.google.com/drive/1RrG_R_gdOI3c9ut4Emh-ausg7vwsStKy?usp=sharing

@stas00
Copy link
Contributor Author

stas00 commented Mar 22, 2022

@vadimkantorov, please have a look - the results are very different between the nightly fast version and the slow python approximation function.

@vadimkantorov
Copy link

I guess we need to tag @rdspring1 who authored pytorch/pytorch#61439...

@jaketae
Copy link
Contributor

jaketae commented Apr 8, 2022

Hey @rdspring1, could you kindly look into this? In particular, we've observed that the number of torch.allclose fail cases increases as the size of the input tensor gets larger.

In the meantime, @stas00, do you think there's something actionable on our end? Perhaps I could load a model and replace HF GELUs with the PyTorch approximation to see if model outputs differ (they most surely will). What I'm not sure about is how to quantify/analyze this difference, as it's difficult to say with certainty that an X amount of atol/rtol difference is "fine" (or not).

@stas00
Copy link
Contributor Author

stas00 commented Apr 8, 2022

Probably the creator of this feature would know the best expected tolerance - we will probably need to wait for their reply before we can proceed. Unless of course you'd like to dig into the source code and try to understand it yourself.

@rdspring1
Copy link

rdspring1 commented Apr 8, 2022

In my personal tests, I used torch.allclose(torch_cpu, hf_cpu, rtol=1e-6, atol=1e-6).

For pytorch internal testing, they use double, long, and complex128 for their numpy reference check.
https://github.com/pytorch/pytorch/blob/master/test/test_ops.py#L207-L215

Here is the numpy tanh gelu reference implementation:
https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_methods_invocations.py#L8238-L8250

Here are the tolerances used in pytorch.
https://github.com/pytorch/pytorch/blob/master/torch/testing/_comparison.py#L1182-L1213

For torch.float32, the default tolerances are rtol=1.3e-6 and atol=1e-5

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

4 participants