In [1]:
import torch
import torch.nn as nn

In [2]:
dim = 512
ff_dim = 2045
n_layers = 8
n_heads = 8

## Transformer Encoder

In [3]:
encoder = nn.TransformerEncoderLayer(dim, nhead=n_heads, dim_feedforward=ff_dim)

In [4]:
print(encoder)

TransformerEncoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2045, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2045, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
)


In [5]:
params = sum([p.numel() for p in encoder.parameters()])
print(f"encoder params: {params:,d}")

encoder params: 3,149,309


## Transformer Decoder

In [6]:
decoder = nn.TransformerDecoderLayer(dim, nhead=n_heads, dim_feedforward=2048)

In [7]:
print(decoder)

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
  )
  (linear1): Linear(in_features=512, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=512, bias=True)
  (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)


In [8]:
# Transformer Decoder参数量：
params = sum([p.numel() for p in decoder.parameters()])
print(f"decoder params: {params:,d}\n")

print("transformer decoder params:")
for name, p in decoder.named_parameters():
    print(name, p.size())

decoder params: 4,204,032

transformer decoder params:
self_attn.in_proj_weight torch.Size([1536, 512])
self_attn.in_proj_bias torch.Size([1536])
self_attn.out_proj.weight torch.Size([512, 512])
self_attn.out_proj.bias torch.Size([512])
multihead_attn.in_proj_weight torch.Size([1536, 512])
multihead_attn.in_proj_bias torch.Size([1536])
multihead_attn.out_proj.weight torch.Size([512, 512])
multihead_attn.out_proj.bias torch.Size([512])
linear1.weight torch.Size([2048, 512])
linear1.bias torch.Size([2048])
linear2.weight torch.Size([512, 2048])
linear2.bias torch.Size([512])
norm1.weight torch.Size([512])
norm1.bias torch.Size([512])
norm2.weight torch.Size([512])
norm2.bias torch.Size([512])
norm3.weight torch.Size([512])
norm3.bias torch.Size([512])
