In [46]:
import torch
import torch.nn as nn
from model import Transformer
from train import get_model, get_dataset, greedy_decode
from config import config, get_weights_file_path
import altair as alt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

In [2]:
# device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [47]:
# load config
config = config()
train_ds, val_ds, tokenizer_src, tokenizer_tgt = get_dataset(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size())
model = model.to(device)

Max length of source sentence: 104
Max length of target sentence: 99


In [48]:
# weights
model_filename = get_weights_file_path(config, "10") # tenth epoch
print(f"Loading {model_filename}")
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])


Loading Helsinki-NLP\opus-100_weights\tmodel__{lang_src}__{lang_tgt}.pt10.pt


<All keys matched successfully>

In [49]:
def fetch_validation_sample():
    # Get a single validation sample
    validation_batch = next(iter(val_ds))
    
    # Extract and move tensors to appropriate device
    src_tensor = validation_batch['encoder_input'].to(device) 
    src_padding = validation_batch['encoder_mask'].to(device)
    tgt_tensor = validation_batch['decoder_input'].to(device)
    tgt_padding = validation_batch['decoder_mask'].to(device)
    
    # Convert source and target IDs to readable tokens
    encoder_input_tokens = [tokenizer_src.id_to_token(id.item()) for id in src_tensor[0]]
    decoder_input_tokens = [tokenizer_tgt.id_to_token(id.item()) for id in tgt_tensor[0]]

    # Validate single sample requirement
    if src_tensor.shape[0] != 1:
        raise ValueError("Expected batch size of 1 for validation")
        

    # Generate model prediction
    prediction = greedy_decode(model, src_tensor, src_padding, 
                             tokenizer_src, tokenizer_tgt,
                             device, config['seq_len'])
    
    # Return validation data
    return validation_batch, encoder_input_tokens, decoder_input_tokens


In [50]:
def attention_matrix_to_dataframe(matrix, rows_limit, cols_limit, row_vocab, col_vocab):
    """Convert attention matrix to pandas dataframe with token info"""
    data = []
    for i in range(min(matrix.shape[0], rows_limit)):
        for j in range(min(matrix.shape[1], cols_limit)):
            row_token = f"{i:03d} {row_vocab[i]}" if i < len(row_vocab) else f"{i:03d} <pad>"
            col_token = f"{j:03d} {col_vocab[j]}" if j < len(col_vocab) else f"{j:03d} <pad>"
            data.append({
                "i": i,
                "j": j, 
                "attention": float(matrix[i,j]),
                "row_label": row_token,
                "col_label": col_token
            })
    return pd.DataFrame(data)

def extract_attention_scores(model_type: str, layer_idx: int, head_idx: int):
    """Get attention scores for specified layer/head"""
    if model_type == "encoder":
        scores = model.encoder.layers[layer_idx].self_attention_block.attention_scores
    elif model_type == "decoder": 
        scores = model.decoder.layers[layer_idx].self_attention_block.attention_scores
    else:
        scores = model.decoder.layers[layer_idx].cross_attention_block.attention_scores
    return scores[0, head_idx].data

def visualize_attention(model_type, layer, head, input_tokens, output_tokens, max_len):
    """Create heatmap visualization of attention weights"""
    df = attention_matrix_to_dataframe(
        extract_attention_scores(model_type, layer, head),
        max_len,
        max_len, 
        input_tokens,
        output_tokens
    )
    
    chart = alt.Chart(df).mark_rect().encode(
        x=alt.X("col_label", title=None),
        y=alt.Y("row_label", title=None), 
        color="attention",
        tooltip=["i", "j", "attention", "row_label", "col_label"]
    ).properties(
        height=400,
        width=400,
        title=f"Attention Layer {layer} Head {head}"
    ).interactive()
    
    return chart

def create_attention_grid(model_type: str, layer_indices: list[int], head_indices: list[int], 
                         input_tokens: list, output_tokens: list, max_len: int):
    """Generate grid of attention visualizations"""
    rows = []
    for layer in layer_indices:
        row = []
        for head in head_indices:
            row.append(visualize_attention(model_type, layer, head, 
                                        input_tokens, output_tokens, max_len))
        rows.append(alt.hconcat(*row))
    return alt.vconcat(*rows)

In [51]:
batch, encoder_input_tokens, decoder_input_tokens = fetch_validation_sample()
print(f'Source: {batch["src_text"][0]}')
print(f'Target: {batch["tgt_text"][0]}')
sentence_len = encoder_input_tokens.index("[PAD]")

Source: ممباتو) أكل للحم البشر) (ولكنه التحق بجامعة (هارفرد
Target: Moombata ist ein Kannibale, aber er hat in Harvard studiert.


In [52]:
layers = [0, 1, 2]
heads = [0, 1, 2, 3, 4, 5, 6, 7]

# self attention in encoder
create_attention_grid("encoder", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))
# create_attention_grid("encoder", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))


In [53]:
# self attention in decoder
create_attention_grid("decoder", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))
