In [1]:
import torch

In [11]:
from attention import _NonLocalBlockND

In [97]:
a = torch.randn(4, 128, 1024)

In [84]:
class NONLocalBlock1D_mutual(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D_mutual, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)

        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                nn.BatchNorm1d(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)        
        
    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)
        n_point = x.size(2)
        data_1 = x[0:1] # Serve as key

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(data_1).view(1, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)

        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)
        

        # fg_attention = torch.max(f_div_C, dim=1)[0][:,:,None].repeat(1, 1, n_point)
        fg_attention = torch.mean(f_div_C, dim=1)[:,:,None].repeat(1, 1, n_point)
        bg_attention = 1 - fg_attention
    
        fg_attention_features = torch.matmul(fg_attention, g_x)
        bg_attention_features = torch.matmul(bg_attention, g_x)

        y = torch.cat([fg_attention_features, bg_attention_features], dim=2)
#         y = (fg_attention_features + bg_attention_features) / 2
#         y = fg_attention_features
#         y = y.permute(0, 2, 1).contiguous()
#         y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y.permute(0, 2, 1).contiguous())
        z = W_y + x

        
        return z


In [88]:
class NONLocalBlock1D_mutual(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D_mutual, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)

        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                nn.BatchNorm1d(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)        
        
    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)
        n_point = x.size(2)
        data_1 = x[0:1] # Serve as key # (1, 128, 10224)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1) # (4, 1024, 128)

        theta_x = self.theta(data_1).view(1, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1) # (1, 1024, 128)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) # (4, 128, 512)

        f = torch.matmul(theta_x, phi_x)   #  (1, 1024, 128)  dot (4, 128, 512) = (4,1024 (theta), 512(phi) )
        f_div_C = F.softmax(f, dim=-1) # sum along phi = 1
        

        # fg_attention = torch.max(f_div_C, dim=1)[0][:,:,None].repeat(1, 1, n_point)
        fg_attention = torch.mean(f_div_C, dim=1)[:,:,None].repeat(1, 1, n_point)
        bg_attention = 1 - fg_attention
    
        fg_attention_features = torch.matmul(fg_attention, g_x)
        bg_attention_features = torch.matmul(bg_attention, g_x)

        y = torch.cat([fg_attention_features, bg_attention_features], dim=2)
#         y = (fg_attention_features + bg_attention_features) / 2
#         y = fg_attention_features
#         y = y.permute(0, 2, 1).contiguous()
#         y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y.permute(0, 2, 1).contiguous())
        z = W_y + x
        return z

In [94]:
mutual = NONLocalBlock1D_mutual(128, sub_sample=False, bn_layer=True)

In [95]:
mutual(a).shape

torch.Size([4, 128, 1024])

In [96]:
a.shape

torch.Size([4, 128, 1024])

In [None]:
class NONLocalBlock1D_bmutual(_NonLocalBlockND):
    def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
        super(NONLocalBlock1D_mutual, self).__init__(in_channels,
                                              inter_channels=inter_channels,
                                              dimension=1, sub_sample=sub_sample,
                                              bn_layer=bn_layer)

        if bn_layer:
            self.W = nn.Sequential(
                nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                        kernel_size=1, stride=1, padding=0),
                nn.BatchNorm1d(self.in_channels)
            )
            nn.init.constant_(self.W[1].weight, 0)
            nn.init.constant_(self.W[1].bias, 0)
        else:
            self.W = nn.Conv1d(in_channels=self.inter_channels*2, out_channels=self.in_channels,
                             kernel_size=1, stride=1, padding=0)
            nn.init.constant_(self.W.weight, 0)
            nn.init.constant_(self.W.bias, 0)        
        
    def forward(self, x):
        '''
        :param x: (b, c, t, h, w)
        :return:
        '''

        batch_size = x.size(0)
        n_point = x.size(2)
        query_data = x[1:] # Serve as query # (6, 128, 1024)

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x[0] # (1024, 128)

        theta_x = self.theta(data_1).view(1, self.inter_channels, -1)
        theta_x = theta_x.view(-1, self.inter_channels) # (1024*(b-1), 128)

        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)[0] # (1024, 128)
        phi_x = phi_x.permute(1, 0)

        f = torch.matmul(theta_x, phi_x)   #  (1024*(b-1), 128)  dot (1024, 128) = (1024*(b-1), 1024)
        f_div_C = F.softmax(f, dim=-1) # sum along phi = 1
        

        # fg_attention = torch.max(f_div_C, dim=1)[0][:,:,None].repeat(1, 1, n_point)
        fg_attention = torch.mean(f_div_C, dim=1)[:,:,None].repeat(1, 1, n_point)
        bg_attention = 1 - fg_attention
    
        fg_attention_features = torch.matmul(fg_attention, g_x)
        bg_attention_features = torch.matmul(bg_attention, g_x)

        y = torch.cat([fg_attention_features, bg_attention_features], dim=2)
#         y = (fg_attention_features + bg_attention_features) / 2
#         y = fg_attention_features
#         y = y.permute(0, 2, 1).contiguous()
#         y = y.view(batch_size, self.inter_channels, *x.size()[2:])
#         W_y = self.W(y.permute(0, 2, 1).contiguous())
#         z = W_y + x
        return y

In [None]:

        g_x = self.g(x).view(batch_size, self.inter_channels, -1)
        g_x = g_x.permute(0, 2, 1)

        theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
        theta_x = theta_x.permute(0, 2, 1)
        phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
        f = torch.matmul(theta_x, phi_x)
        f_div_C = F.softmax(f, dim=-1)

        y = torch.matmul(f_div_C, g_x)
        y = y.permute(0, 2, 1).contiguous()
        y = y.view(batch_size, self.inter_channels, *x.size()[2:])
        W_y = self.W(y)
        z = W_y + x

In [92]:
anchor = a[0:1]

In [21]:
ref = a

In [24]:
ref.permute(0,2,1).shape

torch.Size([4, 1024, 128])

In [31]:
d = F.softmax(torch.matmul(anchor, ref.permute(0,2,1)), 1)

In [39]:
torch.mean(d, 1)

tensor([0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078, 0.0078,
        0.0078, 0.0078, 0.0078, 0.0078, 

In [51]:
mutual(a).shape

RuntimeError: shape '[4, 64, 1024]' is invalid for input of size 8388608