# Analyze Inner Representations

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import import_ipynb # using this to import the modules notebook
import modules # importing the notebook
import torch

Set Model Configuration

In [None]:
# num_tokens: the number of different tokens in the corpus
# t: the length of the sequences as input to the model
# depth: depth of the network (number of transformer blocks)
# heads: number of attention heads in the multi-head attention mechanism
# k: embedding dimension (needs to be a multiple of heads)

k = 6 # x * heads
num_tokens = 10 # integers from 0 to 9
heads = 3
depth = 2
t = 5

Load Model

In [None]:
# load trained model
model = modules.GTransformer(k=k, heads=heads, depth=depth, t=t, num_tokens=num_tokens)
model.load_state_dict(torch.load('gtransformer.pth'))

Set data

In [None]:
# set token
tokens = np.arange(num_tokens)
print(tokens)

Analyze Attention Maps + Key, Value, Query Matrices for example input 

In [None]:
# define example input for class 0: increasing 
input = torch.tensor([[4,5,5,5,4]], dtype=torch.long)
save_fig_path = "./images/misaligned_class2/"


print("input:", input)
print("input size", input[0].size())

In [None]:
# generate output for the input sequence, without adapting the weights + set model to eval mode
model.eval() 
with torch.no_grad():
    output, _, _ = model(input)
    # print("output", output)
    print("output (vocab distribution for each output dim)", output.shape)
    # get most propable token from log_prob output: 
    print("output", torch.argmax(torch.exp(output), dim=2))

In [None]:
# get attention maps, key matrices, queries matrices, and values matrices
out_matrices, probs_matrices, attention_maps, key_matrices, query_matrices, value_matrices  = model.get_respresentations(input)

In [None]:
# analyze attention maps
print(f"number of attention maps (concat over all heads in each layer): {len(attention_maps)} - heads used: {heads}, attention layers used: {depth}")
attention_map = attention_maps[0]
print("Shape of each attention map: ", attention_map.shape) # beware that each head has its own dot product matrix representating the attention
assert attention_map.shape[0] == heads 
assert attention_map.shape[1] == attention_map.shape[2] == t # matrice needs to be k*k
print("Number of heads: ", attention_map.shape[0])
print("Sequence length: ", attention_map.shape[1])


In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(attention_maps)
nheads = attention_maps[0].shape[0]
t = attention_maps[0].shape[1]


print(f"Number of displayable layers: {nlayers}")
print(f"Number of displayable heads: {nheads}")

cols = nheads
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    attention_map = attention_maps[row]
    print(attention_map.shape)
    for col in range(cols):   
        data = attention_map[col,:,:].detach().cpu()
        vmin = data.min().item()
        vmax = data.max().item()
        ax[row, col].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
        ax[row, col].set_xticks(list(range(t)))
        ax[row, col].set_xticklabels((f"In-Position {i}" for i,token in enumerate(input[0].tolist())), rotation=90)
        ax[row, col].set_yticks(list(range(t)))
        ax[row, col].set_yticklabels(f"Out-Position {i}" for i,token in enumerate(input[0].tolist()))
        ax[row, col].set_title(f"TF block {row}, Head {col}")
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                text = ax[row, col].text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="w")
        
        
plt.tight_layout()
plt.suptitle("Attention Maps", y=1.02)
plt.savefig(save_fig_path+ "attention_maps.png")	

In [None]:
# analyze key, query amd value matrices
print(f"number of key matrices (concat over all heads in each layer): {len(key_matrices)}- heads used: {heads}, attention layers used: {depth}")
key_matrice = key_matrices[0]
print("Shape of each key matrix: ", key_matrice.shape) # beware that each head has its own key matrice
# each head specific matrice needs to have the dimension of sequence in one dimension and k/heads in the other
assert key_matrice.shape[0] == heads
assert key_matrice.shape[1] == t
assert key_matrice.shape[2] == k // heads
print("Number of heads: ", key_matrice.shape[0])
print("Sequence length: ", key_matrice.shape[1])
print("k // heads: ", key_matrice.shape[2])


In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(key_matrices)
nheads = key_matrices[0].shape[0]
seq_len =key_matrices[0].shape[1]

print(f"Number of displayable layers: {nlayers}")
print(f"Number of displayable heads: {nheads}")

cols = nheads
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    key_matrice = key_matrices[row]
    for col in range(cols):   
        data =key_matrice[col,:,:].detach().cpu()
        vmin = data.min().item()
        vmax = data.max().item()
        ax[row, col].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
        ax[row, col].set_xticks(list(range(k // heads)))
        ax[row, col].set_yticks(list(range(seq_len)))
        ax[row, col].set_yticklabels(f"Position: {i}" for i in range(0,len(input[0].tolist())))
        ax[row, col].set_title(f"TF block {row}, Head {col}")
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                text = ax[row, col].text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="w")

plt.suptitle("Key Matrices (Keys of dimension k//heads for each token)", y=1.02)
plt.savefig(save_fig_path+ "key_matrices.png")

In [None]:
# analyze query matrices
print(f"number of query matrices (concat over all heads in each layer): {len(query_matrices)}- heads used: {heads}, attention layers used: {depth}")
query_matrice = query_matrices[0]
print("Shape of each query matrix: ", query_matrice.shape) # beware that each head has its own query matrice
# each head specific matrice needs to have the dimension of sequence in one dimension and k/heads in the other
assert query_matrice.shape[0] == heads
assert query_matrice.shape[1] == t
assert query_matrice.shape[2] == k // heads
print("Number of heads: ", key_matrice.shape[0])
print("Sequence length: ", key_matrice.shape[1])
print("k // heads: ", key_matrice.shape[2])


In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(query_matrices)
nheads = query_matrices[0].shape[0]
t = query_matrices[0].shape[1]

print(f"Number of displayable layers: {nlayers}")
print(f"Number of displayable heads: {nheads}")

cols = nheads
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    query_matrice = query_matrices[row]
    for col in range(cols):   
        data = query_matrice[col,:,:].detach().cpu()
        vmin = data.min().item()
        vmax = data.max().item()
        ax[row, col].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
        ax[row, col].set_xticks(list(range(k // heads)))
        ax[row, col].set_yticks(list(range(t)))
        ax[row, col].set_yticklabels(f"Position: {i}" for i in range(0,len(input[0].tolist())))
        ax[row, col].set_title(f"TF block {row}, Head {col}")
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                text = ax[row, col].text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="w")
        
        
plt.tight_layout()
plt.suptitle("Query Matrices (Queries of dimension k//heads for each token)", y=1.02)
plt.savefig(save_fig_path+ "query_matrices.png")

In [None]:
# analyze key matrices
print(f"number of key matrices (concat over all heads in each layer): {len(key_matrices)}- heads used: {heads}, attention layers used: {depth}")
key_matrice = key_matrices[0]
print("Shape of each key matrix: ", key_matrice.shape) # beware that each head has its own key matrice
# each head specific matrice needs to have the dimension of sequence in one dimension and k/heads in the other
assert key_matrice.shape[0] == heads
assert key_matrice.shape[1] == t
assert key_matrice.shape[2] == k // heads
print("Number of heads: ", key_matrice.shape[0])
print("Sequence length: ", key_matrice.shape[1])
print("k // heads: ", key_matrice.shape[2])

In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(key_matrices)
nheads = key_matrices[0].shape[0]
t = key_matrices[0].shape[1]


print(f"Number of displayable layers: {nlayers}")
print(f"Number of displayable heads: {nheads}")

cols = nheads
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    key_matrice = key_matrices[row]
    print(key_matrice.shape)
    for col in range(cols):   
        data = key_matrice[col,:,:].detach().cpu()
        vmin = data.min().item()
        vmax = data.max().item()
        ax[row, col].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
        ax[row, col].set_xticks(list(range(k // heads)))
        ax[row, col].set_yticks(list(range(t)))
        ax[row, col].set_yticklabels(f"Position: {i}" for i in range(0,len(input[0].tolist())))
        ax[row, col].set_title(f"TF block {row}, Head {col}")
        for i in range(data.shape[0]):
            for j in range(data.shape[1]):
                text = ax[row, col].text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="w")
        
        
plt.tight_layout()
plt.suptitle("Key Matrices (Values of dimension k//heads for each token)", y=1.02)
plt.savefig(save_fig_path+ "value_matrices.png")

In [None]:
# analyze out matrices
print(f"number of out matrices: {len(out_matrices)}")
out_matrice = out_matrices[0]
print("Shape of each out matrix: ", out_matrice.shape) # beware that each head has its own out matrice
assert out_matrice.shape[0] == 1
assert out_matrice.shape[1] == t
assert out_matrice.shape[2] == k 
print("Number of output: ", out_matrice.shape[0])
print("Sequence length: ",out_matrice.shape[1])
print("k : ",out_matrice.shape[2])

In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(out_matrices)
t = out_matrices[0].shape[1]
k = out_matrices[0].shape[2]

print(f"Number of displayable layers: {nlayers}")

cols = 1
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    out_matrice = out_matrices[row]
    data = out_matrice.detach().cpu().squeeze(0)
    vmin = data.min().item()
    vmax = data.max().item()
    ax[row].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
    ax[row].set_xticks(list(range(k)))
    ax[row].set_yticks(list(range(t)))
    ax[row].set_yticklabels(f"Position: {i}" for i in range(0,len(input[0].tolist())))
    ax[row].set_title(f"TF block {row}")
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            text = ax[row].text(j, i, f"{data[i, j]:.2f}", ha="center", va="center", color="w")
        
        
plt.tight_layout()
plt.suptitle("Out Matrices (Output Embeddings of dimension k for each output position)", y=1.02)
plt.savefig(save_fig_path+  "out_matrices.png")

In [None]:
# analyze prob matrices
print(f"number of prob matrices: {len(probs_matrices)}")
probs_matrice = probs_matrices[0]
print("Shape of each prob matrix: ", probs_matrice.shape) # beware that each head has its own prob matrix
assert probs_matrice.shape[0] == 1
assert probs_matrice.shape[1] == t
assert probs_matrice.shape[2] == len(tokens)
print("Number of output: ", probs_matrice.shape[0])
print("Sequence length: ", probs_matrice.shape[1])
print("Number of tokens: ", probs_matrice.shape[2])

In [None]:
# test for only one batch (one sequence that went through the 2 TransformerBlocks (depth = 2) that are integrated in the GTransformer)
batch = 0

nlayers = len(probs_matrices)
t = probs_matrices[0].shape[1]
num_tokens = probs_matrices[0].shape[2]

print(f"Number of displayable layers: {nlayers}")

cols = 1
rows = nlayers

fig, ax = plt.subplots(rows, cols, figsize=(15,15))

for row in range(rows):
    probs_matrice = probs_matrices[row]  
    data = probs_matrice.detach().cpu().squeeze(0)
    data = torch.exp(data)
    # assert that sums to 1 
    assert torch.allclose(data.sum(dim=1), torch.ones(t), atol=1e-6)
    vmin = data.min().item()
    vmax = data.max().item()
    ax[row].imshow(data, origin="lower",vmin = vmin, vmax =vmax)
    ax[row].set_xticks(list(range(len(tokens))))
    ax[row].set_xticklabels(f"Vocab-Token: {i}" for i in tokens)
    ax[row].set_yticks(list(range(t)))
    ax[row].set_yticklabels(f"Position: {i}" for i in range(0,len(input[0].tolist())))
    ax[row].set_title(f"TF block {row}")
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            text = ax[row].text(j, i, f"{data[i, j]:.3f}", ha="center", va="center", color="w")
        
        
plt.tight_layout()
plt.suptitle("Probs Matrices (Token Log Probabilities for each output position)", y=1.02)
plt.savefig(save_fig_path+ "prob_matrices.png")