In [2]:
! pip install altair

Collecting altair
  Downloading altair-5.5.0-py3-none-any.whl.metadata (11 kB)
Collecting narwhals>=1.14.2 (from altair)
  Downloading narwhals-2.2.0-py3-none-any.whl.metadata (11 kB)
Downloading altair-5.5.0-py3-none-any.whl (731 kB)
   ---------------------------------------- 0.0/731.2 kB ? eta -:--:--
   ---------------------------------------- 731.2/731.2 kB 6.0 MB/s  0:00:00
Downloading narwhals-2.2.0-py3-none-any.whl (401 kB)
Installing collected packages: narwhals, altair

   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ---------------------------------------- 0/2 [narwhals]
   ----------------------------------------

In [3]:
import torch
import torch .nn as nn
from model import Transformer
from config import get_config, get_weights_file_path
from train import get_model,get_ds,run_validation
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [13]:
# --- Helpers ---
def load_next_batch(val_dl, device):
    batch = next(iter(val_dl))
    encoder_input = batch['encoder_input'].to(device).long()
    decoder_input = batch['decoder_input'].to(device).long()
    encoder_mask  = batch['encoder_mask'].to(device)
    decoder_mask  = batch['decoder_mask'].to(device)
    return batch, encoder_input, decoder_input, encoder_mask, decoder_mask

def ids_to_tokens(tokenizer, ids_1d):
    if torch.is_tensor(ids_1d):
        ids_1d = ids_1d.detach().cpu().tolist()
    return [tokenizer.id_to_token(int(i)) for i in ids_1d]

# --- Fetch one batch ---
batch, encoder_input, decoder_input, encoder_mask, decoder_mask = load_next_batch(val_dataloader, device)

# If your dataset includes raw strings:
src_text = batch.get('src_text', [''])[0]
tgt_text = batch.get('tgt_text', [''])[0]
print("Source text:", src_text)
print("Target text:", tgt_text)

# Token lists (use the *tokenizers* you already have)
encoder_input_tokens = ids_to_tokens(tokenizer_src, encoder_input[0])
decoder_input_tokens = ids_to_tokens(tokenizer_tgt, decoder_input[0])
print("Encoder tokens:", encoder_input_tokens[:50], " ...")
print("Decoder tokens:", decoder_input_tokens[:50], " ...")

# --- Greedy decode one example ---
from train import greedy_decode  # you already imported train; this reuses your function

max_len = getattr(val_dataloader.dataset, "seq_len", config['seq_len'])
pred_ids = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
pred_text = tokenizer_tgt.decode(pred_ids.detach().cpu().tolist())

print("Predicted:", pred_text)


Source text: It contained a bookcase: I soon possessed myself of a volume, taking care that it should be one stored with pictures.
Target text: Vi era una biblioteca e io m'impossessai di un libro, cercando che fosse ornato d'incisioni.
Encoder tokens: ['[SOS]', 'It', 'contained', 'a', 'bookcase', ':', 'I', 'soon', 'possessed', 'myself', 'of', 'a', 'volume', ',', 'taking', 'care', 'that', 'it', 'should', 'be', 'one', 'stored', 'with', 'pictures', '.', '[EOS]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']  ...
Decoder tokens: ['[SOS]', 'Vi', 'era', 'una', 'biblioteca', 'e', 'io', 'm', "'", 'impossessai', 'di', 'un', 'libro', ',', 'cercando', 'che', 'fosse', 'ornato', 'd', "'", 'incisioni', '.', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', 

In [14]:
def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
    return pd.DataFrame(
        [
            (
                r,
                c,
                float(m[r, c]),
                "%.3d %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
                "%.3d %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
            )
            for r in range(m.shape[0])
            for c in range(m.shape[1])
            if r < max_row and c < max_col
        ],
        columns=["row", "column", "value", "row_token", "col_token"],
    )

def get_attn_map(attn_type: str, layer: int, head: int):
    if attn_type == "encoder":
        attn = model.encoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "decoder":
        attn = model.decoder.layers[layer].self_attention_block.attention_scores
    elif attn_type == "encoder-decoder":
        attn = model.decoder.layers[layer].cross_attention_block.attention_scores
    return attn[0, head].data

def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):
    df = mtx2df(
        get_attn_map(attn_type, layer, head),
        max_sentence_len,
        max_sentence_len,
        row_tokens,
        col_tokens,
    )
    return (
        alt.Chart(data=df)
        .mark_rect()
        .encode(
            x=alt.X("col_token", axis=alt.Axis(title="")),
            y=alt.Y("row_token", axis=alt.Axis(title="")),
            color="value",
            tooltip=["row", "column", "value", "row_token", "col_token"],
        )
        #.title(f"Layer {layer} Head {head}")
        .properties(height=400, width=400, title=f"Layer {layer} Head {head}")
        .interactive()
    )

def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):
    charts = []
    for layer in layers:
        rowCharts = []
        for head in heads:
            rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))
        charts.append(alt.hconcat(*rowCharts))
    return alt.vconcat(*charts)

In [19]:
layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

# Encoder Self-Attention
get_all_attention_maps("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, 350))

In [17]:
# Encoder Self-Attention
get_all_attention_maps("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, 350))

In [18]:
# Encoder Self-Attention
get_all_attention_maps("encoder-decoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, 350))