Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Usage of alpha #10

Closed
alibabadoufu opened this issue Oct 13, 2019 · 3 comments
Closed

Usage of alpha #10

alibabadoufu opened this issue Oct 13, 2019 · 3 comments

Comments

@alibabadoufu
Copy link

Hi,

May I know if we need to define a new trainable parameter for each head per layer for the alpha value? Could anyone be kind enough to show a simple example of how it could be used in normal transformer?

Thanks!

@goncalomcorreia
Copy link

goncalomcorreia commented Oct 14, 2019

Hi,

We did define a trainable parameter for each head per layer. Here's a small code snippet that we used:

class AlphaChooser(torch.nn.Module):

    def __init__(self, head_count):
        """head_count (int): number of attention heads"""
        super(AlphaChooser, self).__init__()
        self.pre_alpha = nn.Parameter(torch.randn(head_count))

    def forward(self):
        alpha = 1 + torch.sigmoid(self.pre_alpha)
        return torch.clamp(alpha, min=1.01, max=2)

However, it's possible to have a single alpha per layer, or per transformer block!

@alibabadoufu
Copy link
Author

Thanks so much! @goncalomcorreia
Nice work though!

@hihihihiwsf
Copy link

Hello!

How can I compute the entmax_bisect when the size of alpha is larger than 1 ?
When I use this:
p_attn = entmax_bisect(x, alpha, n_iter=25)
where alpha=tensor([1.3691, 1.5766, 1.7588, 1.9206],grad_fn=), the shape of x is [batch_size, 4, d,d].
There arises the error that
The expanded size of the tensor (1) must match the existing size (4) at non-singleton dimension 3.

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants