In [1]:
import torch 
import torch.nn as nn
from model import Transformer
from config import get_config,get_weights_file_path
from train import get_model,get_ds,greedy_decode
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

config = get_config()

config["seq_len"] = 256  # The model was trained with seq_len=256

print("Loading dataset")
train_dataloader, test_dataloader, tokenizer_src, tokenizer_tgt= get_ds(config)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_model(config,tokenizer_src.get_vocab_size(),tokenizer_tgt.get_vocab_size())
model.to(device)
model_filename = get_weights_file_path(config, f"{8}")
checkpoint = torch.load(model_filename,map_location=device)
# Load only the model state dict, not the entire checkpoint
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()




A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.3 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/lakshmanan/Library/Python/3.11/lib/python/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/lakshmanan/Library/Python/3.11/lib/python/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/lakshmanan/Library/Python/3.11/lib/python/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io

Loading dataset


Calculating max lengths: 100%|██████████| 257067/257067 [01:03<00:00, 4041.98it/s]


Maximum length of source sentences: 554
Maximum length of target sentences: 655


Transformer(
  (encoder): Encoder(
    (layers): ModuleList(
      (0-5): 6 x EncoderBlock(
        (self_attention): MultiHeadAttention(
          (w_q): Linear(in_features=512, out_features=512, bias=True)
          (w_k): Linear(in_features=512, out_features=512, bias=True)
          (w_v): Linear(in_features=512, out_features=512, bias=True)
          (w_o): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (feed_forward): FeedForwardBlock(
          (linear_1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (residual_1): ResidualConnection(
          (dropout): Dropout(p=0.1, inplace=False)
          (norm): LayerNormalization()
        )
        (residual_2): ResidualConnection(
          (dropout): Dropout(p=0.1, inplace=False)
          (norm): LayerNormaliz

In [2]:

def load_next_batch():
    # Load a sample batch from the test set
    batch = next(iter(test_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    encoder_input_tokens = [tokenizer_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [tokenizer_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    # check that the batch size is 1
    assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"

    model_out = greedy_decode(model, encoder_input, encoder_mask, config["seq_len"], tokenizer_tgt, config)
    
    return batch, encoder_input_tokens, decoder_input_tokens

def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention.attention_scores
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

# Run the attention visualization
batch, encoder_input_tokens, decoder_input_tokens = load_next_batch()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
sentence_len = encoder_input_tokens.index("[PAD]")


Source: All these are the twelve tribes of Israel: and this is it that their father spoke to them, and blessed them; every one according to his blessing he blessed them.

Target: இவர்கள் எல்லாரும் இஸ்ரவேலின் பன்னிரண்டு கோத்திரத்தார்; அவர்களுடைய தகப்பன் அவர்களை ஆசீர்வதிக்கையில், அவர்களுக்குச் சொன்னது இதுதான்; அவனவனுக்குரிய ஆசீர்வாதம் சொல்லி அவனவனை ஆசீர்வதித்தான்.



In [3]:

layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

# Encoder Self-Attention
encoder_self_attn = get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))



In [4]:
encoder_self_attn

In [5]:
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))

In [7]:
get_all_attention_maps("encoder-decoder", layers, heads, decoder_input_tokens, encoder_input_tokens, min(20, sentence_len))
