In [None]:
from model import PositionalEncoding
import matplotlib.pyplot as plt

positional_encoding = PositionalEncoding(d_model=512, seq_len=100, dropout=0.1)


plt.figure(figsize=(10, 5))
plt.pcolormesh(
    positional_encoding.pos_encoding.detach().numpy(), cmap="viridis")
plt.xlabel("Embedding Dimensions")
plt.xlim((0, 512))
plt.ylabel("Sequence Position")
plt.colorbar()
plt.show()

In [None]:
from config import get_config_no_parser
from model import Transformer
from dataset import get_tokenizer, get_dataset
import torch

config_path = "configs/laptop_wmt14.yaml"

ds_config, model_config, _, _ = get_config_no_parser(config_path)

src_dataset, tgt_dataset = get_dataset(ds_config, model_config)
src_tokenizer = get_tokenizer(
    src_dataset, ds_config.src_lang, model_config.src_vocab_size
)
tgt_tokenizer = get_tokenizer(
    tgt_dataset, ds_config.tgt_lang, model_config.tgt_vocab_size
)
model = Transformer.from_config(model_config)

In [None]:
phrase = "<s> Yesterday I went to the park. </s>"
src = torch.tensor(src_tokenizer.encode(phrase).ids).unsqueeze(0)

tgt = torch.tensor(tgt_tokenizer.encode("<s>").ids).unsqueeze(0)
output = model(src, tgt, None, None)

print("Decoded output:", tgt_tokenizer.decode(
    output.argmax(dim=-1).squeeze(0).tolist()))

tokens = src_tokenizer.encode(phrase).tokens

In [None]:


def plot_layer(layer, tokens):
    num_heads = model_config.num_heads
    num_cols = 2
    num_rows = (num_heads + num_cols - 1) // num_cols
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5*num_rows))

    attn_scores = layer.multi_head_attention.attn_scores.squeeze(
        0).detach().numpy()

    fig.subplots_adjust(hspace=0.5)

    for n_head in range(num_heads):
        row = n_head // num_cols
        col = n_head % num_cols

        ax = axes[row, col]

        ax.pcolormesh(attn_scores[n_head], cmap="viridis")
        ax.set_xlabel("Key")
        ax.set_ylabel("Query")
        ax.set_title(f"Attention Matrix - Head {n_head+1}")

        ax.set_xticks(range(len(tokens)))
        ax.set_xticklabels(tokens, rotation=90)
        ax.set_yticks(range(len(tokens)))
        ax.set_yticklabels(tokens)

        fig.colorbar(ax.pcolormesh(attn_scores[n_head], cmap="viridis"), ax=ax)

    plt.show()

In [None]:
N = 3
plot_layer(model.encoder.layers[N], tokens)
