In [22]:
import numpy as np
from utils.funs import scores, scores_efficient

In [39]:
### numerical evaluation for the time complexity of computing the symmetry score
## -----------
### - d:   embedding dimension
### - d_h: head dimension

## method with numpy.linalg.norm
## ----------
## - compute the M matrix explicitly -> O(d^2 d_h)
## - compute M + M.T 
## - compute np.linalg.norm for M and .5 * (M + M.T)
## - compute the ratio of the two norms
def scores_norm(Wq, Wk):
    
    M = Wq @ Wk.T
    r = np.linalg.norm(.5 * (M + M.T), 'fro')**2 / np.linalg.norm(M, 'fro')**2

    return r

## method with np.einsum (leverage properties of M and cyclicity of trace)
## ----------
## - compute A, B, and C (without computing M explicitly)  -> O(d d_h^2)
## - compute np.einsum of C with itself (fast way to compute the trace)
## - compute np.einsum of A and B (fast way to compute the trace)
## - compute the ratio of the two norms
def scores_trace(Wq, Wk):
     
    ## (k x n) @ (n x k) -> O(nk^2)
    A = Wq.T @ Wq
    B = Wk.T @ Wk
    C = Wk.T @ Wq

    r = .5 * (1 + (np.einsum('ij,ji->', C, C) / np.einsum('ij,ji->', A, B)))

    return r

In [27]:
## test with BERT large

from transformers import BertModel

'BERT large (l = 24, d = 1024, h = 16 ; 340M parameters)'
dh = 64
d = 1024
h = d // dh

model = BertModel.from_pretrained("bert-large-uncased")

Wq = model.encoder.layer[0].attention.self.query._parameters["weight"].T.view(d,h,dh).detach().numpy()
Wk = model.encoder.layer[0].attention.self.key._parameters["weight"].T.view(d,h,dh).detach().numpy()

In [29]:
for j in range(1):
    
    print(scores_norm(Wq[:,j,:], Wk[:,j,:]))
    print(scores_trace(Wq[:,j,:], Wk[:,j,:]))

    %timeit scores_norm(Wq[:,j,:], Wk[:,j,:])
    %timeit scores_trace(Wq[:,j,:], Wk[:,j,:])

0.5953140874939474
0.5953149540985507
4.26 ms ± 328 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
634 µs ± 32.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [30]:
## test with LLAMA 2

from transformers import AutoModelForCausalLM

'LLAMA 2 7b (l = 32, d = 4096, h = 32 ; tot num parameters 7B)'
dh = 128
l = 32
d = 4096
h = 32

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

Wq = model.model.layers[0].self_attn.q_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()
Wk = model.model.layers[0].self_attn.k_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()

Downloading shards: 100%|██████████| 2/2 [00:00<00:00,  5.90it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:22<00:00, 11.37s/it]


In [31]:
for j in range(1):
    
    print(scores_norm(Wq[:,j,:], Wk[:,j,:]))
    print(scores_trace(Wq[:,j,:], Wk[:,j,:]))

    %timeit scores_norm(Wq[:,j,:], Wk[:,j,:])
    %timeit scores_trace(Wq[:,j,:], Wk[:,j,:])

0.5261237925015685
0.5260768656918057
123 ms ± 5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.94 ms ± 622 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [43]:
## test with TInyGPT 8M parameters

from transformers import AutoModelForCausalLM

'TinyGPT 8m (l = 8, d = 64, h = 4 ; 8M parameters)'
dh = 64
l = 8
d = 256
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-8M")

In [44]:
layers = model.transformer.h

for i, layer in enumerate(layers):

    # access self-attention module within layer
    self_attention = layer.attn.attention

    # access parameters of Conv1D
    Wq = self_attention.q_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()
    Wk = self_attention.k_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()

    for j in range(h):
            
        print(scores_norm(Wq[:,j,:], Wk[:,j,:]))
        print(scores_trace(Wq[:,j,:], Wk[:,j,:]))

        %timeit scores_norm(Wq[:,j,:], Wk[:,j,:])
        %timeit scores_trace(Wq[:,j,:], Wk[:,j,:])

0.5100861340260628
0.5100859254598618
289 µs ± 26.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
205 µs ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5002862416666248
0.5002863103582058
325 µs ± 29.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
203 µs ± 12.7 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
0.5011066375787423
0.5011066528968513
305 µs ± 24.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
197 µs ± 13.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5099731264739339
0.5099732391536236
273 µs ± 27.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
187 µs ± 14.8 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
0.5171045660543462
0.5171047709882259
252 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
