In [1]:
import torch
from transformers import BertModel

model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

Each head computation in attention (we assume matrices are processed/stored as in PyTorch, and we omit biases and scaling) is
$$
\sigma \left[ XW_Q^T W_KX^T \right] X W_V^T W_O^T
$$

In [13]:
def compute_energy_rank(matrix, energy_ratio=0.01):
    singular_values = torch.linalg.svdvals(matrix)
    total_energy = torch.sum(singular_values ** 2)
    sorted_vals = torch.sort(singular_values, descending=True).values
    cumulative_energy = torch.cumsum(sorted_vals ** 2, dim=0)
    rank = torch.sum(cumulative_energy < (1 - energy_ratio) * total_energy).item() + 1
    return rank

for i in range(len(model.encoder.layer)):
    attention = model.encoder.layer[i].attention
    WQ = attention.self.query.weight.detach()
    WK = attention.self.key.weight.detach()
    WV = attention.self.value.weight.detach()

    WO = attention.output.dense.weight.detach()

    WV_WO = WV.T @ WO.T
    WQ_WK = WQ.T @ WK

    rank_WQ = compute_energy_rank(WQ)
    rank_WK = compute_energy_rank(WK)
    rank_WV = compute_energy_rank(WV)
    rank_WO = compute_energy_rank(WO)
    rank_WV_WO = compute_energy_rank(WV_WO)
    rank_WQ_WK = compute_energy_rank(WQ_WK)

    print(f"Layer {i}:")
    print("  Energy-based Rank of WQ:", rank_WQ)
    print("  Energy-based Rank of WK:", rank_WK)
    print("  Energy-based Rank of WV:", rank_WV)
    print("  Energy-based Rank of WO:", rank_WO)
    print("  Energy-based Rank of WV_WO:", rank_WV_WO)
    print("  Energy-based Rank of WQ_WK:", rank_WQ_WK)

Layer 0:
  Energy-based Rank of WQ: 531
  Energy-based Rank of WK: 532
  Energy-based Rank of WV: 553
  Energy-based Rank of WO: 549
  Energy-based Rank of WV_WO: 345
  Energy-based Rank of WQ_WK: 314
Layer 1:
  Energy-based Rank of WQ: 538
  Energy-based Rank of WK: 532
  Energy-based Rank of WV: 554
  Energy-based Rank of WO: 550
  Energy-based Rank of WV_WO: 381
  Energy-based Rank of WQ_WK: 321
Layer 2:
  Energy-based Rank of WQ: 523
  Energy-based Rank of WK: 513
  Energy-based Rank of WV: 550
  Energy-based Rank of WO: 546
  Energy-based Rank of WV_WO: 404
  Energy-based Rank of WQ_WK: 205
Layer 3:
  Energy-based Rank of WQ: 541
  Energy-based Rank of WK: 536
  Energy-based Rank of WV: 546
  Energy-based Rank of WO: 552
  Energy-based Rank of WV_WO: 419
  Energy-based Rank of WQ_WK: 340
Layer 4:
  Energy-based Rank of WQ: 540
  Energy-based Rank of WK: 541
  Energy-based Rank of WV: 553
  Energy-based Rank of WO: 558
  Energy-based Rank of WV_WO: 419
  Energy-based Rank of WQ_WK:

In [14]:
@torch.no_grad()
def per_head_ranks(attention, energy_ratio=0.01):

    W_Q = attention.self.query.weight.detach()   # [hidden, hidden]; rows = heads concatenated
    W_K = attention.self.key.weight.detach()
    W_V = attention.self.value.weight.detach()
    W_O = attention.output.dense.weight.detach() # [hidden, hidden]; cols = heads concatenated

    num_heads = attention.self.num_attention_heads
    head_dim   = attention.self.attention_head_size  # hidden_size // num_heads

    results = []

    for h in range(num_heads):
        rs = slice(h * head_dim, (h + 1) * head_dim)  # row block for Q/K/V
        cs = slice(h * head_dim, (h + 1) * head_dim)  # column block for O

        WQ_h = W_Q[rs, :]          # [head_dim, hidden]
        WK_h = W_K[rs, :]          # [head_dim, hidden]
        WV_h = W_V[rs, :]          # [head_dim, hidden]
        WO_h = W_O[:, cs]          # [hidden, head_dim]

        # combine each head individually
        WV_WO_h = WV_h.T @ WO_h.T     # [hidden, hidden]
        WQ_WK_h = WQ_h.T @ WK_h  # [hidden, hidden]

        ranks = {
            "head": h,
            "rank_WQ": compute_energy_rank(WQ_h, energy_ratio=energy_ratio),
            "rank_WK": compute_energy_rank(WK_h, energy_ratio=energy_ratio),
            "rank_WV": compute_energy_rank(WV_h, energy_ratio=energy_ratio),
            "rank_WO": compute_energy_rank(WO_h, energy_ratio=energy_ratio),
            "rank_WV_WO": compute_energy_rank(WV_WO_h, energy_ratio=energy_ratio),
            "rank_WQ_WK": compute_energy_rank(WQ_WK_h, energy_ratio=energy_ratio),
        }
        results.append(ranks)

    return results

energy_ratio = 0.01
for i, layer in enumerate(model.encoder.layer):
    attention = layer.attention
    ranks = per_head_ranks(attention, energy_ratio=energy_ratio)

    print(f"Layer {i}:")
    for r in ranks:
        print(
            f"  Head {r['head']:2d}: "
            f"rank WQ {r['rank_WQ']:3d}, "
            f"rank WK {r['rank_WK']:3d}, "
            f"rank WV {r['rank_WV']:3d}, "
            f"rank WO {r['rank_WO']:3d}, "
            f"rank WV_WO {r['rank_WV_WO']:3d}, "
            f"rank WQ_WK {r['rank_WQ_WK']:3d}"
        )



Layer 0:
  Head  0: rank WQ  61, rank WK  62, rank WV  62, rank WO  62, rank WV_WO  59, rank WQ_WK  55
  Head  1: rank WQ  63, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  62, rank WQ_WK  61
  Head  2: rank WQ  63, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  63, rank WQ_WK  63
  Head  3: rank WQ  62, rank WK  62, rank WV  62, rank WO  63, rank WV_WO  61, rank WQ_WK  57
  Head  4: rank WQ  60, rank WK  62, rank WV  63, rank WO  62, rank WV_WO  59, rank WQ_WK  52
  Head  5: rank WQ  63, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  63, rank WQ_WK  61
  Head  6: rank WQ  60, rank WK  62, rank WV  63, rank WO  63, rank WV_WO  61, rank WQ_WK  53
  Head  7: rank WQ  62, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  61, rank WQ_WK  57
  Head  8: rank WQ  63, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  61, rank WQ_WK  59
  Head  9: rank WQ  63, rank WK  63, rank WV  63, rank WO  63, rank WV_WO  62, rank WQ_WK  60
  Head 10: rank WQ  63, rank WK  63, rank WV  63, r

It is a fact that
$$
\operatorname{rank} AB = \operatorname{dim} (\operatorname{im(B)}) - \operatorname{dim}\left(\operatorname{im}(B) \cap \operatorname{ker}(A)\right).
$$


As per the first code cell, $W_Q^TW_K$ is generally significantly lower rank than $W_Q$ and $W_K$ . This means their images and kernels overlap as noted.

However, given a head $h$, the second cell shows that the rank of $W_{Q, h}^TW_{K, h}$ is about the same as of the ranks of the individual matrices. This means their images and kernels do not overlap.

Note that $W_{Q, h}^T \in \mathbb{R}^{D \times H}$ and $W_{K, h} \in \mathbb{R}^{H \times D}$, while $W_Q^T \in \mathbb{R}^{D \times H \cdot N}$ and $W_{K} \in \mathbb{R}^{H \cdot N \times D}$.

We assume $H < D$. What we can observe is
$$
\operatorname{dim} (\operatorname{im}(W^T_{Q, h})) \approx H, \\
\operatorname{dim} (\operatorname{im}(W_{K, h})) \approx H, \\
\operatorname{dim}\left(\operatorname{im}(W^T_{Q, h}) \cap \operatorname{ker}(W_{K, h})\right) \approx 0, \\
\operatorname{dim}\left(\operatorname{im}(W_{K, h}) \cap \operatorname{ker}(W^T_{Q, h})\right) \approx 0,
$$

but

$$
\operatorname{dim} (\operatorname{im}(W^T_Q)) \approx R, \\
\operatorname{dim} (\operatorname{im}(W_K)) \approx R, \\
\operatorname{dim}\left(\operatorname{im}(W^T_Q) \cap \operatorname{ker}(W_K)\right) > 0, \\
\operatorname{dim}\left(\operatorname{im}(W_K) \cap \operatorname{ker}(W^T_Q)\right) > 0,
$$
with $\min (N \cdot H, D) > R > H$. The $>$ symbols are used when the inequality is "significant".

Note: multiplying them involve summation across heads:
$$
\begin{bmatrix}
W_{Q,1}^T & W_{Q,2}^T & \cdots & W_{Q,N}^T
\end{bmatrix}
\begin{bmatrix}
W_{K,1}\\
W_{K,2}\\
\vdots\\
W_{K,N}
\end{bmatrix}
=
\sum_{h=1}^{N} W_{Q,h}^T W_{K,h}.
$$


