You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
#PermuterFormer - P
q = q.gather(-1, self.permutation[:, :, :q.shape[2]].expand_as(q))
k = k.gather(-1, self.permutation[:, :, :k.shape[2]].expand_as(k))
# Apply the feature map to the queries and keys
Q = torch.nn.functional.elu(q) + 1
K = torch.nn.functional.elu(k) + 1
#PermuterFormer - r
Q *= (self.ratio.unsqueeze(-1) ** torch.arange(Q.shape[2], device=Q.device).unsqueeze(0)).unsqueeze(-1)
K *= ((1 / self.ratio).unsqueeze(-1) ** torch.arange(K.shape[2], device=K.device).unsqueeze(0)).unsqueeze(-1)
if mask is not None:
K.masked_fill_(mask.unsqueeze(1).unsqueeze(-1), 0.0)
# Compute the KV matrix
KV = torch.einsum("nhsd,nhsm->nhmd", K, v)
# Compute the normalizer
Z = 1/(torch.einsum("nhld,nhd->nlh", Q, K.sum(dim=2))+self.eps)
# Finally compute and return the new values
V = torch.einsum("nhld,nhmd,nlh->nlhm", Q, KV, Z)
But always got "nan" issue after 1~5 steps.
From my perspective, this may caused by this step:
A trivial solution is to set the ratio to something very close to 1 so that K does not explode as the sequence length grows.
As for the paper Transformers are RNNs, there is a better solution. Instead of multiplying K by large numbers and Q by small numbers, modifying equations (18) and (19) to s_i = r * s_{i-1} + ... and z_i = r * z_{i-1} + ... has the same effect. But AFAIK this involves modification of CUDA codes in the fast-transformers package.
Thanks for replying.
Still has another question,
In the paper, to ensure the similarity function depends only on the relative positions rather than absolute ones,
this property is a restriction on the result of matrix multiplication on q and k.
But it seems k and v will be multiplied at first, and then the q.
So, how to guarantee the restriction on q x k will works on q x (k x v)?
Thanks a lot.
Hi,
I tried this method to linear attention
http://proceedings.mlr.press/v119/katharopoulos20a/katharopoulos20a.pdf
as following code:
But always got "nan" issue after 1~5 steps.
From my perspective, this may caused by this step:
which multiply a very small number to Q and a very big number to K when the index is large.
Do I use the correct integration way? Or any suggestion for this?
Thanks.
The text was updated successfully, but these errors were encountered: