In [1]:
import torch 
import matplotlib.pyplot as plt 
import seaborn as sns
plt.rcParams['font.family'] = 'AppleGothic'

In [4]:
def draw_one_attention_map(idx, attn_weights, query, key):
    # attn_weight shape: (batch_size * query_len * key_len)
    # query shape: list size of query_len
    # key shape: list size of key_len
    # query and key are already decoded!!
    attn_weights = attn_weights.cpu()

    # crop only first 10s of query and key
    query = query[:20]
    key = key[:20]
    
    # draw one attention map given attn_weight, query, key
    # picking idx th data from the batch 
    batch_size, query_len, key_len = attn_weights.shape 
    
    if idx < 0 or idx >= batch_size:
        print(f'Invalid idx given. batch_size: {batch_size}, given index: {idx}')

    # pick from the batch 
    attn_weight = attn_weights[idx, :20, :20]

    attn_weight = attn_weight.squeeze(0)


    fig, ax = plt.subplots(figsize=[attn_weight.shape[1] * 1.5, attn_weight.shape[0]])
    sns.heatmap(attn_weight, annot=True, fmt=".2f", cbar=False, ax=ax, yticklabels=query, xticklabels=key)

    ax.set_yticklabels(query, rotation=45, fontsize=15)
    ax.set_xticklabels(key, rotation=60, fontsize=15)
    ax.xaxis.tick_top()
    ax.set_title(f'Average Attention Map', fontsize=12)

    for _, spine in ax.spines.items():
        spine.set_visible(True)
        spine.set_linewidth(2)
        spine.set_color('black')

    plt.tight_layout()
    plt.show()