In [1]:
import numpy as np
import torch

import pytorch_model_summary

import utils.functions as fns
from models.transformer import MultiHeadAttentionLayer
from models.transformer import MultiHeadSelfAttentionLayer
from models.transformer import get_transformer_encoder

  from .autonotebook import tqdm as notebook_tqdm


## Window Operations

In [None]:
# 4-class cyclic shifting
split_heads = torch.arange(72).view(2, 6, 6, 1)
split_heads = split_heads.expand(-1, -1, -1, 8)
split_heads = split_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()

shifted_heads = fns.cyclic_shift(split_heads, 1)
# shifted_heads = shifted_heads.permute(0, 2, 3, 1, 4).contiguous().view(2, 6, 6, -1)
print(shifted_heads)

In [None]:
# window partitioning
# shifted_heads = shifted_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()
partitions = fns.partition_window(shifted_heads, 2)
print(partitions)

In [None]:
# window merging
merged = fns.merge_window(partitions, 2)
print(merged)

In [None]:
# Make masking matrix.
mask = fns.masking_matrix(4, 6, 6, 2, 1)
print(mask)

In [None]:
# Masking
attn_values = torch.matmul(partitions, partitions.transpose(-1, -2))
attn_values.masked_fill_(mask, -1)
print(attn_values)

In [None]:
# Make masking matrix when query != key.
mask = fns.masking_matrix(4, 8, 8, 4, 2,
                             4, 4, 2, 1)
print(mask[0, 1], '\n')
print(mask[0, 2], '\n')
print(mask[0, 3], '\n')

## 2D Relative Position Bias (for Windows)

In [None]:
# Example window index
window_size = 3
coord_index = np.arange(window_size*window_size).reshape((window_size, window_size))
print(coord_index)

In [None]:
# Coordinate indices along each axis
axis_size = window_size * 2 - 1
coord_x = np.repeat(np.arange(window_size) * axis_size, window_size)
coord_y = np.tile(np.arange(window_size), window_size)
print(coord_x)
print(coord_y)

In [None]:
# Relative coordinate indices along each axis
relative_x = coord_x[:, np.newaxis] - coord_x
relative_y = coord_y[:, np.newaxis] - coord_y
print(relative_x)
print(relative_y)

In [None]:
# Relative coordinate indices in 2D window
relative_coord = relative_x + relative_y
relative_coord += relative_coord[-1, 0]
print(relative_coord)

In [None]:
# Defined function
print(fns.relative_position_index(2).reshape((4, 4)))  # 2x2 window
print(fns.relative_position_index(3).reshape((9, 9)))  # 3x3 window

In [None]:
# Example window index when key != query
query_window_size = 4
key_window_size = 2
qk_ratio = query_window_size // key_window_size

In [None]:
# Coordinate indices along each axis
axis_size = query_window_size * 2 - qk_ratio

query_coord_x = np.repeat(np.arange(query_window_size) * axis_size, query_window_size)
query_coord_y = np.tile(np.arange(query_window_size), query_window_size)
print(query_coord_x)
print(query_coord_y)

key_coord_x = np.repeat(np.arange(key_window_size) * axis_size * qk_ratio, key_window_size)
key_coord_y = np.tile(np.arange(key_window_size) * qk_ratio, key_window_size)
print(key_coord_x)
print(key_coord_y)

In [None]:
# Relative coordinate indices along each axis
relative_x = query_coord_x[:, np.newaxis] - key_coord_x
relative_y = query_coord_y[:, np.newaxis] - key_coord_y
print(relative_x)
print(relative_y)

In [None]:
# Relative coordinate indices in 2D window
relative_coord = relative_x + relative_y
relative_coord -= relative_coord[0, -1]
print(relative_coord)

In [None]:
# Defined function
print(fns.relative_position_index(2).reshape((4, 4)))  # 2x2 window
print(fns.relative_position_index(3).reshape((9, 9)))  # 3x3 window

In [None]:
print(fns.relative_position_index(6, 2).reshape((36, 4)))  # 6x6 window - 2x2 window
print()
print(fns.relative_position_index(6, 3).reshape((36, 9)))  # 6x6 window - 3x3 window

## Multi-head Attention Layer

In [2]:
# Multi-head self-attention module
msa_module = MultiHeadSelfAttentionLayer(128, 4, 28, 28, 4, True)
print(pytorch_model_summary.summary(msa_module, torch.zeros(16, 28, 28, 128)))

-----------------------------------------------------------------------------
      Layer (type)              Output Shape         Param #     Tr. Param #
          Linear-1         [16, 28, 28, 128]          16,512          16,512
          Linear-2         [16, 28, 28, 256]          33,024          33,024
         Softmax-3     [16, 4, 7, 7, 16, 16]               0               0
          Linear-4         [16, 28, 28, 128]          16,512          16,512
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
-----------------------------------------------------------------------------


In [3]:
modules = fns.clone_layer(msa_module, 3)
model = torch.nn.Sequential(*modules)

print(pytorch_model_summary.summary(model, torch.zeros(16, 28, 28, 128)))

---------------------------------------------------------------------------------------
                    Layer (type)          Output Shape         Param #     Tr. Param #
   MultiHeadSelfAttentionLayer-1     [16, 28, 28, 128]          66,244          66,244
   MultiHeadSelfAttentionLayer-2     [16, 28, 28, 128]          66,244          66,244
   MultiHeadSelfAttentionLayer-3     [16, 28, 28, 128]          66,244          66,244
Total params: 198,732
Trainable params: 198,732
Non-trainable params: 0
---------------------------------------------------------------------------------------


In [4]:
# Multi-head attention module
sa_module = MultiHeadAttentionLayer(128, 4,
                                    56, 56, 8,  # query config
                                    28, 28, 4,  # key, value config
                                    True)
print(pytorch_model_summary.summary(sa_module, torch.zeros(16, 56, 56, 128), torch.zeros(16, 28, 28, 128)))

-----------------------------------------------------------------------------
      Layer (type)              Output Shape         Param #     Tr. Param #
          Linear-1         [16, 56, 56, 128]          16,512          16,512
          Linear-2         [16, 28, 28, 256]          33,024          33,024
         Softmax-3     [16, 4, 7, 7, 64, 16]               0               0
          Linear-4         [16, 56, 56, 128]          16,512          16,512
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
-----------------------------------------------------------------------------


## Transformer Encoder

In [2]:
encoder = get_transformer_encoder(d_embed=128,
                                  positional_encoding=None,
                                  relative_position_embedding=True,
                                  n_layer=12,
                                  n_head=4,
                                  d_ff=128*4,
                                  n_patch=28,
                                  window_size=4)
print(pytorch_model_summary.summary(encoder, torch.zeros(16, 28, 28, 128)))

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
    EncoderLayer-1     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-2     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-3     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-4     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-5     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-6     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-7     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-8     [16, 28, 28, 128]         198,468         198,468
    EncoderLayer-9     [16, 28, 28, 128]         198,468         198,468
   EncoderLayer-10     [16, 28, 28, 128]         198,468         198,468
   EncoderLayer-11     [16, 28, 28, 128]         198,468         198,468
   EncoderLayer-12     [16, 28, 28, 128]         1