In [43]:
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_


In [36]:
def get_relative_position_index(win_h: int, win_w: int) -> torch.Tensor:
    """Function to generate pair-wise relative position index for each token inside the window.
        Taken from Timms Swin V1 implementation.
    Args:
        win_h (int): Window/Grid height.
        win_w (int): Window/Grid width.
    Returns:
        relative_coords (torch.Tensor): Pair-wise relative position indexes [height * width, height * width].
    """
    coords = torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))
    # print(coords.shape)
    coords_flatten = torch.flatten(coords, 1)
    print('coords_flatten.shape: ', coords_flatten.shape)
    print('coords_flatten.shape: ', coords_flatten[:,:,None].shape)
    relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
    print('relative_coords.shape: ', relative_coords.shape)
    relative_coords = relative_coords.permute(1, 2, 0).contiguous()
    print('relative_coords values: ', relative_coords[0, 0, 0], relative_coords[0, 0, 1])
    relative_coords[:, :, 0] += win_h - 1
    relative_coords[:, :, 1] += win_w - 1
    relative_coords[:, :, 0] *= 2 * win_w - 1
    
    print('relative_coords values after: ', relative_coords[0, 0, 0], relative_coords[0, 0, 1])
    return relative_coords.sum(-1)

print(get_relative_position_index(3, 3))
print(get_relative_position_index(3, 3).view(-1))



coords_flatten.shape:  torch.Size([2, 9])
coords_flatten.shape:  torch.Size([2, 9, 1])
relative_coords.shape:  torch.Size([2, 9, 9])
relative_coords values:  tensor(0) tensor(0)
relative_coords values after:  tensor(10) tensor(2)
tensor([[12, 11, 10,  7,  6,  5,  2,  1,  0],
        [13, 12, 11,  8,  7,  6,  3,  2,  1],
        [14, 13, 12,  9,  8,  7,  4,  3,  2],
        [17, 16, 15, 12, 11, 10,  7,  6,  5],
        [18, 17, 16, 13, 12, 11,  8,  7,  6],
        [19, 18, 17, 14, 13, 12,  9,  8,  7],
        [22, 21, 20, 17, 16, 15, 12, 11, 10],
        [23, 22, 21, 18, 17, 16, 13, 12, 11],
        [24, 23, 22, 19, 18, 17, 14, 13, 12]])
coords_flatten.shape:  torch.Size([2, 9])
coords_flatten.shape:  torch.Size([2, 9, 1])
relative_coords.shape:  torch.Size([2, 9, 9])
relative_coords values:  tensor(0) tensor(0)
relative_coords values after:  tensor(10) tensor(2)
tensor([12, 11, 10,  7,  6,  5,  2,  1,  0, 13, 12, 11,  8,  7,  6,  3,  2,  1,
        14, 13, 12,  9,  8,  7,  4,  3,  2, 1

In [45]:
relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * 3 - 1) * (2 * 3 - 1), 1)
        )
trunc_normal_(relative_position_bias_table, std=0.02)

Parameter containing:
tensor([[-0.0018],
        [ 0.0313],
        [-0.0024],
        [ 0.0076],
        [ 0.0075],
        [-0.0005],
        [-0.0253],
        [-0.0023],
        [-0.0064],
        [-0.0229],
        [ 0.0166],
        [-0.0106],
        [ 0.0071],
        [ 0.0084],
        [ 0.0272],
        [ 0.0013],
        [ 0.0029],
        [-0.0019],
        [ 0.0119],
        [-0.0239],
        [ 0.0146],
        [-0.0215],
        [ 0.0016],
        [ 0.0075],
        [-0.0018]], requires_grad=True)

In [55]:
relative_position_index = get_relative_position_index(3, 3).view(-1)

coords_flatten.shape:  torch.Size([2, 9])
coords_flatten.shape:  torch.Size([2, 9, 1])
relative_coords.shape:  torch.Size([2, 9, 9])
relative_coords values:  tensor(0) tensor(0)
relative_coords values after:  tensor(10) tensor(2)


In [57]:
relative_position_bias_table[relative_position_index].view(9, 9, -1).permute(2, 0, 1).contiguous().unsqueeze(0).size()

torch.Size([1, 1, 9, 9])

In [24]:
print(get_relative_position_index(3, 3))


coords_flatten.shape:  torch.Size([2, 9])
coords_flatten.shape:  torch.Size([2, 9, 1])
relative_coords.shape:  torch.Size([2, 9, 9])
relative_coords values:  tensor(0) tensor(2)
tensor([[12, 11, 10,  7,  6,  5,  2,  1,  0],
        [13, 12, 11,  8,  7,  6,  3,  2,  1],
        [14, 13, 12,  9,  8,  7,  4,  3,  2],
        [17, 16, 15, 12, 11, 10,  7,  6,  5],
        [18, 17, 16, 13, 12, 11,  8,  7,  6],
        [19, 18, 17, 14, 13, 12,  9,  8,  7],
        [22, 21, 20, 17, 16, 15, 12, 11, 10],
        [23, 22, 21, 18, 17, 16, 13, 12, 11],
        [24, 23, 22, 19, 18, 17, 14, 13, 12]])


In [11]:
win_h, win_w = 3, 3
torch.stack(torch.meshgrid([torch.arange(win_h), torch.arange(win_w)]))

tensor([[[0, 0, 0],
         [1, 1, 1],
         [2, 2, 2]],

        [[0, 1, 2],
         [0, 1, 2],
         [0, 1, 2]]])