In [1]:
import torch
from torch import nn
from torch.nn import functional as F

# IMPLEMENTATION BILINEAR ATTENTION FOR SHIFT NET

In [2]:
x = torch.rand((256, 32, 32)).cuda()
y = torch.rand((256, 32, 32)).cuda()

In [3]:
stride = 1
patch_size = 1

In [4]:
def _unfold(img, patch_size, stride, with_indexes=False):
    n_dim = 3
    assert img.dim() == n_dim, 'image must be of dimension 3.'

    kH, kW = patch_size, patch_size
    dH, dW = stride, stride
    input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)

    i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(
        2), input_windows.size(3), input_windows.size(4)

    if with_indexes:
        input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1)
        return input_windows, i_2, i_3, i_1, i_4
    else:
        input_windows = input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1, i_4, i_5)
    return input_windows

In [5]:
x_unfolded = _unfold(x, patch_size, stride)
y_unfolded = _unfold(y, patch_size, stride)

In [6]:
def _filter(input_windows, flag, value):
    ## EXTRACT MASK OR NOT DEPENDING ON VALUE
    input_window = input_windows[flag == value]
    return input_window.view(input_window.size(0), -1)

In [7]:
flag = torch.rand((1024)) > 0.75

In [8]:
x_unfolded_filtered = _filter(x_unfolded, flag, 1)
y_unfolded_filtered = _filter(y_unfolded, flag, 0)

In [9]:
x_unfolded_filtered.shape

torch.Size([264, 256])

In [10]:
y_unfolded_filtered.shape

torch.Size([760, 256])

In [39]:
def create_var(x, y, dim=1024, _v=1):
    N = K = x.size(0)
    M = y.size(0)
    # VERSION 1, REAL IMPLEMENTATION
    if _v == 0:
        U = torch.randn((N, K)).cuda()
        V = torch.randn((M, K)).cuda()
    # VERSION 2, FOR ME :)
    else:
        U = torch.randn((dim, dim)).cuda()
        V = torch.randn((dim, dim)).cuda()

    return U, V

In [40]:
%time U, V = create_var(x_unfolded_filtered, y_unfolded_filtered)

Wall time: 484 ms


In [41]:
print(U.shape)
print(V.shape)

torch.Size([1024, 1024])
torch.Size([1024, 1024])


In [79]:
def BAN(x, y, U, V):
    def _bilinear_attention_map(x, y, U, V, v1, P):
        print(v1.shape, P.shape, 'i,j->ij')
        tmp = torch.einsum('i,j->ij', v1, P)
        print(tmp.shape)
        print()
        
        print(x.shape, U.shape, 'ki,kj->ij')
        XT_U = torch.einsum('ki,kj->ij', [x, U])
        print(XT_U.shape)
        print()
        
        print(tmp.shape, XT_U.shape, 'ij,ij->ij')
        tmp1 = torch.einsum('ij,ij->ij', [tmp, XT_U])
        print(tmp1.shape)
        print()
        
        print(V.shape, y.shape, 'ki,kj->ij')
        VT_y = torch.einsum('ki,kj->ij', [V, y])
        print(VT_y.shape)
        print()
        
        print(tmp1.shape, VT_y.shape, 'ik,kj->ij')
        tmp = torch.einsum('ik,kj->ij', [tmp1, VT_y])
        print(tmp.shape)
        print()
        
        tmp /= tmp.max()
        
        A = F.softmax(tmp, dim=1)   
        print(A.shape)
        print()
        return A, XT_U, VT_y
    
    def _filter_M(U,V,flag):
        mask_indexes = (flag == 1).nonzero().t()
        non_mask_indexes = (flag == 0).nonzero()
        U = U[mask_indexes.t(), mask_indexes]
        V = V[non_mask_indexes, mask_indexes]
        return U, V
    
    def _emulate(x, y):
        flag = torch.rand((1024)) > 0.75
        x_unfolded = _unfold(x, patch_size, stride)
        y_unfolded = _unfold(y, patch_size, stride)
        X = _filter(x_unfolded, flag, 1)
        Y = _filter(y_unfolded, flag, 0)  
        
        N = K = X.size(0)
        M = Y.size(0)
        p = X.size(1)
        r = Y.size(1)
        
        P = torch.randn((K)).cuda()
        v1 = torch.randn((p)).cuda()
        return X, Y, P, v1, flag

    ##EMULATE NEW MASK GENERATED
    print('EMULATE')
    X, Y, P, v1, flag = _emulate(x, y)
    U, V = _filter_M(U, V, flag)
    print('EMULATE')
    print()
    
    A, XT_U, VT_y = _bilinear_attention_map(X, Y, U, V, v1, P) 
    print(A.shape, XT_U.shape, VT_y.shape, 'ki,kk,jk->i')
    A = torch.einsum('ki,kk,jk->i', [XT_U, A, VT_y])
    print(A.shape, v1.shape, 'i,j->ij')
    A = torch.einsum('i,j->ij', [A, v1])
    #A /= A.max()
    return A

In [80]:
%time out = BAN(x, y, U, V)

EMULATE
EMULATE

torch.Size([256]) torch.Size([259]) i,j->ij
torch.Size([256, 259])

torch.Size([259, 256]) torch.Size([259, 259]) ki,kj->ij
torch.Size([256, 259])

torch.Size([256, 259]) torch.Size([256, 259]) ij,ij->ij
torch.Size([256, 259])

torch.Size([765, 259]) torch.Size([765, 256]) ki,kj->ij
torch.Size([259, 256])

torch.Size([256, 259]) torch.Size([259, 256]) ik,kj->ij
torch.Size([256, 256])

torch.Size([256, 256])

torch.Size([256, 256]) torch.Size([256, 259]) torch.Size([259, 256]) ki,kk,jk->i
torch.Size([259]) torch.Size([256]) i,j->ij
Wall time: 9.97 ms


In [92]:
# WITHOUT THE PRINT TO PERFORM CORRECT TIME BENCHMARK
def BAN(x, y, U, V):
    def _bilinear_attention_map(x, y, U, V, v1, P):
        tmp = torch.einsum('i,j->ij', v1, P)
        XT_U = torch.einsum('ki,kj->ij', [x, U])
        tmp1 = torch.einsum('ij,ij->ij', [tmp, XT_U])
        VT_y = torch.einsum('ki,kj->ij', [V, y])
        tmp = torch.einsum('ik,kj->ij', [tmp1, VT_y])
        tmp /= tmp.max()
        A = F.softmax(tmp, dim=1)   
        return A, XT_U, VT_y
    
    def _filter_M(U, V, flag):
        mask_indexes = (flag == 1).nonzero().t()
        non_mask_indexes = (flag == 0).nonzero()
        U = U[mask_indexes.t(), mask_indexes]
        V = V[non_mask_indexes, mask_indexes]
        return U, V
    
    def _emulate(x, y, U, V):
        flag = torch.rand((1024)) > 0.75
        x_unfolded = _unfold(x, patch_size, stride)
        y_unfolded = _unfold(y, patch_size, stride)
        X = _filter(x_unfolded, flag, 1)
        Y = _filter(y_unfolded, flag, 0)  
        
        N = K = X.size(0)
        M = Y.size(0)
        p = X.size(1)
        r = Y.size(1)
        
        P = torch.randn((K)).cuda()
        v1 = torch.randn((p)).cuda()
        U, V = _filter_M(U, V, flag)
        return X, Y, P, v1, flag, U, V

    X, Y, P, v1, flag, u, v = _emulate(x, y, U, V)
    
    A, XT_U, VT_y = _bilinear_attention_map(X, Y, u, v, v1, P) 
    A = torch.einsum('ki,kk,jk->i', [XT_U, A, VT_y])
    A = torch.einsum('i,j->ij', [A, v1])
    #A /= A.max()
    return A

In [93]:
%timeit BAN(x, y, U, V)

3.63 ms ± 220 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [94]:
out.shape

torch.Size([259, 256])

In [95]:
out

tensor([[ -779.3818,   213.2228,  -803.4930,  ...,  -116.9878,   656.7700,
           198.3965],
        [-1502.6038,   411.0814, -1549.0889,  ...,  -225.5457,  1266.2153,
           382.4972],
        [ 1524.9980,  -417.2081,  1572.1760,  ...,   228.9072, -1285.0865,
          -388.1978],
        ...,
        [ 2398.3250,  -656.1323,  2472.5205,  ...,   359.9965, -2021.0225,
          -610.5086],
        [-2057.4216,   562.8682, -2121.0708,  ...,  -308.8257,  1733.7496,
           523.7295],
        [ -151.3851,    41.4158,  -156.0684,  ...,   -22.7234,   127.5693,
            38.5360]], device='cuda:0')

# MOST OF THE TIMES WAS SPENT IN CREATING THE BIG U AND V MATRICES