In [1]:
import numpy as np
import torch
import time

import pytorch_model_summary

import utils.functions as fns
from models.transformer import MultiHeadAttentionLayer, MultiHeadSelfAttentionLayer
from models.transformer import get_transformer_encoder, get_transformer_decoder
from models.whole_models import SRTransformer, TwoStageDecoder

  from .autonotebook import tqdm as notebook_tqdm


## Window Operations

In [None]:
# 4-class cyclic shifting
split_heads = torch.arange(72).view(2, 6, 6, 1)
split_heads = split_heads.expand(-1, -1, -1, 8)
split_heads = split_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()

shifted_heads = fns.cyclic_shift(split_heads, 1)
# shifted_heads = shifted_heads.permute(0, 2, 3, 1, 4).contiguous().view(2, 6, 6, -1)
print(shifted_heads)

In [None]:
# window partitioning
# shifted_heads = shifted_heads.view(2, 6, 6, 4, 2).permute(0, 3, 1, 2, 4).contiguous()
partitions = fns.partition_window(shifted_heads, 2)
print(partitions)

In [None]:
# window merging
merged = fns.merge_window(partitions, 2)
print(merged)

In [None]:
# Make masking matrix.
mask = fns.masking_matrix(4, 6, 6, 2, 1)
print(mask)

In [None]:
# Masking
attn_values = torch.matmul(partitions, partitions.transpose(-1, -2))
attn_values.masked_fill_(mask, -1)
print(attn_values)

In [None]:
# Make masking matrix when query != key.
mask = fns.masking_matrix(4, 8, 8, 4, 2,
                             4, 4, 2, 1)
print(mask[0, 1], '\n')
print(mask[0, 2], '\n')
print(mask[0, 3], '\n')

## 2D Relative Position Bias (for Windows)

In [None]:
# Example window index
window_size = 3
coord_index = np.arange(window_size*window_size).reshape((window_size, window_size))
print(coord_index)

In [None]:
# Coordinate indices along each axis
axis_size = window_size * 2 - 1
coord_x = np.repeat(np.arange(window_size) * axis_size, window_size)
coord_y = np.tile(np.arange(window_size), window_size)
print(coord_x)
print(coord_y)

In [None]:
# Relative coordinate indices along each axis
relative_x = coord_x[:, np.newaxis] - coord_x
relative_y = coord_y[:, np.newaxis] - coord_y
print(relative_x)
print(relative_y)

In [None]:
# Relative coordinate indices in 2D window
relative_coord = relative_x + relative_y
relative_coord += relative_coord[-1, 0]
print(relative_coord)

In [None]:
# Defined function
print(fns.relative_position_index(2).reshape((4, 4)))  # 2x2 window
print(fns.relative_position_index(3).reshape((9, 9)))  # 3x3 window

In [None]:
# Example window index when key != query
query_window_size = 4
key_window_size = 2
qk_ratio = query_window_size // key_window_size

In [None]:
# Coordinate indices along each axis
axis_size = query_window_size * 2 - qk_ratio

query_coord_x = np.repeat(np.arange(query_window_size) * axis_size, query_window_size)
query_coord_y = np.tile(np.arange(query_window_size), query_window_size)
print(query_coord_x)
print(query_coord_y)

key_coord_x = np.repeat(np.arange(key_window_size) * axis_size * qk_ratio, key_window_size)
key_coord_y = np.tile(np.arange(key_window_size) * qk_ratio, key_window_size)
print(key_coord_x)
print(key_coord_y)

In [None]:
# Relative coordinate indices along each axis
relative_x = query_coord_x[:, np.newaxis] - key_coord_x
relative_y = query_coord_y[:, np.newaxis] - key_coord_y
print(relative_x)
print(relative_y)

In [None]:
# Relative coordinate indices in 2D window
relative_coord = relative_x + relative_y
relative_coord -= relative_coord[0, -1]
print(relative_coord)

In [None]:
# Defined function
print(fns.relative_position_index(2).reshape((4, 4)))  # 2x2 window
print(fns.relative_position_index(3).reshape((9, 9)))  # 3x3 window

In [None]:
print(fns.relative_position_index(6, 2).reshape((36, 4)))  # 6x6 window - 2x2 window
print()
print(fns.relative_position_index(6, 3).reshape((36, 9)))  # 6x6 window - 3x3 window

## Multi-head Attention Layer

In [2]:
# Multi-head self-attention module
msa_module = MultiHeadSelfAttentionLayer(128, 4, 28, 28, 4, True)
print(pytorch_model_summary.summary(msa_module, torch.zeros(16, 28, 28, 128)))

-----------------------------------------------------------------------------
      Layer (type)              Output Shape         Param #     Tr. Param #
          Linear-1         [16, 28, 28, 128]          16,512          16,512
          Linear-2         [16, 28, 28, 256]          33,024          33,024
         Softmax-3     [16, 4, 7, 7, 16, 16]               0               0
          Linear-4         [16, 28, 28, 128]          16,512          16,512
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
-----------------------------------------------------------------------------


In [3]:
modules = fns.clone_layer(msa_module, 3)
model = torch.nn.Sequential(*modules)

print(pytorch_model_summary.summary(model, torch.zeros(16, 28, 28, 128)))

---------------------------------------------------------------------------------------
                    Layer (type)          Output Shape         Param #     Tr. Param #
   MultiHeadSelfAttentionLayer-1     [16, 28, 28, 128]          66,244          66,244
   MultiHeadSelfAttentionLayer-2     [16, 28, 28, 128]          66,244          66,244
   MultiHeadSelfAttentionLayer-3     [16, 28, 28, 128]          66,244          66,244
Total params: 198,732
Trainable params: 198,732
Non-trainable params: 0
---------------------------------------------------------------------------------------


In [4]:
# Multi-head attention module
sa_module = MultiHeadAttentionLayer(128, 4,
                                    56, 56, 8,  # query config
                                    28, 28, 4,  # key, value config
                                    True)
print(pytorch_model_summary.summary(sa_module, torch.zeros(16, 56, 56, 128), torch.zeros(16, 28, 28, 128)))

-----------------------------------------------------------------------------
      Layer (type)              Output Shape         Param #     Tr. Param #
          Linear-1         [16, 56, 56, 128]          16,512          16,512
          Linear-2         [16, 28, 28, 256]          33,024          33,024
         Softmax-3     [16, 4, 7, 7, 64, 16]               0               0
          Linear-4         [16, 56, 56, 128]          16,512          16,512
Total params: 66,048
Trainable params: 66,048
Non-trainable params: 0
-----------------------------------------------------------------------------


## Transformer Bodies

In [2]:
encoder = get_transformer_encoder(d_embed=128,
                                  positional_encoding=None,
                                  relative_position_embedding=True,
                                  n_layer=12,
                                  n_head=4,
                                  d_ff=128*4,
                                  n_patch=24,
                                  window_size=4)
print(pytorch_model_summary.summary(encoder, torch.zeros(16, 24, 24, 128)))

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
    EncoderLayer-1     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-2     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-3     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-4     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-5     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-6     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-7     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-8     [16, 24, 24, 128]         198,468         198,468
    EncoderLayer-9     [16, 24, 24, 128]         198,468         198,468
   EncoderLayer-10     [16, 24, 24, 128]         198,468         198,468
   EncoderLayer-11     [16, 24, 24, 128]         198,468         198,468
   EncoderLayer-12     [16, 24, 24, 128]         1

In [3]:
decoder = get_transformer_decoder(d_embed=128,
                                  positional_encoding=None,
                                  relative_position_embedding=True,
                                  n_layer=12,
                                  n_head=4,
                                  d_ff=128*4,
                                  query_n_patch=48,
                                  query_window_size=8,
                                  key_n_patch=24,
                                  key_window_size=4)

print(pytorch_model_summary.summary(decoder,
                                    torch.zeros(16, 48, 48, 128), torch.zeros(16, 24, 24, 128),
                                    show_input=True))

--------------------------------------------------------------------------------------------
      Layer (type)                              Input Shape         Param #     Tr. Param #
    DecoderLayer-1     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-2     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-3     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-4     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-5     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-6     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-7     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-8     [16, 48, 48, 128], [16, 24, 24, 128]         266,516         266,516
    DecoderLayer-9     [16, 48, 48, 128], [16, 24, 24, 128]         266,516    

## Whole SR Transformer Model

In [3]:
# x3 upscale
device = torch.device('cuda')
srTrans = SRTransformer().to(device)
print(pytorch_model_summary.summary(srTrans,
                                    torch.zeros(1, 3, 48, 48, device=device), torch.zeros(1, 3, 48*2, 48*2, device=device),
                                    show_input=True))

------------------------------------------------------------------------------------------------
            Layer (type)                            Input Shape         Param #     Tr. Param #
        EmbeddingLayer-1                         [1, 3, 48, 48]           1,664           1,664
    TransformerEncoder-2                       [1, 24, 24, 128]       2,381,616       2,381,616
        EmbeddingLayer-3                         [1, 3, 96, 96]           1,664           1,664
    TransformerDecoder-4     [1, 48, 48, 128], [1, 24, 24, 128]       3,198,192       3,198,192
   ReconstructionBlock-5                       [1, 48, 48, 128]          72,204          72,204
Total params: 5,655,340
Trainable params: 5,655,340
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [6]:
# x4 upscale
device = torch.device('cuda')
srTrans = SRTransformer(upscale=4, decoder_n_layer=12).to(device)
print(pytorch_model_summary.summary(srTrans,
                                    torch.zeros(1, 3, 48, 48, device=device), torch.zeros(1, 3, 48*4, 48*4, device=device),
                                    show_input=True))

------------------------------------------------------------------------------------------------
            Layer (type)                            Input Shape         Param #     Tr. Param #
        EmbeddingLayer-1                         [1, 3, 48, 48]           1,664           1,664
    TransformerEncoder-2                       [1, 24, 24, 128]       2,381,616       2,381,616
        EmbeddingLayer-3                       [1, 3, 192, 192]           6,272           6,272
       OneStageDecoder-4     [1, 48, 48, 128], [1, 24, 24, 128]       3,262,992       3,262,992
   ReconstructionBlock-5                       [1, 96, 96, 128]          72,204          72,204
Total params: 5,724,748
Trainable params: 5,724,748
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [5]:
# x4 upscale with 24 decoder layer
device = torch.device('cuda')
srTrans = SRTransformer(upscale=4, decoder_n_layer=24).to(device)
print(pytorch_model_summary.summary(srTrans,
                                    torch.zeros(1, 3, 48, 48, device=device), torch.zeros(1, 3, 48*4, 48*4, device=device),
                                    show_input=True))

------------------------------------------------------------------------------------------------
            Layer (type)                            Input Shape         Param #     Tr. Param #
        EmbeddingLayer-1                         [1, 3, 48, 48]           1,664           1,664
    TransformerEncoder-2                       [1, 24, 24, 128]       2,381,616       2,381,616
        EmbeddingLayer-3                       [1, 3, 192, 192]           6,272           6,272
       OneStageDecoder-4     [1, 48, 48, 128], [1, 24, 24, 128]       6,459,936       6,459,936
   ReconstructionBlock-5                       [1, 96, 96, 128]          72,204          72,204
Total params: 8,921,692
Trainable params: 8,921,692
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [2]:
# Two-stage decoder for x4 upscale
device = torch.device('cuda')
twoStageDecoder = TwoStageDecoder(2, 4, 128, 12, 4, 24, 0.1).to(device)

x = torch.zeros(1, 24*2, 24*2, 128, device=device)
z = torch.zeros(1, 24, 24, 128, device=device)
origin_img = torch.zeros(1, 3, 48*2, 48*2, device=device)

print(pytorch_model_summary.summary(twoStageDecoder,
                                    x, z, origin_img,
                                    show_input=True))

------------------------------------------------------------------------------------------------
            Layer (type)                            Input Shape         Param #     Tr. Param #
    TransformerDecoder-1     [1, 48, 48, 128], [1, 24, 24, 128]       1,599,096       1,599,096
   ReconstructionBlock-2                       [1, 48, 48, 128]          72,204          72,204
        EmbeddingLayer-3                       [1, 3, 192, 192]           1,664           1,664
    TransformerDecoder-4     [1, 96, 96, 128], [1, 24, 24, 128]       1,597,848       1,597,848
Total params: 3,270,812
Trainable params: 3,270,812
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [4]:
# x4 upscale with two-stage decoder
device = torch.device('cuda')
srTrans = SRTransformer(upscale=4, intermediate_upscale=True, decoder_n_layer=12).to(device)
print(pytorch_model_summary.summary(srTrans,
                                    torch.zeros(1, 3, 48, 48, device=device), torch.zeros(1, 3, 48*2, 48*2, device=device),
                                    show_input=False))

------------------------------------------------------------------------------------------------
            Layer (type)                           Output Shape         Param #     Tr. Param #
        EmbeddingLayer-1                       [1, 24, 24, 128]           1,664           1,664
    TransformerEncoder-2                       [1, 24, 24, 128]       2,381,616       2,381,616
        EmbeddingLayer-3                       [1, 48, 48, 128]           1,664           1,664
       TwoStageDecoder-4     [1, 96, 96, 128], [1, 3, 192, 192]       3,270,812       3,270,812
   ReconstructionBlock-5                       [1, 3, 192, 192]          72,204          72,204
Total params: 5,727,960
Trainable params: 5,727,960
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [3]:
# x2 upscale - revised model
device = torch.device('cuda:2')

srTrans = SRTransformer(d_embed=180,
                        interpolated_decoder_input=False,
                        raw_decoder_input=False).to(device)

print(pytorch_model_summary.summary(srTrans,
                                    torch.zeros(1, 3, 48, 48, device=device), torch.zeros(1, 3, 48*2, 48*2, device=device),
                                    show_input=True))

------------------------------------------------------------------------------------------------
            Layer (type)                            Input Shape         Param #     Tr. Param #
        EmbeddingLayer-1                         [1, 3, 48, 48]           2,340           2,340
    TransformerEncoder-2                       [1, 24, 24, 180]       4,696,032       4,696,032
                Linear-3                       [1, 24, 24, 180]           8,688           8,688
        EmbeddingLayer-4                         [1, 3, 96, 96]           2,340           2,340
    TransformerDecoder-5     [1, 48, 48, 180], [1, 24, 24, 180]       6,286,368       6,286,368
   ReconstructionBlock-6                       [1, 48, 48, 180]         138,972         138,972
Total params: 11,134,740
Trainable params: 11,134,740
Non-trainable params: 0
------------------------------------------------------------------------------------------------


In [8]:
encoder = get_transformer_encoder(d_embed=192,
                                  positional_encoding=None,
                                  relative_position_embedding=True,
                                  n_layer=18,
                                  n_head=8,
                                  d_ff=192*4,
                                  n_patch=24,
                                  window_size=4)
print(pytorch_model_summary.summary(encoder, torch.zeros(16, 24, 24, 192)))

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
    EncoderLayer-1     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-2     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-3     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-4     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-5     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-6     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-7     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-8     [16, 24, 24, 192]         445,256         445,256
    EncoderLayer-9     [16, 24, 24, 192]         445,256         445,256
   EncoderLayer-10     [16, 24, 24, 192]         445,256         445,256
   EncoderLayer-11     [16, 24, 24, 192]         445,256         445,256
   EncoderLayer-12     [16, 24, 24, 192]         4