In [4]:
import numpy as np
import torch

import pytorch_model_summary

import utils.functions as fns
from models.transformer import MultiHeadAttentionLayer
from models.transformer import get_transformer_encoder

## Window Operations

In [None]:
# 4-class cyclic shifting
split_heads = torch.arange(72).view(2, 6, 6, 1)
split_heads = split_heads.expand(-1, -1, -1, 8)
split_heads = split_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()

shifted_heads = fns.cyclic_shift(split_heads, 1)
# shifted_heads = shifted_heads.permute(0, 2, 3, 1, 4).contiguous().view(2, 6, 6, -1)
print(shifted_heads)

In [None]:
# window partitioning
# shifted_heads = shifted_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()
partitions = fns.partition_window(shifted_heads, 2)
print(partitions)

In [None]:
# window merging
merged = fns.merge_window(partitions, 2)
print(merged)

In [None]:
# Make masking matrix.
mask = fns.masking_matrix(4, 6, 6, 2, 1)
print(mask)

In [None]:
# Masking
attn_values = torch.matmul(partitions, partitions.transpose(-1, -2))
attn_values.masked_fill_(mask, -1)
print(attn_values)

## 2D Relative Position Bias (for Windows)

In [None]:
# Example window index
window_size = 3
coord_index = np.arange(window_size*window_size).reshape((window_size, window_size))
print(coord_index)

In [None]:
# Coordinate indices along each axis
axis_size = window_size * 2 - 1
coord_x = np.repeat(np.arange(window_size) * axis_size, window_size)
coord_y = np.tile(np.arange(window_size), window_size)
print(coord_x)
print(coord_y)

In [None]:
# Relative coordinate indices along each axis
relative_x = coord_x[:, np.newaxis] - coord_x
relative_y = coord_y[:, np.newaxis] - coord_y
print(relative_x)
print(relative_y)

In [None]:
# Relative coordinate indices in 2D window
relative_coord = relative_x + relative_y
relative_coord += relative_coord[-1, 0]
print(relative_coord)

In [None]:
# Defined function
print(fns.relative_position_index(2).reshape((4, 4)))  # 2x2 window
print(fns.relative_position_index(3).reshape((9, 9)))  # 3x3 window

## Multi-head Attention Layer

In [5]:
msa_module = MultiHeadAttentionLayer(256, 4, 28, 28, 7, True)
print(pytorch_model_summary.summary(msa_module, torch.zeros(16, 28, 28, 256)))

-----------------------------------------------------------------------------
      Layer (type)              Output Shape         Param #     Tr. Param #
          Linear-1         [16, 28, 28, 768]         197,376         197,376
         Softmax-2     [16, 4, 4, 4, 49, 49]               0               0
          Linear-3         [16, 28, 28, 256]          65,792          65,792
Total params: 263,168
Trainable params: 263,168
Non-trainable params: 0
-----------------------------------------------------------------------------


In [6]:
modules = fns.clone_layer(msa_module, 3)
model = torch.nn.Sequential(*modules)

print(pytorch_model_summary.summary(model, torch.zeros(16, 28, 28, 256)))

-----------------------------------------------------------------------------------
                Layer (type)          Output Shape         Param #     Tr. Param #
   MultiHeadAttentionLayer-1     [16, 28, 28, 256]         263,844         263,844
   MultiHeadAttentionLayer-2     [16, 28, 28, 256]         263,844         263,844
   MultiHeadAttentionLayer-3     [16, 28, 28, 256]         263,844         263,844
Total params: 791,532
Trainable params: 791,532
Non-trainable params: 0
-----------------------------------------------------------------------------------


## Transformer Encoder

In [7]:
encoder = get_transformer_encoder()
print(pytorch_model_summary.summary(encoder, torch.zeros(16, 28, 28, 256)))

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
    EncoderLayer-1     [16, 28, 28, 256]         790,436         790,436
    EncoderLayer-2     [16, 28, 28, 256]         790,436         790,436
    EncoderLayer-3     [16, 28, 28, 256]         790,436         790,436
    EncoderLayer-4     [16, 28, 28, 256]         790,436         790,436
    EncoderLayer-5     [16, 28, 28, 256]         790,436         790,436
    EncoderLayer-6     [16, 28, 28, 256]         790,436         790,436
Total params: 4,742,616
Trainable params: 4,742,616
Non-trainable params: 0
-------------------------------------------------------------------------
