In [1]:
import torch
from torch import nn

In [4]:
x = torch.randn(2, 357)
print(x.unsqueeze(1).repeat(1, 4, 1).shape)

torch.Size([2, 4, 357])


In [3]:
encoder_layer = nn.TransformerEncoderLayer(
            d_model=512,
            nhead=4,
            dim_feedforward=4*512,
            dropout=0.1,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=4
        )




In [None]:
for name, module in encoder.named_modules():
    print(name)



layers
layers.0
layers.0.self_attn
layers.0.self_attn.out_proj
layers.0.linear1
layers.0.dropout
layers.0.linear2
layers.0.norm1
layers.0.norm2
layers.0.dropout1
layers.0.dropout2
layers.1
layers.1.self_attn
layers.1.self_attn.out_proj
layers.1.linear1
layers.1.dropout
layers.1.linear2
layers.1.norm1
layers.1.norm2
layers.1.dropout1
layers.1.dropout2
layers.2
layers.2.self_attn
layers.2.self_attn.out_proj
layers.2.linear1
layers.2.dropout
layers.2.linear2
layers.2.norm1
layers.2.norm2
layers.2.dropout1
layers.2.dropout2
layers.3
layers.3.self_attn
layers.3.self_attn.out_proj
layers.3.linear1
layers.3.dropout
layers.3.linear2
layers.3.norm1
layers.3.norm2
layers.3.dropout1
layers.3.dropout2


In [5]:
# Let's inspect where the q, k, v (query, key, value) projections are in the encoder modules.
# In PyTorch's TransformerEncoderLayer, these are typically in the MultiheadAttention submodule.

for name, module in encoder.named_modules():
    if isinstance(module, nn.MultiheadAttention):
        print(f"MultiheadAttention module: {name}")
        print("  q_proj_weight shape:", module.in_proj_weight[:module.embed_dim].shape)
        print("  k_proj_weight shape:", module.in_proj_weight[module.embed_dim:2*module.embed_dim].shape)
        print("  v_proj_weight shape:", module.in_proj_weight[2*module.embed_dim:].shape)
        print("  in_proj_weight shape:", module.in_proj_weight.shape)
        print("  in_proj_bias shape:", module.in_proj_bias.shape)
        print("  out_proj:", module.out_proj)
        print()

# You can also look for the submodules directly:
for name, module in encoder.named_modules():
    if isinstance(module, nn.TransformerEncoderLayer):
        print(f"--- {name} ---")
        for subname, submodule in module.named_modules():
            if isinstance(submodule, nn.MultiheadAttention):
                print(f"  MultiheadAttention: {subname}")
                print(f"    in_proj_weight shape: {submodule.in_proj_weight.shape}")
                print(f"    out_proj: {submodule.out_proj}")


MultiheadAttention module: layers.0.self_attn
  q_proj_weight shape: torch.Size([512, 512])
  k_proj_weight shape: torch.Size([512, 512])
  v_proj_weight shape: torch.Size([512, 512])
  in_proj_weight shape: torch.Size([1536, 512])
  in_proj_bias shape: torch.Size([1536])
  out_proj: NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)

MultiheadAttention module: layers.1.self_attn
  q_proj_weight shape: torch.Size([512, 512])
  k_proj_weight shape: torch.Size([512, 512])
  v_proj_weight shape: torch.Size([512, 512])
  in_proj_weight shape: torch.Size([1536, 512])
  in_proj_bias shape: torch.Size([1536])
  out_proj: NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)

MultiheadAttention module: layers.2.self_attn
  q_proj_weight shape: torch.Size([512, 512])
  k_proj_weight shape: torch.Size([512, 512])
  v_proj_weight shape: torch.Size([512, 512])
  in_proj_weight shape: torch.Size([1536, 512])
  in_proj_bias shape: torch.Size([1536])