In [1]:
from util import load_config, load_tokenizers
from model import Transformer
import transformer_components
import matplotlib.pyplot as plt
import torch
import ipywidgets as widgets

### Setup:

In [2]:
# Load config
config = load_config("config.yaml")

# Load the tokenizers
src_tokenizer, tgt_tokenizer = load_tokenizers(**config["tokenizer"])

# Create the model
transformer = Transformer(
    source_vocab_size=src_tokenizer.vocab_size(),
    target_vocab_size=tgt_tokenizer.vocab_size(),
    bos_idx=tgt_tokenizer.bos_id(),
    eos_idx=tgt_tokenizer.eos_id(),
    **config["transformer_params"]
)

stack_size = config["transformer_params"]["stack_size"]
# Initialize the attention weights queue
transformer_components.initialize_attention_weights_queue(3 * stack_size)

In [3]:
# Create the sequence pair
src_sentence = "hallo welt, aufmerksamkeit ist alles was du brauchst!"
tgt_sentence = "hello world, attention is all you need!"

src_sequence_ids = src_tokenizer.encode(src_sentence, add_bos=True, add_eos=True)
tgt_sequence_ids = tgt_tokenizer.encode(tgt_sentence, add_bos=True, add_eos=True)
src_sequence_ids.extend([src_tokenizer.pad_id() for _ in range(4)])
tgt_sequence_ids.extend([tgt_tokenizer.pad_id() for _ in range(3)])

src_tokens = src_tokenizer.id_to_piece(src_sequence_ids)
tgt_tokens = tgt_tokenizer.id_to_piece(tgt_sequence_ids)

src_sequence_ids = torch.tensor([src_sequence_ids])
tgt_sequence_ids = torch.tensor([tgt_sequence_ids])
src_key_padding_mask = src_sequence_ids == src_tokenizer.pad_id()

In [4]:
# Encode
encoded_src = transformer.encode_source(src_sequence_ids, src_key_padding_mask)

# Decode
predictions = transformer(encoded_src, tgt_sequence_ids[..., :-1], src_key_padding_mask)

attention_weights = list(transformer_components.attention_weights_queue)
encoder_attention = attention_weights[:stack_size]
decoder_masked_attention = attention_weights[stack_size::2]
decoder_encoder_attention = attention_weights[stack_size+1::2]

In [5]:
def get_widgets():
    stack_idx = widgets.IntSlider(
        value=0,
        min=0,
        max=stack_size - 1,
        step=1,
        description='Select the stack index:',
        style={'description_width': '200px'},
        layout={'width': '400px'}
    )
    attention_head_idx = widgets.IntSlider(
        value=0,
        min=0,
        max=encoder_attention[0].shape[1] - 1,
        step=1,
        description='Select the attention head index:',
        style={'description_width': '200px'},
        layout={'width': '400px'}
    )
    return stack_idx, attention_head_idx

In [6]:
def plot_encoder_attention(stack_idx, attention_head_idx):
    attention_weights = encoder_attention[stack_idx][0, attention_head_idx].detach().numpy()

    # Plot the weights
    plt.imshow(attention_weights, vmin=0, vmax=1)

    # Set the xticks and yticks labels
    plt.xticks(range(len(src_tokens)), src_tokens, rotation=90, fontsize=8)
    plt.yticks(range(len(src_tokens)), src_tokens, fontsize=8)

    cbar = plt.colorbar()
    plt.tight_layout()
    plt.show()
    
def plot_masked_decoder_attention(stack_idx, attention_head_idx):
    attention_weights = decoder_masked_attention[stack_idx][0, attention_head_idx].detach().numpy()

    # Plot the weights
    plt.imshow(attention_weights, vmin=0, vmax=1)

    # Set the xticks and yticks labels
    plt.xticks(range(len(tgt_tokens)-1), tgt_tokens[:-1], rotation=90, fontsize=8)
    plt.yticks(range(len(tgt_tokens)-1), tgt_tokens[:-1], fontsize=8)

    cbar = plt.colorbar()
    plt.tight_layout()
    plt.show()
    
def plot_decoder_encoder_attention(stack_idx, attention_head_idx):
    attention_weights = decoder_encoder_attention[stack_idx][0, attention_head_idx].detach().numpy()

    # Plot the weights
    plt.imshow(attention_weights, vmin=0, vmax=1)

    # Set the xticks and yticks labels
    plt.xticks(range(len(src_tokens)), src_tokens, rotation=90, fontsize=8)
    plt.yticks(range(len(tgt_tokens)-1), tgt_tokens[:-1], fontsize=8)

    cbar = plt.colorbar()
    plt.tight_layout()
    plt.show()

### Encoder Attention:

In [7]:
stack_idx, attention_head_idx = get_widgets()

widgets.interact(plot_encoder_attention, stack_idx=stack_idx, attention_head_idx=attention_head_idx)
print()

interactive(children=(IntSlider(value=0, description='Select the stack index:', layout=Layout(width='400px'), …




### Masked Decoder Attention:

In [8]:
stack_idx, attention_head_idx = get_widgets()

widgets.interact(plot_masked_decoder_attention, stack_idx=stack_idx, attention_head_idx=attention_head_idx)
print()

interactive(children=(IntSlider(value=0, description='Select the stack index:', layout=Layout(width='400px'), …




### Decoder-Encoder Attention

In [9]:
stack_idx, attention_head_idx = get_widgets()

widgets.interact(plot_decoder_encoder_attention, stack_idx=stack_idx, attention_head_idx=attention_head_idx)
print()

interactive(children=(IntSlider(value=0, description='Select the stack index:', layout=Layout(width='400px'), …


