In [1]:
import torch
from torch import nn
import torch.nn.functional as F

from fastai.callback.hook import Hooks

In [14]:
bs, seq_len, d_in = 5, 10, 32
x = torch.randn(bs, seq_len, d_in)
d_h = 32

In [15]:
attn = nn.MultiheadAttention(d_h, 4, bias=False, batch_first=True)

for n, p in attn.named_parameters():
    print(f"{n:<12} {p.shape}")

in_proj_weight torch.Size([96, 32])
out_proj.weight torch.Size([32, 32])


In [18]:
out, attn_weights = attn(x, x, x, average_attn_weights=False)

In [19]:
out.shape, attn_weights.shape

(torch.Size([5, 10, 32]), torch.Size([5, 4, 10, 10]))

In [29]:
enc = nn.TransformerEncoderLayer(d_h, 4, 128, batch_first=True)

In [32]:
for n, p in enc.named_parameters():
    print(f"{n:<28} {p.shape}")

self_attn.in_proj_weight     torch.Size([96, 32])
self_attn.in_proj_bias       torch.Size([96])
self_attn.out_proj.weight    torch.Size([32, 32])
self_attn.out_proj.bias      torch.Size([32])
linear1.weight               torch.Size([128, 32])
linear1.bias                 torch.Size([128])
linear2.weight               torch.Size([32, 128])
linear2.bias                 torch.Size([32])
norm1.weight                 torch.Size([32])
norm1.bias                   torch.Size([32])
norm2.weight                 torch.Size([32])
norm2.bias                   torch.Size([32])


In [27]:
out = enc(x)
print(out.shape)

torch.Size([5, 10, 32])


In [41]:
dec = nn.TransformerDecoderLayer(d_h, 4, batch_first=True)
for n, p in dec.named_parameters():
    print(f"{n:<32} {p.shape}")

self_attn.in_proj_weight         torch.Size([96, 32])
self_attn.in_proj_bias           torch.Size([96])
self_attn.out_proj.weight        torch.Size([32, 32])
self_attn.out_proj.bias          torch.Size([32])
multihead_attn.in_proj_weight    torch.Size([96, 32])
multihead_attn.in_proj_bias      torch.Size([96])
multihead_attn.out_proj.weight   torch.Size([32, 32])
multihead_attn.out_proj.bias     torch.Size([32])
linear1.weight                   torch.Size([2048, 32])
linear1.bias                     torch.Size([2048])
linear2.weight                   torch.Size([32, 2048])
linear2.bias                     torch.Size([32])
norm1.weight                     torch.Size([32])
norm1.bias                       torch.Size([32])
norm2.weight                     torch.Size([32])
norm2.bias                       torch.Size([32])
norm3.weight                     torch.Size([32])
norm3.bias                       torch.Size([32])


In [42]:
dec

TransformerDecoderLayer(
  (self_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
  )
  (multihead_attn): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
  )
  (linear1): Linear(in_features=32, out_features=2048, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (linear2): Linear(in_features=2048, out_features=32, bias=True)
  (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (norm3): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
  (dropout1): Dropout(p=0.1, inplace=False)
  (dropout2): Dropout(p=0.1, inplace=False)
  (dropout3): Dropout(p=0.1, inplace=False)
)

In [43]:
y = torch.randn(bs, seq_len-1, d_h)
out = dec(x, y)
print(out.shape)

torch.Size([5, 10, 32])


In [38]:
enc = nn.TransformerEncoder(
    nn.TransformerEncoderLayer(d_h, 4, batch_first=True),
    2
)

In [39]:
out = enc(x)
print(out.shape)

torch.Size([5, 10, 32])


In [44]:
enc

TransformerEncoder(
  (layers): ModuleList(
    (0): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (linear1): Linear(in_features=32, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=32, bias=True)
      (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
      (dropout1): Dropout(p=0.1, inplace=False)
      (dropout2): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerEncoderLayer(
      (self_attn): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
      )
      (linear1): Linear(in_features=32, out_features=2048, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (linear2): Linear(in_features=2048, out_features=32, b