In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import BertModel

from utils.regularizers import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dh = 64
l = 12
d = 768
h = d // dh

'BERT base (l = 12, d = 768, h = 12 ; 110M parameters)'
model = BertModel.from_pretrained("bert-base-uncased")

In [3]:
scores_heads = np.zeros((l,h))
scores_full = np.zeros((l,h))

In [25]:
layers = model.encoder.layer

for i, layer in enumerate(layers):

    self_attention = layer.attention.self
    Wq = self_attention.query._parameters["weight"].reshape(h, dh, d).detach()
    Wq_t = Wq.transpose(-1,-2)
    Wk = self_attention.key._parameters["weight"].reshape(h, dh, d).detach()
    Wk_t = Wk.transpose(-1,-2)

    ## loop over heads
    for j in range(h):

        A = Wq[j,:] @ Wq_t[j,:]
        B = Wk[j,:] @ Wk_t[j,:]
        C = Wk[j,:] @ Wq_t[j,:]
        S = .5 * (1 + (torch.einsum('ij,ji->', C, C)) / torch.einsum('ij,ji->', A, B))
        scores_heads[i,j] = S

    # einsum over heads
    A = torch.matmul(Wq, Wq_t)
    B = torch.matmul(Wk, Wk_t)
    C = torch.matmul(Wk, Wq_t)
    S = .5 * (1 + (torch.einsum('hij,hji->h', C, C)) / torch.einsum('hij,hji->h', A, B))
    scores_full[i,:] = S    

In [27]:
np.allclose(scores_heads, scores_full)

True