In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from tests.test import test_layer, summary_layer
from deberta.config import Config

In [2]:
from deberta.attentions import DisentangledSelfAttention

In [3]:
config = Config(
    hidden_dim=768,
    embedding_dim=1024,
    max_seq_len=512,
    padding_idx=0,
    vocab_size=30522,
    position_biased_input=True,
    num_heads=12,
    num_head_dim=64,
    layernorm_eps=1e-9,
    hidden_dropout_prob=0.1,
    num_hidden_layers=12,
)

In [4]:
layer = DisentangledSelfAttention(config)
layer

DisentangledSelfAttention(
  (query_layer): Linear(in_features=768, out_features=768, bias=True)
  (key_layer): Linear(in_features=768, out_features=768, bias=True)
  (value_layer): Linear(in_features=768, out_features=768, bias=True)
  (relative_position_embedding): RelativePositionEmbedding(
    (relative_position_embedding_layer): Embedding(512, 768)
    (relative_position_query_layer): Linear(in_features=768, out_features=768, bias=True)
    (relative_position_key_layer): Linear(in_features=768, out_features=768, bias=True)
  )
  (feedforward): AttentionFeedForward(
    (dense): Linear(in_features=768, out_features=768, bias=True)
    (layernorm): LayerNorm((768,), eps=1e-09, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [5]:
output = test_layer(layer, (10, 512, 768))

  from .autonotebook import tqdm as notebook_tqdm


input shape: torch.Size([10, 512, 768])
output type: <class 'torch.Tensor'>
output shape: torch.Size([10, 512, 768])


In [6]:
summary_layer(layer, (10, 512, 768))

Layer (type:depth-idx)                   Output Shape              Param #
DisentangledSelfAttention                [10, 512, 768]            --
├─Linear: 1-1                            [10, 512, 768]            590,592
├─Linear: 1-2                            [10, 512, 768]            590,592
├─Linear: 1-3                            [10, 512, 768]            590,592
├─RelativePositionEmbedding: 1-4         [10, 512, 768]            --
│    └─Embedding: 2-1                    [512, 768]                393,216
│    └─Linear: 2-2                       [512, 768]                590,592
│    └─Linear: 2-3                       [512, 768]                590,592
├─AttentionFeedForward: 1-5              [10, 512, 768]            --
│    └─Linear: 2-4                       [10, 512, 768]            590,592
│    └─Dropout: 2-5                      [10, 512, 768]            --
│    └─LayerNorm: 2-6                    [10, 512, 768]            1,536
Total params: 3,938,304
Trainable params: 3,938

In [7]:
from deberta.networks import InputEmbedding

layer = InputEmbedding(config)

In [8]:
arr = torch.randint(0, 30522, (2, 512))
output = test_layer(layer, input_data=arr)

input shape: torch.Size([2, 512])
output type: <class 'dict'>
embeddings shape: torch.Size([2, 512, 768])
position_embeddings shape: torch.Size([2, 512, 1024])


In [9]:
summary_layer(layer, input_data=arr)

Layer (type:depth-idx)                   Output Shape              Param #
InputEmbedding                           [2, 512, 1024]            --
├─Embedding: 1-1                         [2, 512, 1024]            31,254,528
├─Embedding: 1-2                         [2, 512, 1024]            524,288
├─Linear: 1-3                            [2, 512, 768]             786,432
├─LayerNorm: 1-4                         [2, 512, 768]             1,536
Total params: 32,566,784
Trainable params: 32,566,784
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 65.13
Input size (MB): 0.01
Forward/backward pass size (MB): 29.36
Params size (MB): 130.27
Estimated Total Size (MB): 159.64


In [10]:
from deberta.layers import RelativePositionEmbedding

layer = RelativePositionEmbedding(config)

In [11]:
hidden_states = output['embeddings']
output = test_layer(layer, input_data=hidden_states)

input shape: torch.Size([2, 512, 768])
output type: <class 'tuple'>
output 0 shape: torch.Size([2, 512, 768])
output 1 shape: torch.Size([2, 512, 768])


In [12]:
from deberta.networks import BaseNetwork

embedding_layer = InputEmbedding(config)
layer = BaseNetwork(config)

In [13]:
input_data = torch.randint(0, 30522, (10, 512))

output = test_layer(embedding_layer, input_data=input_data)
summary_layer(embedding_layer, input_data=input_data)

input shape: torch.Size([10, 512])
output type: <class 'dict'>
embeddings shape: torch.Size([10, 512, 768])
position_embeddings shape: torch.Size([10, 512, 1024])
Layer (type:depth-idx)                   Output Shape              Param #
InputEmbedding                           [10, 512, 1024]           --
├─Embedding: 1-1                         [10, 512, 1024]           31,254,528
├─Embedding: 1-2                         [10, 512, 1024]           524,288
├─Linear: 1-3                            [10, 512, 768]            786,432
├─LayerNorm: 1-4                         [10, 512, 768]            1,536
Total params: 32,566,784
Trainable params: 32,566,784
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 325.67
Input size (MB): 0.04
Forward/backward pass size (MB): 146.80
Params size (MB): 130.27
Estimated Total Size (MB): 277.11


In [14]:
inputs = output['embeddings']
output = test_layer(layer, input_data=inputs)
summary_layer(layer, input_data=inputs)

input shape: torch.Size([10, 512, 768])
output type: <class 'tuple'>
output 0 shape: torch.Size([10, 512, 768])
output 1 type: <class 'list'>
Layer (type:depth-idx)                                  Output Shape              Param #
BaseNetwork                                             [10, 512, 768]            --
├─ModuleList: 1-1                                       --                        --
│    └─TransformerBlock: 2-1                            [10, 512, 768]            --
│    │    └─DisentangledSelfAttention: 3-1              [10, 512, 768]            --
│    │    │    └─Linear: 4-1                            [10, 512, 768]            590,592
│    │    │    └─Linear: 4-2                            [10, 512, 768]            590,592
│    │    │    └─Linear: 4-3                            [10, 512, 768]            590,592
│    │    │    └─RelativePositionEmbedding: 4-4         [10, 512, 768]            --
│    │    │    │    └─Embedding: 5-1                    [512, 768]       

In [19]:
for i in range(len(output[1])):
    print(output[1][i].shape)

torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
torch.Size([10, 512, 768])
