# Analyze Learned Token Position Embeddings

In [None]:
import matplotlib.pyplot as plt
import import_ipynb # using this to import the modules notebook
import modules # importing the notebook
import tensorflow as tf
from tensorboard.plugins import projector
import os
import torch
import numpy as np

# Analyse Transformer Behavior


Set Model Configuration (from train_model.ipynb)

In [None]:
# num_tokens: the number of different tokens in the corpus
# t: the length of the sequences as input to the model
# depth: depth of the network (number of transformer blocks)
# heads: number of attention heads in the multi-head attention mechanism
# k: embedding dimension (needs to be a multiple of heads)

k = 6 # x * heads
num_tokens = 10 # integers from 0 to 9
heads = 3
depth = 2
t = 5

Load Model

In [None]:
print(hasattr(modules, 'GTransformer'))

In [None]:
# Load trained model
model = modules.GTransformer(k=k, heads=heads, depth=depth, t=t, num_tokens=num_tokens)
model.load_state_dict(torch.load('gtransformer.pth'))

Analyze Token Embeddings

In [None]:
# Set token
tokens = np.arange(num_tokens)
print(tokens)

In [None]:
# Create logging directory
log_dir='./torchlogs/pos-tokens/'

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# Get the token embeddings
pos_embeddings = model.pos_embedding.weight.data.cpu().numpy()
print(pos_embeddings.shape)


In [None]:
# Create a TensorFlow variable for the embeddings
embeddings = tf.Variable(pos_embeddings, name='pos_embeddings')

# Save the embeddings and metadata
checkpoint = tf.train.Checkpoint(embedding=embeddings)
checkpoint.save(os.path.join(log_dir, "token_embedding.ckpt"))

In [None]:
# Set up config.
config = projector.ProjectorConfig()
embedding = config.embeddings.add()

# The name of the tensor will be suffixed by `/.ATTRIBUTES/VARIABLE_VALUE`.
embedding.tensor_name = "embedding/.ATTRIBUTES/VARIABLE_VALUE"
embedding.metadata_path = 'metadata.tsv'
projector.visualize_embeddings(log_dir, config)

In [None]:
# Write metadata (sentences)
with open(os.path.join(log_dir, 'metadata.tsv'), 'w') as f:
    for token in tokens:
        f.write(f"{token}\n")

In [None]:
# Start TensorBoard (or use the command line: tensorboard --logdir=./src/model-basic/torchlogs/pos-tokens/)
%load_ext tensorboard
%tensorboard --logdir ./torchlogs/vocab-tokens/