In [1]:
import numpy as np
import torch

import utils.functions as fns

  from .autonotebook import tqdm as notebook_tqdm


## Window Operations

In [3]:
# 4-class cyclic shifting
split_heads = torch.arange(72).view(2, 6, 6, 1)
split_heads = split_heads.expand(-1, -1, -1, 8)
split_heads = split_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()

shifted_heads = fns.cyclic_shift(split_heads, 1)
# shifted_heads = shifted_heads.permute(0, 2, 3, 1, 4).contiguous().view(2, 6, 6, -1)
print(shifted_heads)

tensor([[[[[ 0,  0],
           [ 1,  1],
           [ 2,  2],
           [ 3,  3],
           [ 4,  4],
           [ 5,  5]],

          [[ 6,  6],
           [ 7,  7],
           [ 8,  8],
           [ 9,  9],
           [10, 10],
           [11, 11]],

          [[12, 12],
           [13, 13],
           [14, 14],
           [15, 15],
           [16, 16],
           [17, 17]],

          [[18, 18],
           [19, 19],
           [20, 20],
           [21, 21],
           [22, 22],
           [23, 23]],

          [[24, 24],
           [25, 25],
           [26, 26],
           [27, 27],
           [28, 28],
           [29, 29]],

          [[30, 30],
           [31, 31],
           [32, 32],
           [33, 33],
           [34, 34],
           [35, 35]]],


         [[[ 5,  5],
           [ 0,  0],
           [ 1,  1],
           [ 2,  2],
           [ 3,  3],
           [ 4,  4]],

          [[11, 11],
           [ 6,  6],
           [ 7,  7],
           [ 8,  8],
           [ 9,  9

In [4]:
# window partitioning
# shifted_heads = shifted_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()
partitions = fns.partition_window(shifted_heads, 2)
print(partitions)

tensor([[[[[[ 0,  0],
            [ 1,  1],
            [ 6,  6],
            [ 7,  7]],

           [[ 2,  2],
            [ 3,  3],
            [ 8,  8],
            [ 9,  9]],

           [[ 4,  4],
            [ 5,  5],
            [10, 10],
            [11, 11]]],


          [[[12, 12],
            [13, 13],
            [18, 18],
            [19, 19]],

           [[14, 14],
            [15, 15],
            [20, 20],
            [21, 21]],

           [[16, 16],
            [17, 17],
            [22, 22],
            [23, 23]]],


          [[[24, 24],
            [25, 25],
            [30, 30],
            [31, 31]],

           [[26, 26],
            [27, 27],
            [32, 32],
            [33, 33]],

           [[28, 28],
            [29, 29],
            [34, 34],
            [35, 35]]]],



         [[[[ 5,  5],
            [ 0,  0],
            [11, 11],
            [ 6,  6]],

           [[ 1,  1],
            [ 2,  2],
            [ 7,  7],
            [ 8,  8]],

  

In [5]:
# window merging
merged = fns.merge_window(partitions, 2)
print(merged)

tensor([[[[[ 0,  0],
           [ 1,  1],
           [ 2,  2],
           [ 3,  3],
           [ 4,  4],
           [ 5,  5]],

          [[ 6,  6],
           [ 7,  7],
           [ 8,  8],
           [ 9,  9],
           [10, 10],
           [11, 11]],

          [[12, 12],
           [13, 13],
           [14, 14],
           [15, 15],
           [16, 16],
           [17, 17]],

          [[18, 18],
           [19, 19],
           [20, 20],
           [21, 21],
           [22, 22],
           [23, 23]],

          [[24, 24],
           [25, 25],
           [26, 26],
           [27, 27],
           [28, 28],
           [29, 29]],

          [[30, 30],
           [31, 31],
           [32, 32],
           [33, 33],
           [34, 34],
           [35, 35]]],


         [[[ 5,  5],
           [ 0,  0],
           [ 1,  1],
           [ 2,  2],
           [ 3,  3],
           [ 4,  4]],

          [[11, 11],
           [ 6,  6],
           [ 7,  7],
           [ 8,  8],
           [ 9,  9

In [2]:
# Make masking matrix.
mask = fns.masking_matrix(4, 6, 6, 2, 1)
print(mask)

tensor([[[[[[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, False, False, False]],

           [[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, False, False, False]],

           [[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, False, False, False]]],


          [[[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, False, False, False]],

           [[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, False, False, False]],

           [[False, False, False, False],
            [False, False, False, False],
            [False, False, False, False],
            [False, Fa

In [8]:
# Masking
attn_values = torch.matmul(partitions, partitions.transpose(-1, -2))
attn_values.masked_fill_(mask, -1)
print(attn_values)

tensor([[[[[[    0,     0,     0,     0],
            [    0,     2,    12,    14],
            [    0,    12,    72,    84],
            [    0,    14,    84,    98]],

           [[    8,    12,    32,    36],
            [   12,    18,    48,    54],
            [   32,    48,   128,   144],
            [   36,    54,   144,   162]],

           [[   32,    40,    80,    88],
            [   40,    50,   100,   110],
            [   80,   100,   200,   220],
            [   88,   110,   220,   242]]],


          [[[  288,   312,   432,   456],
            [  312,   338,   468,   494],
            [  432,   468,   648,   684],
            [  456,   494,   684,   722]],

           [[  392,   420,   560,   588],
            [  420,   450,   600,   630],
            [  560,   600,   800,   840],
            [  588,   630,   840,   882]],

           [[  512,   544,   704,   736],
            [  544,   578,   748,   782],
            [  704,   748,   968,  1012],
            [  736,   