Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convit GPSA Layer 少了attention normalization, 跟原论文和torch versions不一样 #664

Open
wtomin opened this issue May 30, 2023 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@wtomin
Copy link

wtomin commented May 30, 2023

Mindcv 实现的GPSA layer代码get_attention 函数没有对attn 进行normalization:

    def get_attention(self, x: Tensor) -> Tensor:
        B, N, C = x.shape
        q = ops.reshape(self.q(x), (B, N, self.num_heads, C // self.num_heads))
        q = ops.transpose(q, (0, 2, 1, 3))
        k = ops.reshape(self.k(x), (B, N, self.num_heads, C // self.num_heads))
        k = ops.transpose(k, (0, 2, 3, 1))

        pos_score = self.pos_proj(self.rel_indices)
        pos_score = ops.transpose(pos_score, (0, 3, 1, 2))
        pos_score = self.softmax(pos_score)
        patch_score = self.batch_matmul(q, k)
        patch_score = ops.mul(patch_score, self.scale)
        patch_score = self.softmax(patch_score)

        gating = ops.reshape(self.gating_param, (1, -1, 1, 1))
        gating = ops.Sigmoid()(gating)
        attn = (1.0 - gating) * patch_score + gating * pos_score
        attn = self.attn_drop(attn)
        return attn

再看一个hugging face实现的torch 版本

    def get_attention(self, x):
        B, N, C = x.shape
        qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k = qk[0], qk[1]
        pos_score = self.rel_indices.expand(B, -1, -1, -1)
        pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
        patch_score = (q @ k.transpose(-2, -1)) * self.scale
        patch_score = patch_score.softmax(dim=-1)
        pos_score = pos_score.softmax(dim=-1)

        gating = self.gating_param.view(1, -1, 1, 1)
        attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
        attn /= attn.sum(dim=-1).unsqueeze(-1) # attention normalized by its sum
        attn = self.attn_drop(attn)
        return attn

虽然并不清楚这个normalization对performance的影响是大还是小,但是我认为最好跟原论文保持一致。

@wtomin wtomin added the bug Something isn't working label May 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants