New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
The code is different from the original paper #54
Comments
Yep, also found those discrepancies! class Self_Attention(nn.Module):
def __init__(self, inChannels, k=8):
super(Self_Attention, self).__init__()
embedding_channels = inChannels // k # C_bar
self.key = nn.Conv2d(inChannels, embedding_channels, 1)
self.query = nn.Conv2d(inChannels, embedding_channels, 1)
self.value = nn.Conv2d(inChannels, embedding_channels, 1)
self.self_att = nn.Conv2d(embedding_channels, inChannels, 1)
self.gamma = nn.Parameter(torch.tensor(0.0))
self.softmax = nn.Softmax(dim=1)
def forward(self,x):
"""
inputs:
x: input feature map [Batch, Channel, Height, Width]
returns:
out: self attention value + input feature
attention: [Batch, Channel, Height, Width]
"""
batchsize, C, H, W = x.size()
N = H * W # Number of features
f_x = self.key(x).view(batchsize, -1, N) # Keys [B, C_bar, N]
g_x = self.query(x).view(batchsize, -1, N) # Queries [B, C_bar, N]
h_x = self.value(x).view(batchsize, -1, N) # Values [B, C_bar, N]
s = torch.bmm(f_x.permute(0,2,1), g_x) # Scores [B, N, N]
beta = self.softmax(s) # Attention Map [B, N, N]
v = torch.bmm(h_x, beta) # Value x Softmax [B, C_bar, N]
v = v.view(batchsize, -1, H, W) # Recover input shape [B, C_bar, H, W]
o = self.self_att(v) # Self-Attention output [B, C, H, W]
y = self.gamma * o + x # Learnable gamma + residual
return y, o |
Apparently, as mentioned here, max pooling inside the attention layer is just motivated by design-wise to save computation/memory overhead. This should close the issue. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hello, first of all thank you for your code, let me have a deeper study of SAGAN, but after reading your code I have the following questions:
The text was updated successfully, but these errors were encountered: