Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The code is different from the original paper #54

Open
lmisssunl opened this issue Nov 30, 2020 · 2 comments
Open

The code is different from the original paper #54

lmisssunl opened this issue Nov 30, 2020 · 2 comments

Comments

@lmisssunl
Copy link

lmisssunl commented Nov 30, 2020

Hello, first of all thank you for your code, let me have a deeper study of SAGAN, but after reading your code I have the following questions:

  1. In the original text, the relationship between β j, i and h(x) is N∑(i=1) βj, ih(xi), which is not the result of multiplying the torch.bmm matrix in your code
  2. The third 1*1 convolution in the article (i.e. h(x)), the dimension obtained is not C, but C//8
  3. The last layer in the original text has a 1 * 1 convolution, but it is not reflected in your code (maybe because of the second item I just mentioned, you changed the output C//8 to C, the last layer of 1 * 1 convolution is not ignored)
@lmisssunl lmisssunl reopened this Nov 30, 2020
@valillon
Copy link

valillon commented May 17, 2021

Yep, also found those discrepancies!
This could solve it.
Also make sure softmax iterates over keys for a given query (dim=1).

class Self_Attention(nn.Module):
    def __init__(self, inChannels, k=8):
        super(Self_Attention, self).__init__()
        embedding_channels = inChannels // k  # C_bar
        self.key      = nn.Conv2d(inChannels, embedding_channels, 1)
        self.query    = nn.Conv2d(inChannels, embedding_channels, 1)
        self.value    = nn.Conv2d(inChannels, embedding_channels, 1)
        self.self_att = nn.Conv2d(embedding_channels, inChannels, 1)
        self.gamma    = nn.Parameter(torch.tensor(0.0))
        self.softmax  = nn.Softmax(dim=1)

    def forward(self,x):
        """
            inputs:
                x: input feature map [Batch, Channel, Height, Width]
            returns:
                out: self attention value + input feature
                attention: [Batch, Channel, Height, Width]
        """
        batchsize, C, H, W = x.size()
        N = H * W                                       # Number of features
        f_x = self.key(x).view(batchsize,   -1, N)      # Keys                  [B, C_bar, N]
        g_x = self.query(x).view(batchsize, -1, N)      # Queries               [B, C_bar, N]
        h_x = self.value(x).view(batchsize, -1, N)      # Values                [B, C_bar, N]

        s =  torch.bmm(f_x.permute(0,2,1), g_x)         # Scores                [B, N, N]
        beta = self.softmax(s)                          # Attention Map         [B, N, N]

        v = torch.bmm(h_x, beta)                        # Value x Softmax       [B, C_bar, N]
        v = v.view(batchsize, -1, H, W)                 # Recover input shape   [B, C_bar, H, W]
        o = self.self_att(v)                            # Self-Attention output [B, C, H, W]
        
        y = self.gamma * o + x                          # Learnable gamma + residual
        return y, o

@valillon
Copy link

Apparently, as mentioned here, max pooling inside the attention layer is just motivated by design-wise to save computation/memory overhead. This should close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants