In [1]:
import torch
from torch import nn
from torch.nn import functional as F

# IMPLEMENTATION BILINEAR ATTENTION FOR SHIFT NET

In [2]:
x = torch.randn((256, 32, 32)).cuda()
y = torch.randn((256, 32, 32)).cuda()

In [3]:
stride = 1
patch_size = 1

In [4]:
def _unfold(img, patch_size, stride):
    n_dim = 3
    assert img.dim() == n_dim, 'image must be of dimension 3.'

    kH, kW = patch_size, patch_size
    dH, dW = stride, stride
    input_windows = img.unfold(1, kH, dH).unfold(2, kW, dW)

    i_1, i_2, i_3, i_4, i_5 = input_windows.size(0), input_windows.size(1), input_windows.size(
        2), input_windows.size(3), input_windows.size(4)

    return input_windows.permute(1, 2, 0, 3, 4).contiguous().view(i_2 * i_3, i_1, i_4, i_5)

In [5]:
#torch.einsum('abcde->bcade', input_windows)

In [6]:
%prun x_unfolded = _unfold(x, patch_size, stride)
y_unfolded = _unfold(y, patch_size, stride)

 

         15 function calls in 0.002 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.002    0.002    0.002    0.002 {method 'contiguous' of 'torch._C._TensorBase' objects}
        2    0.000    0.000    0.000    0.000 {method 'unfold' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.002    0.002 <ipython-input-4-8160f3836a48>:1(_unfold)
        1    0.000    0.000    0.000    0.000 {method 'view' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'permute' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.002    0.002 {built-in method builtins.exec}
        5    0.000    0.000    0.000    0.000 {method 'size' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.000    0.000 {method 'dim' of 'torch._C._TensorBase' objects}
        1    0.000    0.000    0.002    0.002 <string>:1(<module>)
        1    0.000    0.000    0.0

In [7]:
def _filter(input_windows, flag, value):
    ## EXTRACT MASK OR NOT DEPENDING ON VALUE
    input_window = input_windows[flag == value]
    return input_window.view(input_window.size(0), -1)

In [8]:
flag = torch.rand((1024)) > 0.75

In [9]:
x_unfolded_filtered = _filter(x_unfolded, flag, 1)
y_unfolded_filtered = _filter(y_unfolded, flag, 0)

In [10]:
x_unfolded_filtered.shape

torch.Size([248, 256])

In [11]:
y_unfolded_filtered.shape

torch.Size([776, 256])

In [12]:
def create_var(x, y, dim=1024, _v=1):
    N = K = x.size(0)
    M = y.size(0)
    # VERSION 1, REAL IMPLEMENTATION
    if _v == 0:
        U = torch.randn((N, K)).cuda()
        V = torch.randn((M, K)).cuda()
    # VERSION 2, FOR ME :)
    else:
        U = torch.randn((dim, dim)).cuda()
        V = torch.randn((dim, dim)).cuda()

    return U, V

In [13]:
%time U, V = create_var(x_unfolded_filtered, y_unfolded_filtered)

CPU times: user 28 ms, sys: 4 ms, total: 32 ms
Wall time: 30.9 ms


In [14]:
print(U.shape)
print(V.shape)

torch.Size([1024, 1024])
torch.Size([1024, 1024])


In [27]:
def pprint(*args, print_=True):
    if print_:
        print(*args)
        
def BAN(x, y, U, V):
    def _bilinear_attention_map(x, y, U, V, v1, P):
        pprint(v1.shape, P.shape, 'i,j->ij')
        tmp = torch.einsum('i,j->ij', [v1, P])
        pprint(tmp.shape)
        pprint()
        
        pprint(x.shape, U.shape, 'ki,kj->ij')
        XT_U = torch.einsum('ki,kj->ij', [x, U])
        pprint(XT_U.shape)
        pprint()
        
        pprint(tmp.shape, XT_U.shape, 'ij,ij->ij')
        tmp1 = torch.einsum('ij,ij->ij', [tmp, XT_U])
        pprint(tmp1.shape)
        pprint()
        
        pprint(V.shape, y.shape, 'ki,kj->ij')
        VT_y = torch.einsum('ki,kj->ij', [V, y])
        pprint(VT_y.shape)
        pprint()
        
        pprint(tmp1.shape, VT_y.shape, 'ik,kj->ij')
        tmp = torch.einsum('ik,kj->ij', [tmp1, VT_y])
        pprint(tmp.shape)
        pprint()
        
        #tmp /= tmp.max()
        A = F.softmax(tmp, dim=1)   
        print(A)
        pprint(A.shape)
        pprint()
        return A, XT_U, VT_y       
    
    def _filter_M(U,V,flag):
        mask_indexes = (flag == 1).nonzero().t()
        non_mask_indexes = (flag == 0).nonzero()
        U = U[mask_indexes.t(), mask_indexes]
        V = V[non_mask_indexes, mask_indexes]
        return U, V
    
    def _emulate(x, y):
        flag = torch.rand((1024)) > 0.75
        x_unfolded = _unfold(x, patch_size, stride)
        y_unfolded = _unfold(y, patch_size, stride)
        X = _filter(x_unfolded, flag, 1)
        Y = _filter(y_unfolded, flag, 0)  
        
        N = K = X.size(0)
        M = Y.size(0)
        p = X.size(1)
        r = Y.size(1)
        
        P = torch.randn((K)).cuda()
        v1 = torch.randn((p)).cuda()
        return X, Y, P, v1, flag

    ##EMULATE NEW MASK GENERATED
    pprint('EMULATE')
    X, Y, P, v1, flag = _emulate(x, y)
    U, V = _filter_M(U, V, flag)
    pprint(X.shape, Y.shape, U.shape, V.shape)
    pprint('EMULATE')
    pprint()
    
    A, XT_U, VT_y = _bilinear_attention_map(X, Y, U, V, v1, P) 
    pprint(A.shape, XT_U.shape, VT_y.shape, 'ki,kk,jk,a->ia')
    A = torch.einsum('ki,kk,jk,a->ia', [XT_U, A, VT_y, v1])
    pprint(A.shape, v1.shape, 'i,j->ij')
    A  = (A - A.mean())/A.std()
    return A

In [28]:
%time out = BAN(x, y, U, V)

EMULATE
torch.Size([236, 256]) torch.Size([788, 256]) torch.Size([236, 236]) torch.Size([788, 236])
EMULATE

torch.Size([256]) torch.Size([236]) i,j->ij
torch.Size([256, 236])

torch.Size([236, 256]) torch.Size([236, 236]) ki,kj->ij
torch.Size([256, 236])

torch.Size([256, 236]) torch.Size([256, 236]) ij,ij->ij
torch.Size([256, 236])

torch.Size([788, 236]) torch.Size([788, 256]) ki,kj->ij
torch.Size([236, 256])

torch.Size([256, 236]) torch.Size([236, 256]) ik,kj->ij
torch.Size([256, 256])

tensor([[  -228.7440,   2260.8074,  -2245.1704,  ...,    264.9022,
          -3032.8220,  -1499.6331],
        [ 12068.5996,   1665.9946,   2667.0586,  ...,  10224.2930,
         -13163.8760,   5809.5278],
        [ 18997.7148,  10164.2715,  -3481.1406,  ...,  13996.4180,
         -10000.0488,  22184.8184],
        ...,
        [   883.2329,  -4868.3423, -10449.9395,  ...,   3700.0425,
          -3574.8711,   1997.4094],
        [  -895.6641,  -6653.4136,   2140.6934,  ...,   9737.2305,
           

In [29]:
def pprint(*args, print_=False):
    if print_:
        print(*args)

In [22]:
%timeit out = BAN(x, y, U, V)

2.18 ms ± 45.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
out.shape

torch.Size([261, 256])

In [24]:
out

tensor([[ 0.6222, -1.6228,  1.3755,  ..., -0.9942, -0.6951,  0.3489],
        [ 0.7233, -1.8864,  1.5991,  ..., -1.1557, -0.8080,  0.4057],
        [-0.3618,  0.9409, -0.7989,  ...,  0.5761,  0.4026, -0.2033],
        ...,
        [-0.4025,  1.0468, -0.8888,  ...,  0.6410,  0.4479, -0.2261],
        [-2.0871,  5.4360, -4.6116,  ...,  3.3295,  2.3273, -1.1715],
        [-0.8837,  2.3006, -1.9523,  ...,  1.4090,  0.9848, -0.4962]],
       device='cuda:0')

# MOST OF THE TIMES WAS SPENT IN CREATING THE BIG U AND V MATRICES