**Goal**
Optimize the follow-up attention calculation in matrix form.

In [127]:
import os

import numpy as np
from typing import List, Tuple, Any
import json
import seaborn as sns
import matplotlib.pyplot as plt
import torch

import project_path

import matplotlib as mpl
import seaborn as sns

from attwizard.decoder import get_attention_tensor
from attwizard.decoder import merge_attention_prompt_and_new_tokens
from attwizard.decoder import get_attention_matrix
from attwizard.decoder import condense_attention
from attwizard.decoder import heatmap_visualize
from attwizard.decoder import normalize_less_attention_on_early_tokens

import torch

In [128]:
%load_ext autoreload
%autoreload 2

## Parameters

In [129]:
MODEL_FOLDER = "../huggingface_models"
HUGGING_FACE_REPO = "Salesforce/codegen-350M-mono"
FOLDER_WITH_SAMPLES = "code_snippet_samples"
OUTPUT_FOLDER = "tmp"

EXP_NAME = "exp_v07"
EXP_FOLDER = os.path.join("..", "data", "model_output", EXP_NAME)

# Query Model (CodeGen 350M)

In [5]:
from transformers import AutoTokenizer
from attwizard.models.modeling_codegen import CodeGenForCausalLM
from attwizard.script.utils import get_model_folder_path
from pprint import pprint

model_folder_path = get_model_folder_path(
    model_folder=MODEL_FOLDER,
    hugging_short_repo=HUGGING_FACE_REPO
)

if os.path.exists(os.path.join(model_folder_path, "pytorch_model.bin")):
    print("Model loaded locally...")
    tokenizer = AutoTokenizer.from_pretrained(model_folder_path)
    model = CodeGenForCausalLM.from_pretrained(model_folder_path)
    pprint(str(model)[:500])
else:
    print("You must download the model first with: attwizard.script.download_model.py.")

Model loaded locally...
('CodeGenForCausalLM(\n'
 '  (transformer): CodeGenModel(\n'
 '    (wte): Embedding(51200, 1024)\n'
 '    (drop): Dropout(p=0.0, inplace=False)\n'
 '    (h): ModuleList(\n'
 '      (0): CodeGenBlock(\n'
 '        (ln_1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n'
 '        (attn): CodeGenAttention(\n'
 '          (attn_dropout): Dropout(p=0.0, inplace=False)\n'
 '          (resid_dropout): Dropout(p=0.0, inplace=False)\n'
 '          (qkv_proj): Linear(in_features=1024, out_features=3072, '
 'bias=False)\n'
 '          (out_proj): Linear(in')


In [6]:
prompt = """a = 3
b = 5
c = a
d = 7
e = d
f = b
g = a
h = 4 + d
i = d
j = h
print(f) # prints the value '5'
print(g) # prints the value '3'
print(f) # prints the value '5'
print(e) # prints the value '7'
print(c) # prints the value '3'
print(h) # prints the value '11'
print(j) # prints the value '4'
print(i) # prints the value '7'
"""


In [7]:
tmp = tokenizer(prompt, return_tensors="pt")
input_ids = tmp['input_ids']
attention_mask = tmp['attention_mask']
torch.manual_seed(37)  # 42

N_NEW_TOKENS = 20

model_output = model.generate(
    input_ids,
    attention_mask=attention_mask,
    do_sample=True,
    max_time=3,
    eos_token_id=tokenizer.eos_token_id,
    temperature=0.9,
    output_attentions=True,
    max_length=len(input_ids[0]) + N_NEW_TOKENS, 
    return_dict_in_generate=True
)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [8]:
generated_text = tokenizer.decode(model_output["sequences"][0])
print(generated_text)

a = 3
b = 5
c = a
d = 7
e = d
f = b
g = a
h = 4 + d
i = d
j = h
print(f) # prints the value '5'
print(g) # prints the value '3'
print(f) # prints the value '5'
print(e) # prints the value '7'
print(c) # prints the value '3'
print(h) # prints the value '11'
print(j) # prints the value '4'
print(i) # prints the value '7'
print(k) # prints the value '1'
print(l) # prints the value


## Backup Attention Tensor

In [10]:
from attwizard.decoder import get_attention_tensor
from attwizard.decoder import merge_attention_prompt_and_new_tokens
from attwizard.decoder import get_attention_matrix
from attwizard.decoder import condense_attention
from attwizard.decoder import heatmap_visualize
from attwizard.decoder import normalize_less_attention_on_early_tokens

import torch

tokens_all_attended = tokenizer.convert_ids_to_tokens(model_output["sequences"][0])
tokens_prompt = tokenizer.convert_ids_to_tokens(input_ids[0])
tokens_generated = tokens_all_attended[len(tokens_prompt):]

att_tensor = get_attention_tensor(
    model_output=model_output
)
att_tensor = normalize_less_attention_on_early_tokens(
    att_tensor=att_tensor
)


In [13]:
# Save tensor
att_tensor_np = att_tensor.numpy()
# Save tensor as npy file
with open(os.path.join(OUTPUT_FOLDER, "raw_attention.npy"), "wb") as f:
    np.save(f, att_tensor_np)

### Compute Albert Attention

In [36]:

import typing
from typing import ForwardRef
def get_follow_up_attention_matrix_v2(
        attention_tensor: typing.Union[np.ndarray, ForwardRef('torch.Tensor')],
        normalise: bool = True, minus_mean = False):
    """
    Extract the follow-up attention for a layer and head. 
    Assume that dims are:
    0 is irrelevant (sum over it!)
    1 is layer
    2 isn’t there
    3 is head
    4 is attended from or to
    5 is attended to or from
    """
    assert len(attention_tensor.shape) == 6
    print("Start: ", attention_tensor.shape)
    n_layers = attention_tensor.shape[1]
    n_heads = attention_tensor.shape[3]
    n_tokens = attention_tensor.shape[4]
    assert n_tokens == attention_tensor.shape[5], "to and from dimension mismatch"
    attention_tensor = attention_tensor.sum(2).sum(0)  # sum over the heads
    output = np.zeros((n_tokens, n_tokens, n_layers - 1))
    print("Condensed start: ", attention_tensor.shape)
    # Condensed start:  (20, 16, 157, 157)

    print("Start output: ", output.shape)
    for k_layer in range(n_layers - 1):
        print("Layer ", k_layer, ")")
        print("Select a layer and condense on the head dimension")
        print("Before size: ", attention_tensor.shape)
        layer_matrix = attention_tensor[k_layer, :, :, :].sum(0)
        next_layer_matrix = attention_tensor[k_layer + 1, :, :, :].sum(0)
        print("After size: ", layer_matrix.shape)
        # make it so each recipient has sum of attention = 1

        print("Make is so that each person (recipient) has sum of attention of 1")
        for i in range(n_tokens):
            print("layer_matrix[i, :].sum(): ", layer_matrix[i, :].sum())
            layer_matrix[i, :] = layer_matrix[i, :] / layer_matrix[i, ].sum()
            next_layer_matrix[i, :] = next_layer_matrix[i, :] / next_layer_matrix[i, ].sum()
            print("layer_matrix[i, :].sum(): ", layer_matrix[i, :].sum())
        
        # Condensed start:  (20, 16, 157, 157)
        for i_first_attended in range(n_tokens):
            for j_second_attended in range(n_tokens):
                # only tokens after i and j can be compared:
                max_i_j = max(i_first_attended, j_second_attended)
                # who attended to it in layer k?
                attend_to_i_in_layer = attention_tensor[k_layer, :, max_i_j:, i_first_attended].sum(0)
                # who attended to it in layer k+1?
                attend_to_j_in_next_layer = attention_tensor[k_layer + 1, :, max_i_j:, j_second_attended].sum(0)
                if minus_mean:
                    attend_to_i_in_layer = attend_to_i_in_layer - torch.mean(attend_to_i_in_layer)
                    attend_to_j_in_next_layer = attend_to_j_in_next_layer - torch.mean(attend_to_j_in_next_layer)
                if normalise:
                    attend_to_i_in_layer = attend_to_i_in_layer / np.linalg.norm(attend_to_i_in_layer)
                    attend_to_j_in_next_layer = attend_to_j_in_next_layer / np.linalg.norm(attend_to_j_in_next_layer)
                # take the dot product of the two
                dotproduct = np.dot(attend_to_i_in_layer, attend_to_j_in_next_layer)
                output[i_first_attended, j_second_attended, k_layer] = \
                    0 if np.isnan(dotproduct) else dotproduct
    return output

In [37]:
att_matrix_3 = get_follow_up_attention_matrix_v2(att_tensor_np, normalise=True) # best so far

Start:  (20, 20, 1, 16, 157, 157)
Condensed start:  (20, 16, 157, 157)
Start output:  (157, 157, 19)
Layer  0 )
Select a layer and condense on the head dimension
Before size:  (20, 16, 157, 157)
After size:  (157, 157)
Make is so that each person (recipient) has sum of attention of 1
layer_matrix[i, :].sum():  16.0
layer_matrix[i, :].sum():  1.0


In [17]:
def normalize_matrix_by_lines(att_matrix):
    att_matrix[0,:] = 0.0000001
    att_matrix[:,0] = 0.0000001
    for i in range(att_matrix.shape[0]):
        att_matrix[i, :] = att_matrix[i, :] / att_matrix[i, :].sum()
    return att_matrix
att_matrix_3_processed = normalize_matrix_by_lines(att_matrix_3.sum(2))

# Save tensor as npy file
with open(os.path.join(OUTPUT_FOLDER, "attention_v3.npy"), "wb") as f:
    np.save(f, att_matrix_3_processed)


## Read (Pre-computed) Attention tensor

In [130]:
# read raw attention
with open(os.path.join(OUTPUT_FOLDER, "raw_attention.npy"), "rb") as f:
    att_tensor_np = np.load(f)

In [131]:
# read attention matrix v3
with open(os.path.join(OUTPUT_FOLDER, "attention_v3.npy"), "rb") as f:
    att_matrix_3_processed = np.load(f)

## Simple Example

In [84]:
import torch
import numpy as np
from pprint import pprint

In [None]:
# 1 - slice/expand a matrix removing x lines aka [1..n] from the matrix

# 2 - normalize every column (divide by the 2-norm) 
    # compute the 2-norm for each column
        # compute the dot product of each column (with itself)
        # compute the square root of each result

# repeat this for two consecutive layers.

# 3 - multiply each matrix slice by the the corresponding slice of the consecutive level matrix
# (to get the similarity between token A and token B)
# note that this multiplication will create the final results only for the 
# form token X (line) to token Y (column), where either X or Y is the first not empty line of the matrix slice

# 4 - Merge the results for all the matrix multiplications


# 5 - generalize in parallel for multiple layers

# 6 - 

In [38]:
def crete_random_matrix(size: int):
    """Create a random attention matrix with the given size."""
    base = torch.rand(size, size)
    mask = torch.tril(torch.ones(size, size))
    return base * mask

N_TOKENS = 7
# set torch seed 
torch.manual_seed(42)
a = crete_random_matrix(N_TOKENS)
b = crete_random_matrix(N_TOKENS)
b


tensor([[0.1952, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6581, 0.4913, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3278, 0.6532, 0.3958, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9497, 0.6666, 0.9811, 0.0874, 0.0000, 0.0000, 0.0000],
        [0.7025, 0.6790, 0.9155, 0.2418, 0.1591, 0.0000, 0.0000],
        [0.8035, 0.3813, 0.7860, 0.1115, 0.2477, 0.6524, 0.0000],
        [0.3725, 0.7980, 0.8399, 0.1374, 0.2331, 0.9578, 0.3313]])

In [39]:
# 1 - slice/expand a matrix removing x lines aka [1..n] from the matrix
slicer_mat_rows = torch.triu(torch.ones(N_TOKENS, N_TOKENS))
# add dimension
slicer_mat_rows = slicer_mat_rows.unsqueeze(2)
slicer_mat_rows = slicer_mat_rows.expand(-1, -1, N_TOKENS)
slicer_mat_rows


tensor([[[1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1., 1., 1.]],

        [[0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0.],
         [1., 1., 1., 1., 1., 1., 1.],
         [1., 1., 1

In [40]:
slicer_mat_col = torch.transpose(slicer_mat_rows, 2, 1)
slicer_mat_col = torch.flip(slicer_mat_col, [2])
slicer_mat_col = torch.flip(slicer_mat_col, [0])
slicer_mat_col

tensor([[[1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0.]],

        [[1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0., 0., 0.]],

        [[1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0., 0.]],

        [[1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 1., 0., 0., 0.],
         [1., 1., 1

In [42]:
# replicate existing base matrix
a_stacked = a.unsqueeze(0)
a_stacked = a_stacked.expand(N_TOKENS, -1, -1)
a_stacked
# filter out previous rows
a_stacked_filtered = a_stacked
a_stacked_filtered = a_stacked_filtered * slicer_mat_rows
a_stacked_filtered = a_stacked_filtered * slicer_mat_col
a_stacked_filtered

tensor([[[0.8823, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7936, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7411, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4414, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5472, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8090, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7104, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7936, 0.9408, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7411, 0.4294, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4414, 0.2969, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.5472, 0.0062, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.8090, 0.5779, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.7104, 0.9464, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

In [56]:
v = [0.8823, 0.7936, 0.7411, 0.4414, 0.5472, 0.8090, 0.7104]
np.linalg.norm(v)

1.9002353591068661

In [44]:
a_stacked_filtered.shape

torch.Size([7, 7, 7])

In [59]:
# normalize each column with euclidean norm
a_stacked_norm = torch.norm(a_stacked_filtered, dim=1)
a_stacked_norm = a_stacked_norm.unsqueeze(1)
a_stacked_norm = a_stacked_norm.expand(-1, N_TOKENS, -1)
a_stacked_norm

tensor([[[1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.9002, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [1.6830, 1.5451, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[1.4841, 1.2257, 1.9547, 0.0000, 0.0000, 0.0000, 

In [63]:
# Normalize the each value 
a_stacked_normalized = a_stacked_filtered / a_stacked_norm
# replace nan
a_stacked_normalized = torch.nan_to_num(a_stacked_normalized, nan=0.0)

a_stacked_normalized


tensor([[[0.4643, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4177, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3900, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2323, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2880, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4257, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3739, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4716, 0.6089, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4403, 0.2779, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2622, 0.1922, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3251, 0.0040, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4807, 0.3740, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4221, 0.6125, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 

In [144]:
def expand_and_normalized(a: torch.Tensor):
    """Expand the N_tokens x N_tokens matrix for the number of tokens.
    
    Expand a matrix adding a new dimension N_tokens and remove
    - x lines from the top and 
    - x columns from the bottom, 
    so to leave only a rectangular non-zero section.

    Then normalize these rectangular section in a column wise fashion.
    """
    n_tokens = a.shape[0]
    # mask on rows
    slicer_mat_rows = torch.triu(torch.ones(n_tokens, n_tokens))
    slicer_mat_rows = slicer_mat_rows.unsqueeze(2)
    slicer_mat_rows = slicer_mat_rows.expand(-1, -1, n_tokens)
    # mask on columns
    slicer_mat_col = torch.transpose(slicer_mat_rows, 2, 1)
    slicer_mat_col = torch.flip(slicer_mat_col, [2])
    slicer_mat_col = torch.flip(slicer_mat_col, [0])
    # replicate existing base matrix
    a_stacked = a.unsqueeze(0)
    a_stacked = a_stacked.expand(n_tokens, -1, -1)
    # keep only rectangular sections
    a_stacked_filtered = a_stacked
    a_stacked_filtered = a_stacked_filtered * slicer_mat_rows
    a_stacked_filtered = a_stacked_filtered * slicer_mat_col
    # normalize each column with euclidean norm
    normalization_coeff = torch.norm(a_stacked_filtered, dim=1)
    normalization_coeff = normalization_coeff.unsqueeze(1)
    normalization_coeff = normalization_coeff.expand(-1, n_tokens, -1)
    # Normalize the each value 
    a_stacked_normalized = a_stacked_filtered / normalization_coeff
    # replace nan
    a_stacked_normalized = torch.nan_to_num(a_stacked_normalized, nan=0.0)
    return a_stacked_normalized

current_level_stacked = expand_and_normalized(a)
next_level_stacked = expand_and_normalized(b)

In [142]:
def compute_from_a_to_b(
        current_level_stacked: torch.Tensor,
        next_level_stacked: torch.Tensor,):
    """Compute a token -> token matrix for consecutive layers."""
    n_tokens = current_level_stacked.shape[0]
    assert current_level_stacked.shape[0] == next_level_stacked.shape[0], "the two inputs must have the same dimensions"
    # generalized multiplication to compute the dot products in parallel
    # this compares the number of followers of a token pairs in two 
    # consecutive layers
    res = torch.einsum(
        'bji,bjk->bik', 
        current_level_stacked, next_level_stacked)
    # create mask to remove the extra values in the upper left section
    # this is needed since the complete list of followers of a token 
    # are exactly all those tokens that follows in the sequence.
    # Note that his sequence is decided by the token of the pair which comes
    # last in the pair (aka we compare the largest set of followers which had
    # the possibility to follow both tokens in the pairs).
    slicer_mat_rows = torch.tril(torch.ones(n_tokens + 1, n_tokens + 1))[:-1, 1:]
    slicer_mat_rows = slicer_mat_rows.unsqueeze(2)
    slicer_mat_rows = slicer_mat_rows.expand(-1, -1, n_tokens)
    transposed = torch.transpose(slicer_mat_rows, 2, 1)
    mask_keep_only_last_slide = 1 - (slicer_mat_rows * transposed)
    mask_keep_only_last_slide
    res = res * mask_keep_only_last_slide
    # condense the stacked version
    return res.sum(dim=0)

In [74]:
TARGET_TOKEN_POS = 3

In [94]:
current_level_slice = current_level_stacked[TARGET_TOKEN_POS]
current_level_slice

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3433, 0.2586, 0.4772, 0.1658, 0.0000, 0.0000, 0.0000],
        [0.4256, 0.0054, 0.5460, 0.1185, 0.0000, 0.0000, 0.0000],
        [0.6291, 0.5034, 0.5187, 0.8731, 0.0000, 0.0000, 0.0000],
        [0.5525, 0.8244, 0.4528, 0.4430, 0.0000, 0.0000, 0.0000]])

In [95]:
next_level_slice = next_level_stacked[TARGET_TOKEN_POS]
next_level_slice

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7386, 0.5807, 0.5630, 0.1375, 0.0000, 0.0000, 0.0000],
        [0.5464, 0.5915, 0.5253, 0.3806, 0.0000, 0.0000, 0.0000],
        [0.6249, 0.3322, 0.4510, 0.1755, 0.0000, 0.0000, 0.0000],
        [0.2897, 0.6952, 0.4820, 0.2163, 0.0000, 0.0000, 0.0000]])

In [96]:
d_current = current_level_slice[:, TARGET_TOKEN_POS]
pprint(d_current)
d_next = next_level_slice[:, TARGET_TOKEN_POS]
pprint(d_next)

tensor([0.0000, 0.0000, 0.0000, 0.1658, 0.1185, 0.8731, 0.4430])
tensor([0.0000, 0.0000, 0.0000, 0.1375, 0.3806, 0.1755, 0.2163])


In [90]:
torch.dot(d_current, d_next)

tensor(0.3170)

In [97]:
torch.matmul(d_current, next_level_slice)

tensor([0.8611, 0.7643, 0.7629, 0.3170, 0.0000, 0.0000, 0.0000])

In [101]:
torch.matmul(d_next, current_level_slice)

tensor([0.4391, 0.3043, 0.4624, 0.3170, 0.0000, 0.0000, 0.0000])

In [100]:
torch.matmul(current_level_slice.t(), next_level_slice)
# the only valid data are those on the edge of the section, thus 
# the last non-zero row 
# and the last non-zero column

tensor([[1.0392, 1.0441, 0.9669, 0.4391, 0.0000, 0.0000, 0.0000],
        [0.7474, 0.8937, 0.7728, 0.3043, 0.0000, 0.0000, 0.0000],
        [1.1061, 1.0872, 1.0077, 0.4624, 0.0000, 0.0000, 0.0000],
        [0.8611, 0.7643, 0.7629, 0.3170, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [98]:
current_level_stacked[:3, :3, :3]

tensor([[[0.4643, 0.0000, 0.0000],
         [0.4177, 0.0000, 0.0000],
         [0.3900, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000],
         [0.4716, 0.6089, 0.0000],
         [0.4403, 0.2779, 0.0000]],

        [[0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.4994, 0.3503, 0.4530]]])

In [125]:
# generalize for all the tokens

# create mask to remove the extra values in the upper left section
slicer_mat_rows = torch.tril(torch.ones(N_TOKENS + 1, N_TOKENS + 1))[:-1, 1:]
slicer_mat_rows = slicer_mat_rows.unsqueeze(2)
slicer_mat_rows = slicer_mat_rows.expand(-1, -1, N_TOKENS)
transposed = torch.transpose(slicer_mat_rows, 2, 1)
mask_keep_only_last_slide = 1 - (slicer_mat_rows * transposed)
mask_keep_only_last_slide

res = torch.einsum(
    'bji,bjk->bik', 
    current_level_stacked, next_level_stacked)
res = res * mask_keep_only_last_slide
res[TARGET_TOKEN_POS, :, :]

tensor([[0.0000, 0.0000, 0.0000, 0.4391, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.3043, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.4624, 0.0000, 0.0000, 0.0000],
        [0.8611, 0.7643, 0.7629, 0.3170, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])

In [124]:
res.sum(dim=0)

tensor([[0.7355, 0.9288, 0.8479, 0.4391, 0.3038, 1.2960, 0.4394],
        [0.7165, 0.8044, 0.7163, 0.3043, 0.2664, 1.3367, 0.4394],
        [0.9544, 1.1492, 0.8927, 0.4624, 0.2958, 1.2950, 0.4394],
        [0.8611, 0.7643, 0.7629, 0.3170, 0.2779, 1.1724, 0.4394],
        [0.7990, 0.9945, 0.9217, 0.4664, 0.2687, 1.3146, 0.4394],
        [0.7822, 0.7418, 0.9564, 0.2817, 0.3956, 1.3049, 0.4394],
        [0.5244, 0.8432, 1.0645, 0.4883, 0.2955, 1.6249, 0.4394]])

In [138]:
compute_from_a_to_b(
    current_level_stacked, next_level_stacked)


tensor([[0.7355, 0.9288, 0.8479, 0.4391, 0.3038, 1.2960, 0.4394],
        [0.7165, 0.8044, 0.7163, 0.3043, 0.2664, 1.3367, 0.4394],
        [0.9544, 1.1492, 0.8927, 0.4624, 0.2958, 1.2950, 0.4394],
        [0.8611, 0.7643, 0.7629, 0.3170, 0.2779, 1.1724, 0.4394],
        [0.7990, 0.9945, 0.9217, 0.4664, 0.2687, 1.3146, 0.4394],
        [0.7822, 0.7418, 0.9564, 0.2817, 0.3956, 1.3049, 0.4394],
        [0.5244, 0.8432, 1.0645, 0.4883, 0.2955, 1.6249, 0.4394]])

## Real Example

In [132]:
def normalize_matrix_by_lines(att_matrix):
    att_matrix[0,:] = 0.0000001
    att_matrix[:,0] = 0.0000001
    for i in range(att_matrix.shape[0]):
        att_matrix[i, :] = att_matrix[i, :] / att_matrix[i, :].sum()
    return att_matrix

In [136]:
INPUT = torch.tensor(att_tensor_np)
INPUT = INPUT.sum(2).sum(0)
# sum all heads
INPUT = INPUT.sum(1)
INPUT.shape

torch.Size([20, 157, 157])

In [153]:
from tqdm import tqdm
# consider each pair of two layers

all_layers_results = []

n_layers = INPUT.shape[0]
for i in tqdm(range(n_layers-1)):
    c_layer = INPUT[i]
    n_layer = INPUT[i + 1]

    current_level_stacked = expand_and_normalized(c_layer)
    next_level_stacked = expand_and_normalized(n_layer)

    res = compute_from_a_to_b(
        current_level_stacked, next_level_stacked)
    
    all_layers_results.append(res)

# stack results
all_layers_results = torch.stack(all_layers_results, dim=0)
all_layers_results.shape

100%|██████████| 19/19 [00:01<00:00, 18.52it/s]


torch.Size([19, 157, 157])

In [151]:
OUTPUT = all_layers_results.sum(0)
OUTPUT =  normalize_matrix_by_lines(OUTPUT)
OUTPUT.shape

torch.Size([157, 157])

In [152]:
np.allclose(OUTPUT.numpy(), att_matrix_3_processed)

True

In [149]:
att_matrix_3_processed.shape

(157, 157)

In [150]:
import seaborn as sns

In [None]:
diff = OUTPUT.numpy() - att_matrix_3_processed