In [None]:
from pathlib import Path

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
from spacy.lang.en import English
from spacy.lang.hi import Hindi

from transformer import Transformer, MHA
from utils import sample

__author__ = "__Girish_Hegde__"



In [None]:
def attn_hook(module, input, output):
    """ Forward hook

    Refs:
        https://www.youtube.com/watch?v=1ZbLA7ofasY
    """
    _, attn = output  # [bs, h, i, j]
    attn = attn.detach()
    module.attn = attn

In [None]:
def attach_hook(net, starts_with='dec.cross_attn_layers', layers=[0, ], type=MHA):
    layers = set(layers)
    i = 0
    net.attn_viz_layers = []
    for name, module in net.named_modules():
        if isinstance(module, type) and name.startswith(starts_with):
            if i in layers:
                net.attn_viz_layers.append(module)
                module.name = name
                module.attn = None
                module.firing_hook = module.register_forward_hook(attn_hook)
            i += 1

In [None]:

CKPT = Path('./data/eng_hindi/runs/best.pt')
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

ckpt = torch.load(CKPT, map_location=DEVICE)
kwargs, state_dict = ckpt['net'].values()
epoch, loss, best = ckpt['training'].values()
in_int2tk, out_int2tk, start_token, end_token, pad_token, ukn_token = ckpt['dataset'].values()

net = Transformer(**kwargs)
net.load_state_dict(state_dict)
net = net.to(DEVICE)
tokenizer = English()

In [None]:
# attach_hook(net, starts_with='enc', layers=[0, 1, 2], type=MHA)
# attach_hook(net, starts_with='dec.self_attn_layers', layers=[0, 1, 2], type=MHA)
attach_hook(net, starts_with='dec.cross_attn_layers', layers=[0, 1, 2], type=MHA)

In [None]:
inp = "The monkeys jump from branch to branch."
pred = sample(
    inp, net, tokenizer,
    in_int2tk, out_int2tk,
    start_token, end_token,
    pad_token, ukn_token,
    top_k=1, max_size=100,
    device=DEVICE,
)
print(pred)

In [None]:
attn_table = net.attn_viz_layers[1].attn.cpu().numpy()

In [None]:
font_prop = FontProperties(fname='./data/devanagari.ttf', size=11)
fig = plt.figure(figsize=(6, 6), constrained_layout=True)
inp_tks = [str(tk) for tk in English()(inp)]
pred_tks = [str(tk) for tk in Hindi()(pred)]
for i in range(attn_table.shape[1]):
    plt.imshow(attn_table[0, i, :-1, 1:-1], cmap='Reds')
    plt.xticks(range(attn_table.shape[-1] - 2), inp_tks, rotation=20)
    plt.yticks(range(attn_table.shape[-2] - 1), pred_tks, rotation=20, fontproperties=font_prop)
    plt.show()
