In [1]:
import numpy as np

In [2]:
### numerical evaluation for the time complexity of computing the symmetry score
## -----------
### - d:   embedding dimension
### - d_h: head dimension

## method with numpy.linalg.norm
## ----------
## - compute the M matrix explicitly -> O(d^2 d_h)
## - compute M + M.T 
## - compute np.linalg.norm for M and .5 * (M + M.T)
## - compute the ratio of the two norms
def scores_norm(Wq, Wk):
    
    M = Wq @ Wk.T
    r = np.linalg.norm(.5 * (M + M.T), 'fro')**2 / np.linalg.norm(M, 'fro')**2

    return r

## method with np.einsum (leverage properties of M and cyclicity of trace)
## ----------
## - compute A, B, and C (without computing M explicitly)  -> O(d d_h^2)
## - compute np.einsum of C with itself (fast way to compute the trace)
## - compute np.einsum of A and B (fast way to compute the trace)
## - compute the ratio of the two norms

def scores_trace(Wq, Wk):
     
    ## (k x n) @ (n x k) -> O(nk^2)
    A = Wq.T @ Wq
    B = Wk.T @ Wk
    C = Wk.T @ Wq

    r = .5 * (1 + (np.einsum('ij,ji->', C, C) / np.einsum('ij,ji->', A, B)))

    return r

def scores_trace_heads(Wq, Wk):
     
    ## (h x k x n) @ (h x n x k) -> O(nk^2)
    A = Wq.T @ Wq
    B = Wk.T @ Wk
    C = Wk.T @ Wq

    r = .5 * (1 + (np.einsum('ij,ji->', C, C) / np.einsum('ij,ji->', A, B)))

    return r

In [3]:
## test with BERT large

from transformers import BertModel

'BERT large (l = 24, d = 1024, h = 16 ; 340M parameters)'
l = 14
dh = 64
d = 1024
h = d // dh

model = BertModel.from_pretrained("bert-large-uncased")

Wq = model.encoder.layer[0].attention.self.query._parameters["weight"].T.view(d,h,dh).detach().numpy()
Wk = model.encoder.layer[0].attention.self.key._parameters["weight"].T.view(d,h,dh).detach().numpy()

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
for j in range(l):
    
    print(scores_norm(Wq[:,j,:], Wk[:,j,:]))
    print(scores_trace(Wq[:,j,:], Wk[:,j,:]))

    %timeit scores_norm(Wq[:,j,:], Wk[:,j,:])
    %timeit scores_trace(Wq[:,j,:], Wk[:,j,:])

0.5953140874939474
0.5953149423003197
3.78 ms ± 241 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
593 µs ± 38.3 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5205281040301025
0.5205290913581848
3.95 ms ± 365 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
690 µs ± 64.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5762675599554818
0.5762671381235123
3.95 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
826 µs ± 252 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5275472278991179
0.5275460779666901
4.16 ms ± 467 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
703 µs ± 44.8 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.5186810611568575
0.5186807066202164
4.09 ms ± 246 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
800 µs ± 107 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.6056445305491756
0.6056437641382217
3.76 ms ± 579 µs per loop (mean ± std. dev. of 7 runs,

In [30]:
## test with LLAMA 2

from transformers import AutoModelForCausalLM

'LLAMA 2 7b (l = 32, d = 4096, h = 32 ; tot num parameters 7B)'
dh = 128
l = 32
d = 4096
h = 32

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

Wq = model.model.layers[0].self_attn.q_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()
Wk = model.model.layers[0].self_attn.k_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()

Downloading shards: 100%|██████████| 2/2 [00:00<00:00,  5.90it/s]
Loading checkpoint shards: 100%|██████████| 2/2 [00:22<00:00, 11.37s/it]


In [31]:
for j in range(1):
    
    print(scores_norm(Wq[:,j,:], Wk[:,j,:]))
    print(scores_trace(Wq[:,j,:], Wk[:,j,:]))

    %timeit scores_norm(Wq[:,j,:], Wk[:,j,:])
    %timeit scores_trace(Wq[:,j,:], Wk[:,j,:])

0.5261237925015685
0.5260768656918057
123 ms ± 5 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
4.94 ms ± 622 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [4]:
### numerical evaluation for grouping the computation of the symmetry scores
### across heads and layers
### -----------
### - d:        embedding dimension
### - d_h:      head dimension
### - layers:   # of layers

def scores_trace_full(layers):

    Wq, Wk = stack_layers(layers)

    A = np.einsum('lnhi,lnhj->lhij', Wq, Wq)
    B = np.einsum('lnhi,lnhj->lhij', Wk, Wk)
    C = np.einsum('lnhi,lnhj->lhij', Wk, Wq)

    r = 0.5 * (1 + np.einsum('lhij,lhji->lh', C, C) / np.einsum('lhij,lhji->lh', A, B))

    return r

def scores_trace_loop(h,l,layers):

    Wq, Wk = stack_layers(layers)

    r = []

    for i in range(l):
        for j in range(h):        

            r.append(scores_trace(Wq[i,:,j,:], Wk[i,:,j,:]))

    return r

In [39]:
## test with TInyGPT 8M parameters

from transformers import AutoModelForCausalLM

'TinyGPT 8m (l = 8, d = 64, h = 4 ; 8M parameters)'
dh = 64
l = 8
d = 256
h = d // dh

model = AutoModelForCausalLM.from_pretrained("roneneldan/TinyStories-8M")

In [40]:
def stack_layers(layers):

    WqList = []
    WkList = []
    for i, layer in enumerate(layers):

        self_attention = layer.attn.attention

        Wq = self_attention.q_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()
        Wk = self_attention.k_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()

        WqList.append(Wq)
        WkList.append(Wk)

    WqList = np.stack(WqList,axis=0)
    WkList = np.stack(WkList,axis=0)

    return WqList, WkList

In [41]:
layers = model.transformer.h

print(scores_trace_full(layers))
%timeit scores_trace_full(layers)

print(scores_trace_loop(h,l,layers))
%timeit scores_trace_loop(h,l,layers)

[[0.51008594 0.5002863  0.5011067  0.5099732 ]
 [0.51710474 0.48365727 0.5046482  0.5378101 ]
 [0.41431645 0.4889804  0.44898984 0.4257167 ]
 [0.47386387 0.52357817 0.47452947 0.52988505]
 [0.5297684  0.41397804 0.38179463 0.39657295]
 [0.45648208 0.4570157  0.50635654 0.40751418]
 [0.5407863  0.56023765 0.5436243  0.5434173 ]
 [0.5298016  0.5333205  0.5603063  0.5317767 ]]
20.9 ms ± 184 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
[0.5100859254598618, 0.5002863103582058, 0.5011066528968513, 0.5099732391536236, 0.5171047709882259, 0.4836572855710983, 0.5046482090838253, 0.537810105830431, 0.41431643068790436, 0.48898041900247335, 0.4489898346364498, 0.42571668326854706, 0.4738638661801815, 0.5235781688243151, 0.47452946938574314, 0.5298850629478693, 0.5297684259712696, 0.41397804766893387, 0.38179463893175125, 0.39657294005155563, 0.45648208260536194, 0.4570157080888748, 0.5063565503805876, 0.40751418471336365, 0.5407863259315491, 0.5602376125752926, 0.5436243042349815, 0.54

In [42]:
## test with BERT large

from transformers import BertModel

'BERT large (l = 24, d = 1024, h = 16 ; 340M parameters)'
dh = 64
d = 1024
h = d // dh

model = BertModel.from_pretrained("bert-large-uncased")

In [45]:
def stack_layers(layers):

    WqList = []
    WkList = []
    for i, layer in enumerate(layers):

        # access self-attention module within layer
        self_attention = layer.attention.self

        # access parameters of Conv1D
        Wq = self_attention.query._parameters["weight"].T.view(d,h,dh).detach().numpy()
        Wk = self_attention.key._parameters["weight"].T.view(d,h,dh).detach().numpy()

        WqList.append(Wq)
        WkList.append(Wk)

    WqList = np.stack(WqList,axis=0)
    WkList = np.stack(WkList,axis=0)

    return WqList, WkList

In [46]:
layers = model.encoder.layer

print(scores_trace_full(layers))
%timeit scores_trace_full(layers)

print(scores_trace_loop(h,l,layers))
%timeit scores_trace_loop(h,l,layers)

[[0.5953149  0.5205291  0.5762671  0.52754605 0.5186807  0.60564375
  0.7040896  0.5424154  0.5100304  0.650583   0.5411442  0.65075475
  0.7321492  0.5443711  0.51514816 0.506263  ]
 [0.6182766  0.7586884  0.6002054  0.65335536 0.71380043 0.53553736
  0.59238815 0.52918947 0.59624547 0.5107676  0.55948055 0.5727372
  0.52321815 0.6255044  0.5412552  0.5369339 ]
 [0.5764037  0.5160598  0.6525016  0.5320777  0.6415573  0.5513461
  0.5023911  0.5294934  0.740103   0.50999993 0.52891695 0.47562772
  0.5361897  0.6094643  0.5196909  0.53254163]
 [0.54349697 0.64425504 0.52394927 0.5461073  0.53182757 0.55093443
  0.6887869  0.5313994  0.51421505 0.40598178 0.5647216  0.5434591
  0.62864274 0.66327137 0.7294234  0.533965  ]
 [0.5858928  0.5748431  0.6538075  0.6340821  0.52380526 0.59078705
  0.5359635  0.5951824  0.55511755 0.5569792  0.52062106 0.5553155
  0.78356344 0.5621277  0.5184722  0.53117317]
 [0.5476827  0.5499055  0.57550657 0.53467035 0.61678123 0.61833227
  0.62975156 0.615020

In [5]:
## test with LLAMA 2

from transformers import AutoModelForCausalLM

'LLAMA 2 7b (l = 32, d = 4096, h = 32 ; tot num parameters 7B)'
dh = 128
l = 32
d = 4096
h = 32

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:20<00:00, 10.09s/it]


In [6]:
def stack_layers(layers):

    WqList = []
    WkList = []
    for i, layer in enumerate(layers):

        # access self-attention module within layer
        self_attention = layer.self_attn

        # access parameters of Conv1D
        Wq = self_attention.q_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()
        Wk = self_attention.k_proj._parameters["weight"].T.view(d,h,dh).detach().numpy()

        WqList.append(Wq)
        WkList.append(Wk)

    WqList = np.stack(WqList,axis=0)
    WkList = np.stack(WkList,axis=0)

    return WqList, WkList

In [7]:
layers = model.model.layers

print(scores_trace_full(layers))
%timeit scores_trace_full(layers)

print(scores_trace_loop(h,l,layers))
%timeit scores_trace_loop(h,l,layers)

[[0.52607685 0.53095376 0.50443894 ... 0.5699742  0.503946   0.51258475]
 [0.5351433  0.85908824 0.51482266 ... 0.51724714 0.5058044  0.5098453 ]
 [0.5556293  0.5044855  0.54950416 ... 0.66939104 0.5937965  0.5016917 ]
 ...
 [0.5050277  0.5271805  0.50715715 ... 0.489089   0.5075681  0.53097606]
 [0.50809693 0.5172196  0.47805548 ... 0.56132686 0.55978334 0.5196364 ]
 [0.505103   0.5139302  0.5019172  ... 0.55089754 0.5231217  0.48070505]]
