In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import BertModel

from utils.regularizers import *

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
n = 4 

## initialize Wq as an orthogonal matrix (QR decomposition)
Wq, _ = np.linalg.qr(np.random.rand(n, n))  

## initialize a skew-symmetric matrix
S = np.random.rand(n, n)
S = S - S.T

Wk = S @ Wq
M = Wq @ Wk.T

print("Wq:\n", Wq)
print("Wk:\n", Wk)
print("M:\n", M)
print("M.T + M:\n", M.T + M)


Wq (orthogonal matrix):
 [[-0.43476016 -0.39724226  0.32263535  0.74100514]
 [-0.37819614 -0.34490959 -0.858437   -0.03302968]
 [-0.79577179  0.18673693  0.29460762 -0.49505851]
 [-0.18627545  0.82967781 -0.26869814  0.45247927]]
Wk:
 [[-0.01224729  0.44582636 -0.73823265  0.55324365]
 [-0.13487001  0.97781431 -0.33192696 -0.03971689]
 [-0.03123068 -0.36076954  0.39958714  0.15191116]
 [ 0.43583008  0.70114881  0.61213617 -0.74271538]]
M:
 [[ 9.00141379e-17 -4.66314855e-01  3.98378648e-01 -8.20866638e-01]
 [ 4.66314855e-01 -1.07504082e-16 -2.11793766e-01 -9.07610898e-01]
 [-3.98378648e-01  2.11793766e-01  3.02577298e-17  3.32136651e-01]
 [ 8.20866638e-01  9.07610898e-01 -3.32136651e-01  5.22922159e-17]]
M.T + M:
 [[ 1.80028276e-16  1.66533454e-16 -1.66533454e-16  0.00000000e+00]
 [ 1.66533454e-16 -2.15008165e-16  8.32667268e-17  0.00000000e+00]
 [-1.66533454e-16  8.32667268e-17  6.05154596e-17 -1.66533454e-16]
 [ 0.00000000e+00  0.00000000e+00 -1.66533454e-16  1.04584432e-16]]


In [16]:
n = 4 
m = 2

## initialize Wq 
Wq = np.random.rand(n, m) 

## initialize a skew-symmetric matrix
S = np.random.rand(m, m)
S = S - S.T

Wk = Wq @ S.T
M = Wq @ Wk.T

print("Wq:\n", Wq)
print("Wk:\n", Wk)
print("M:\n", M)
print("M.T + M:\n", M.T + M)

Wq:
 [[0.91069962 0.34861717]
 [0.08239539 0.95749676]
 [0.47094519 0.17041381]
 [0.64504426 0.82436681]]
Wk:
 [[ 0.12660719 -0.3307385 ]
 [ 0.34773381 -0.02992351]
 [ 0.06188913 -0.17103302]
 [ 0.29938504 -0.23426053]]
M:
 [[ 1.49481878e-17  3.06249197e-01 -3.26263999e-03  1.90982601e-01]
 [-3.06249197e-01  2.72226007e-18 -1.58664184e-01 -1.99635751e-01]
 [ 3.26263999e-03  1.58664184e-01 -2.59802564e-18  1.01072716e-01]
 [-1.90982601e-01  1.99635751e-01 -1.01072716e-01  2.99529288e-18]]
M.T + M:
 [[ 2.98963756e-17  0.00000000e+00  4.33680869e-18  2.77555756e-17]
 [ 0.00000000e+00  5.44452014e-18 -2.77555756e-17  0.00000000e+00]
 [ 4.33680869e-18 -2.77555756e-17 -5.19605128e-18 -1.38777878e-17]
 [ 2.77555756e-17  0.00000000e+00 -1.38777878e-17  5.99058577e-18]]


In [12]:
import numpy as np

# Function to create a random skew-symmetric matrix
def random_skew_symmetric(m):
    M = np.random.rand(m, m)
    return M - M.T

# Dimensions
n = 3  # Number of rows
m = 2  # Number of columns

# Step 1: Initialize matrix A (random n x m matrix)
A = np.random.rand(n, m)

# Step 2: Initialize a skew-symmetric matrix S of size m x m
S = random_skew_symmetric(m)

# Step 3: Calculate matrix B as S * A (S is m x m, A is n x m)
B = np.dot(S, A.T).T  # We need to ensure B is n x m

# Step 4: Calculate C as AB^T
C = np.dot(A, B.T)

# Check if C is skew-symmetric
print("Matrix A (n x m):\n", A)
print("Matrix S (m x m, Skew-symmetric):\n", S)
print("Matrix B (S * A):\n", B)
print("Matrix C (A * B^T):\n", C)
print("C^T + C:\n", C.T + C)  # This should be close to zero matrix if C is skew-symmetric


Matrix A (n x m):
 [[0.37139719 0.7767818 ]
 [0.38608924 0.04250411]
 [0.02016199 0.65433717]]
Matrix S (m x m, Skew-symmetric):
 [[ 0.         -0.58189764]
 [ 0.58189764  0.        ]]
Matrix B (S * A):
 [[-0.4520075   0.21611515]
 [-0.02473304  0.22466442]
 [-0.38075726  0.01173221]]
Matrix C (A * B^T):
 [[ 7.95191620e-18  1.65329450e-01 -1.32298804e-01]
 [-1.65329450e-01  2.38203145e-19 -1.46507612e-01]
 [ 1.32298804e-01  1.46507612e-01  1.30538278e-19]]
C^T + C:
 [[1.59038324e-17 0.00000000e+00 0.00000000e+00]
 [0.00000000e+00 4.76406289e-19 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 2.61076557e-19]]


In [2]:
dh = 64
l = 12
d = 768
h = d // dh

'BERT base (l = 12, d = 768, h = 12 ; 110M parameters)'
model = BertModel.from_pretrained("bert-base-uncased")

In [3]:
scores_heads = np.zeros((l,h))
scores_full = np.zeros((l,h))

In [25]:
layers = model.encoder.layer

for i, layer in enumerate(layers):

    self_attention = layer.attention.self
    Wq = self_attention.query._parameters["weight"].reshape(h, dh, d).detach()
    Wq_t = Wq.transpose(-1,-2)
    Wk = self_attention.key._parameters["weight"].reshape(h, dh, d).detach()
    Wk_t = Wk.transpose(-1,-2)

    ## loop over heads
    for j in range(h):

        A = Wq[j,:] @ Wq_t[j,:]
        B = Wk[j,:] @ Wk_t[j,:]
        C = Wk[j,:] @ Wq_t[j,:]
        S = .5 * (1 + (torch.einsum('ij,ji->', C, C)) / torch.einsum('ij,ji->', A, B))
        scores_heads[i,j] = S

    # einsum over heads
    A = torch.matmul(Wq, Wq_t)
    B = torch.matmul(Wk, Wk_t)
    C = torch.matmul(Wk, Wq_t)
    S = .5 * (1 + (torch.einsum('hij,hji->h', C, C)) / torch.einsum('hij,hji->h', A, B))
    scores_full[i,:] = S    

In [27]:
np.allclose(scores_heads, scores_full)

True