Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

entmax_bisect bugs with fp16/bf16 #30

Open
bpopeters opened this issue Aug 7, 2023 · 1 comment
Open

entmax_bisect bugs with fp16/bf16 #30

bpopeters opened this issue Aug 7, 2023 · 1 comment
Assignees
Labels
bug Something isn't working

Comments

@bpopeters
Copy link
Collaborator

It would be great if entmax worked with torch.float16 and torch.bfloat16. Unfortunately, it currently does not. There are bugs for both bisection and the exact algorithm. Here I'll document a numerical stability problem that exists for the bisection-based algorithm for both torch.float16 and torch.bfloat16 (don't believe the propaganda that says that bf16 is a drop-in solution for float32).

Let's say you have a 32-bit vector of logits whose largest element is sufficiently negative.

a = torch.zeros(128, device="cuda").fill_(-5)  # torch.float32
a[0] = 0
a -= 1000

With alpha=1.5, the correct output for this vector is a one-hot distribution peaked on index 0. We get this behavior with both entmax.entmax15 and entmax.entmax_bisect.

p1 = entmax.entmax15(a)
p2 = entmax.entmax_bisect(a, alpha=1.5)

p1[0] == p2[0] == 1  # True

Ok, great. But what happens if we use torch.float16?

b = a.to(torch.float16)

p3 = entmax.entmax_bisect(b, alpha=1.5)
p3.isnan().all()  # True

and what about torch.bfloat16?

c = a.to(torch.bfloat16)

p4 = entmax.entmax_bisect(c, alpha=1.5)
p4.isnan().all()  # True

Well that's not good! (solution after this commercial break)

@bpopeters
Copy link
Collaborator Author

As it turns out, we can use the same solution for both torch.float16 and torch.bfloat16. This doesn't seem obvious because they offer different tradeoffs compared to full-precision floats: fp16 keeps the mantissa at the cost of reduced range, bf16 keeps the range at cost of reduced mantissa. But strangely enough we can, using the classic softmax stability trick of subtracting the largest logit from the vector. This makes intuitive sense for torch.float16 because of its reduced range, but what about for torch.bfloat16?

It turns out that bfloat16 has weird problems far away from zero:

x = torch.tensor(0, dtype=torch.bfloat16, device="cuda")
x == (x - 1)  # False, of course

y = torch.tensor(-500, dtype=torch.bfloat16, device="cuda")
y == (y - 1)  # True?!?

So that's why the softmax stability trick works. Bringing it back to our earlier examples:

b = a.to(torch.float16)
entmax.entmax_bisect(b - b.max(), alpha=1.5)  # one-hot!

c = a.to(torch.bfloat16)
entmax.entmax_bisect(c - c.max(), alpha=1.5)  # one-hot!

Tests and PR upcoming.

@bpopeters bpopeters self-assigned this Aug 7, 2023
@bpopeters bpopeters added the bug Something isn't working label Aug 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant